Source code for chia.base.tools.ChiaTool

from __future__ import annotations

from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict, Any

from mcp.server.fastmcp import FastMCP
from mcp.server.transport_security import TransportSecuritySettings
import ray
import logging
import sys
from chia.base.tools.util import make_router_lifespan

[docs] @dataclass class ToolInfo: name: str port: int node_id: str
[docs] class ChiaTool: """Base class for MCP tool servers deployed onto Ray workers. Subclasses define a setup() method, which calls:: ``self.mcp.add_tool(self.method, name=...)`` (one or more times) to register functions as tools with instances of this ChiaTool. Subclass can shut down the tool server with:: ``self.stop()`` - Tells the actor to shut down uvicorn, then kills the actor. - Because start and stop run in the same actor process, the uvicorn server reference is always reachable. The resulting MCP endpoint is at:: http://{self.hostname}:{self.port}/{self.name}/mcp Example subclass:: class BashTool(ChiaTool): def setup(self): self.mcp.add_tool(self.run_command, name=f"{name}_run_command") def run_command(self, command: str) -> str: ... Alternatively, instead of a setup method a subclass can define an __init__ method which must do the following:: def __init__(self, name, task_options): super().__init__(name, task_options=task_options) # Registers fns with self.mcp.add_tool super().__post_init__() """ # Stores {"name": str, "port": int, "node_id": str} _tool_registry: List[ToolInfo] = [] def __init__(self, name: str, task_options: Optional[Dict] = None, logging_level = logging.DEBUG): """Initializes ChiaTool with a name and optional resource requirements. """ self.name = name self.logging_name = name self.mcp = FastMCP( name, stateless_http=True, transport_security=TransportSecuritySettings( enable_dns_rebinding_protection=False, ), ) self.hostname = None # Will be set when the tool starts up and finds its IP address. self.port = 8000 # Will be set to the actual port by start_tool. self.task_options = task_options self.logger = logging.getLogger(self.logging_name) self.logger.setLevel(logging_level) self.node_id = None # Will be set to the actual node_id by __post_init__. self.tool_info = None self._server_actor = None # Ray actor handle, set by __post_init__ def __post_init__(self): # Idempotency guard: deploying twice would spin up a second actor and # orphan the first. No-op if the server is already up — this makes the # setup()-hook construction style (see __init_subclass__) safe even when # a subclass also calls super().__post_init__() from a hand-written # __init__. if self._server_actor is not None: return if self.task_options is not None: self._server_actor = _ToolServerActor.options(**self.task_options).remote() else: self._server_actor = _ToolServerActor.remote() self.hostname, self.port, self.node_id = ray.get( self._server_actor.start.remote(self) ) self.logger.info(f"{self.name} started at {self.hostname}:{self.port} on node {self.node_id}") self.tool_info = ToolInfo( name=self.name, port=self.port, node_id=self.node_id ) ChiaTool._tool_registry.append(self.tool_info)
[docs] def setup(self, *args, **kwargs): """Hook for the auto-constructed style — override to register tools and set instance state, *instead of* writing ``__init__``. A subclass that defines ``setup`` and no ``__init__`` is given an ``__init__`` automatically (see :meth:`__init_subclass__`) that runs ``ChiaTool.__init__`` before and ``__post_init__`` after ``setup``, so the contract with the subclass can't be written incorrectly. Inside ``setup`` the base ``__init__`` has already run, so ``self.name`` / ``self.mcp`` are available:: class BashTool(ChiaTool): def setup(self, work_dir="/"): self.work_dir = work_dir self.mcp.add_tool(self.run_command, name=f"{self.name}_run_command") Positional/keyword args from the constructor (other than ``name``, ``task_options``, and ``logging_level``, which the base consumes) are forwarded here. Multi-level subclasses override ``setup`` and call ``super().setup(...)`` to chain. """ raise NotImplementedError( f"{type(self).__name__} defines neither __init__ nor setup(); " "implement one of them." )
def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) # Opt-in convenience: a subclass that defines setup() but no __init__ # gets an __init__ that brackets setup() with ChiaTool.__init__ (before) # and __post_init__ (after) automatically. Subclasses that write their # own __init__ are left untouched and keep using the explicit # super().__init__() / super().__post_init__() pattern. if "setup" in cls.__dict__ and "__init__" not in cls.__dict__: def _auto_init(self, name, *args, task_options=None, logging_level=logging.DEBUG, **kwargs): # Call ChiaTool.__init__ explicitly (not super()/self.__init__): # self.__init__ is *this* function, so that would recurse. ChiaTool.__init__(self, name, task_options=task_options, logging_level=logging_level) self.setup(*args, **kwargs) self.__post_init__() _auto_init.__name__ = "__init__" _auto_init.__qualname__ = f"{cls.__qualname__}.__init__" cls.__init__ = _auto_init
[docs] def stop(self): """Stop the tool's MCP server and clean up resources.""" if self.tool_info in ChiaTool._tool_registry: ChiaTool._tool_registry.remove(self.tool_info) if self._server_actor is not None: try: ray.get(self._server_actor.stop.remote()) except Exception as e: self.logger.warning(f"Error stopping tool {self.name}: {e}") ray.kill(self._server_actor, no_restart=True) self._server_actor = None
def dict_entry(self): return self.mcp def __getstate__(self): """Exclude actor handle when serialized by Ray (e.g. passed to remote tasks).""" state = self.__dict__.copy() state['_server_actor'] = None return state def __setstate__(self, state): self.__dict__.update(state)
[docs] def resolve_tool_url(url: str) -> str: """Rewrite a tool URL so it is routable from the current node. On tunnelled EC2 workers ``CHIA_TOOL_ADVERTISE_HOST`` and ``CHIA_TOOL_RELAY_HOST`` are set. Tool URLs advertise the head's real IP (``CHIA_TOOL_ADVERTISE_HOST``) which is only directly reachable from the local network. EC2 workers must connect via a reverse-tunnel relay instead, so this function replaces the host portion of the URL with the relay loopback (``CHIA_TOOL_RELAY_HOST``). On non-tunnelled nodes (or when the env vars are absent) the URL is returned unchanged. """ import os advertise = os.environ.get("CHIA_TOOL_ADVERTISE_HOST") relay = os.environ.get("CHIA_TOOL_RELAY_HOST") if advertise and relay and advertise in url: return url.replace(advertise, relay, 1) return url
def makeMCPDeploymentClass( name: str, fastapi_app, autoscaling_config: dict = { "min_replicas": 1, "max_replicas": 20, "target_ongoing_requests": 5 }, ray_actor_options: dict = {"num_cpus": 0.2} ): from ray import serve @serve.deployment( autoscaling_config=autoscaling_config, ray_actor_options=ray_actor_options, name = name ) @serve.ingress(fastapi_app) class _MCPDeployment: def __init__(self): pass return _MCPDeployment def make_lifespan(mcpInst: FastMCP[Any], name=""): @asynccontextmanager async def lifespan(app): app.mount(f"/", mcpInst.streamable_http_app()) async with mcpInst.session_manager.run(): yield return lifespan class _PortRegistry: """Tracks which ports are taken per IP, so concurrent start_router calls in the same process don't race on the same port.""" _taken: Dict[str, set] = {} # ip -> set of ports @classmethod def reserve(cls, ip: str, port: int): cls._taken.setdefault(ip, set()).add(port) @classmethod def is_taken(cls, ip: str, port: int) -> bool: return port in cls._taken.get(ip, set()) @classmethod def release(cls, ip: str, port: int) -> bool: """ Releases port from reservation for ip Returns false if port was not already reserved """ try: cls._taken[ip].remove(port) except KeyError: return False return True # Worker-side registry: lives in the actor process so start_router and # stop_router always see the same dict. import threading as _threading _active_servers: Dict[str, Tuple["uvicorn.Server", _threading.Thread, str, int]] = {} """tool_name -> (server, thread, ip, port).""" @ray.remote(num_cpus=0) class _ToolServerActor: """Persistent Ray actor that manages a uvicorn server for one MCP tool. Because it is an actor, start() and stop() always execute in the same process, so the uvicorn.Server reference (in _active_servers) is never lost. """ def __init__(self): self._name = None def start(self, tool: "ChiaTool") -> Tuple[str, int, str]: """Start the MCP server. Returns (advertised_ip, port, node_id). Reads CHIA_TOOL_BASE_PORT / CHIA_TOOL_MAX_PORT from the **worker** environment so that tunnelled nodes bind to the SSH-forwarded port range instead of the default (8000). When CHIA_TOOL_ADVERTISE_HOST is set (tunnelled workers), the returned IP is the head's resolved IP rather than the tunnel loopback. Uvicorn still binds to the real node IP so the SSH forward tunnel can reach it. """ import os base_port = int(os.environ.get("CHIA_TOOL_BASE_PORT", "8000")) max_port = int(os.environ.get("CHIA_TOOL_MAX_PORT", "0")) max_tries = (max_port - base_port + 1) if max_port else 100 bind_ip = ray.util.get_node_ip_address() port = start_router(tool, bind_ip, base_port=base_port, max_tries=max_tries) node_id = ray.get_runtime_context().get_node_id() self._name = tool.name advertise_ip = os.environ.get("CHIA_TOOL_ADVERTISE_HOST", bind_ip) return advertise_ip, port, node_id def stop(self) -> bool: """Stop the uvicorn server. Returns True if it was running.""" if self._name: return stop_router(self._name) return False
[docs] def start_router(tool: ChiaTool, ip_address: str, base_port: int = 8000, max_tries: int = 100) -> int: """Start MCP tool servers using a local uvicorn instance. Uses a plain uvicorn server (background thread) instead of Ray Serve so that multiple nodes can each host their own independent tool servers without route-prefix conflicts. Tries ports starting from *base_port*, skipping any that are already in use (e.g. host-networked Docker containers sharing the same port space). Returns the port that was successfully bound. """ import threading import time import uvicorn from fastapi import FastAPI # Try ports sequentially, starting uvicorn on each until one succeeds. for port in range(base_port, base_port + max_tries): if _PortRegistry.is_taken(ip_address, port): continue # Create a fresh app each attempt — MCP's StreamableHTTPSessionManager # can only .run() once per instance. tool.mcp._session_manager = None app = FastAPI(lifespan=make_router_lifespan([tool.mcp])) app.mount(f"/{tool.name}", tool.mcp.streamable_http_app()) config = uvicorn.Config(app, host=ip_address, port=port, log_level="info") server = uvicorn.Server(config) thread = threading.Thread(target=server.run, daemon=True) thread.start() # Wait for uvicorn to confirm it bound the port. for _ in range(50): # up to 5 seconds time.sleep(0.1) if server.started: _PortRegistry.reserve(ip_address, port) _active_servers[tool.name] = (server, thread, ip_address, port) return port # Uvicorn didn't start — shut it down and try next port. server.should_exit = True thread.join(timeout=2) continue raise RuntimeError( f"Could not find available port in range {base_port}-{base_port + max_tries - 1}" )
[docs] def stop_router(tool_name: str) -> bool: """Stop a running uvicorn server by tool name. Looks up the server in the worker-side _active_servers registry, signals it to exit, waits for the thread, and releases the port. Returns True if the server was found and stopped. """ entry = _active_servers.pop(tool_name, None) if entry is None: print(f"Warning: stop_router called for '{tool_name}' but no active server found") return False server, thread, ip, port = entry server.should_exit = True thread.join(timeout=5) if thread.is_alive(): print(f"Warning: uvicorn thread for '{tool_name}' still alive after 5s") _PortRegistry.release(ip, port) return True