Source code for chia.models.claude

from __future__ import annotations

import json
import logging
import os
import re
import subprocess
import tempfile
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
import threading
from typing import TYPE_CHECKING, List, Optional
from uuid import uuid4

import ray

from chia.base.ChiaFunction import ChiaFunction, ObjectRefCallback
from chia.base.llm_call import QueryResult, LLMCallBase

if TYPE_CHECKING:
    from chia.base.tools.ChiaTool import ChiaTool


# ---------------------------------------------------------------------------
# Result type — claude-specific
# ---------------------------------------------------------------------------


[docs] @dataclass class ClaudeCodeQueryResult(QueryResult): """ Derived QueryResult specialized for the Claude Code CLI backend. Carries the on-disk ``<session_id>.jsonl`` transcript bytes back to the caller so a ``resume_session=True`` LLM can continue the same conversation on a different worker — written by :meth:`ClaudeCodeLLM._capture_transcript` after a CLI run and consumed by :meth:`ClaudeCodeLLM._restore_transcript` before the next ``--resume`` invocation. """ session_transcript: Optional[bytes] = None session_transcript_path: Optional[str] = None
# ``CLIResult`` is the generic alias for any LLM result — kept pointing at # ``QueryResult`` for back-compat. Use ``ClaudeCodeQueryResult`` directly when # you need to construct or access the session-transcript fields. CLIResult = QueryResult # --------------------------------------------------------------------------- # Exceptions # ---------------------------------------------------------------------------
[docs] class ClaudeCodeError(Exception): """Base for all Claude Code CLI errors. Every subclass must implement ``__reduce__`` for Ray serialization. """ def __init__( self, node_id: str, error_type: str, exit_code: int = -1, raw_message: str = "", ): self.node_id = node_id self.error_type = error_type self.exit_code = exit_code self.raw_message = raw_message super().__init__(f"{error_type} on {node_id}: {raw_message[:200]}") def __reduce__(self): return ( self.__class__, (self.node_id, self.error_type, self.exit_code, self.raw_message), )
[docs] class RateLimitError(ClaudeCodeError): """Raised when the Claude CLI response indicates a usage-limit hit.""" def __init__( self, node_id: str, reset_time: datetime, raw_message: str = "", exit_code: int = -1, ): self.reset_time = reset_time super().__init__( node_id=node_id, error_type="rate_limit", exit_code=exit_code, raw_message=raw_message, ) def __reduce__(self): return ( self.__class__, (self.node_id, self.reset_time, self.raw_message, self.exit_code), )
[docs] class AuthenticationError(ClaudeCodeError): """Raised when the CLI's auth token/API key is invalid or expired.""" def __init__(self, node_id: str, exit_code: int = -1, raw_message: str = ""): super().__init__(node_id, "authentication_failed", exit_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.exit_code, self.raw_message))
[docs] class BillingError(ClaudeCodeError): """Raised when the billing account has payment issues.""" def __init__(self, node_id: str, exit_code: int = -1, raw_message: str = ""): super().__init__(node_id, "billing_error", exit_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.exit_code, self.raw_message))
[docs] class InvalidRequestError(ClaudeCodeError): """Raised when the request is malformed (bad prompt, unsupported params, etc.).""" def __init__(self, node_id: str, exit_code: int = -1, raw_message: str = ""): super().__init__(node_id, "invalid_request", exit_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.exit_code, self.raw_message))
[docs] class ServerError(ClaudeCodeError): """Raised when Anthropic's API returns a server-side error (500/503).""" def __init__( self, node_id: str, exit_code: int = -1, raw_message: str = "", retry_after: Optional[int] = None, ): self.retry_after = retry_after super().__init__(node_id, "server_error", exit_code, raw_message) def __reduce__(self): return ( self.__class__, (self.node_id, self.exit_code, self.raw_message, self.retry_after), )
[docs] class MaxOutputTokensError(ClaudeCodeError): """Raised when the LLM's response was truncated by the output token limit.""" def __init__( self, node_id: str, exit_code: int = -1, raw_message: str = "", partial_text: str = "", ): self.partial_text = partial_text super().__init__(node_id, "max_output_tokens", exit_code, raw_message) def __reduce__(self): return ( self.__class__, (self.node_id, self.exit_code, self.raw_message, self.partial_text), )
[docs] class UnknownClaudeError(ClaudeCodeError): """Raised for unclassified CLI errors.""" def __init__( self, node_id: str, exit_code: int = -1, raw_message: str = "", stderr: str = "", ): self.stderr = stderr super().__init__(node_id, "unknown", exit_code, raw_message) def __reduce__(self): return ( self.__class__, (self.node_id, self.exit_code, self.raw_message, self.stderr), )
# --------------------------------------------------------------------------- # Rate-limit text parser # --------------------------------------------------------------------------- _RATE_LIMIT_RE = re.compile( r"You've hit your limit\s*[·•\-—]\s*resets?\s+(\d{1,2})\s*(am|pm)\s*\(([^)]+)\)", re.IGNORECASE, )
[docs] def parse_rate_limit_reset(text: str) -> Optional[datetime]: """Parse a Claude rate-limit message and return the UTC reset time. Expected format: ``"You've hit your limit · resets 4pm (America/Los_Angeles)"`` Returns ``None`` when no rate-limit message is found. """ m = _RATE_LIMIT_RE.search(text) if m is None: return None hour = int(m.group(1)) ampm = m.group(2).lower() tz_str = m.group(3).strip() # Convert 12-hour → 24-hour if ampm == "pm" and hour != 12: hour += 12 elif ampm == "am" and hour == 12: hour = 0 # Resolve timezone try: import zoneinfo tz = zoneinfo.ZoneInfo(tz_str) except Exception: # Fallback: common abbreviations _abbrev = { "PST": -8, "PDT": -7, "MST": -7, "MDT": -6, "CST": -6, "CDT": -5, "EST": -5, "EDT": -4, "UTC": 0, "GMT": 0, } offset_hours = _abbrev.get(tz_str.upper(), 0) tz = timezone(timedelta(hours=offset_hours)) now_in_tz = datetime.now(tz) reset_local = now_in_tz.replace(hour=hour, minute=0, second=0, microsecond=0) # If reset hour is in the past, it means tomorrow if reset_local <= now_in_tz: reset_local += timedelta(days=1) return reset_local.astimezone(timezone.utc)
[docs] def parse_rate_limit_event(event: dict) -> Optional[datetime]: """Parse a ``rate_limit_event`` JSON object and return the UTC reset time. Only triggers when ``rate_limit_info.status`` is ``"rejected"`` — the event is also emitted with other statuses as an informational notice, which should NOT be treated as a rate limit. """ info = event.get("rate_limit_info", {}) if info.get("status") != "rejected": return None resets_at = info.get("resetsAt") if resets_at is None: return None try: return datetime.fromtimestamp(int(resets_at), tz=timezone.utc) except (ValueError, TypeError, OSError): return None
# --------------------------------------------------------------------------- # Session-tracking decorator # --------------------------------------------------------------------------- def _session_tracked(chia_fn): """Stack ABOVE ``@ChiaFunction`` on ``prompt`` to auto-attach transcript sync. Remote dispatch still targets the inner ``@ChiaFunction`` (the CLI runs on a worker). This outer layer only changes what ``chia_remote`` *returns*: when the instance is resuming (``resume_session=True`` -> ``_session_id`` set) it wraps the ObjectRef in an :class:`ObjectRefCallback` carrying ``_sync_transcript``, so ``get(ref)`` harvests the transcript onto the instance with no explicit ``callback=``. When not resuming it returns the plain ObjectRef unchanged. Dispatch stays async (the ObjectRef/ObjectRefCallback returns immediately). ``instance.prompt(...)`` (local) is unchanged; ``ClaudeCodeLLM.prompt`` (class access) exposes the raw inner ``@ChiaFunction``. """ def _wrap(instance, ref): if instance._session_id is not None: return ObjectRefCallback(ref, instance._sync_transcript) return ref class _TrackedHandle: def __init__(self, inner_handle, instance): self._inner = inner_handle self._instance = instance def chia_remote(self, *args, **kwargs): return _wrap(self._instance, self._inner.chia_remote(*args, **kwargs)) def remote(self, *args, **kwargs): return self.chia_remote(*args, **kwargs) class _BoundTracked: def __init__(self, instance): self._instance = instance def __call__(self, *args, **kwargs): # Local, in-process call — prompt() captures the transcript onto # self directly (it runs here), so no wrapping is needed. return chia_fn(self._instance, *args, **kwargs) def chia_remote(self, *args, **kwargs): return _wrap(self._instance, chia_fn.chia_remote(*args, **kwargs)) def options(self, **opts): return _TrackedHandle(chia_fn.options(**opts), self._instance) def __getattr__(self, name): return getattr(chia_fn, name) class _TrackedDescriptor: def __get__(self, obj, objtype=None): if obj is None: return chia_fn # class access -> raw ChiaFunction return _BoundTracked(obj) def __getattr__(self, name): return getattr(chia_fn, name) return _TrackedDescriptor()
[docs] class ClaudeCodeLLM(LLMCallBase): """Wraps the Claude Code CLI (``claude --print``) as an LLM backend. Each call to :meth:`prompt` spawns a ``claude`` subprocess that can optionally connect to MCP tool servers (e.g. :class:`BashTool`). """ def __init__( self, model: str = "claude-sonnet-4-6", system_message: str = "", timeout_seconds: int = 600, retries: int = 3, logging_name: str = "claude_code", logging_level: int = logging.DEBUG, log_dir: Optional[str] = None, resume_session: bool = False, projects_cwd: Optional[str] = "/home/ray/.claude/projects/-home-ray-llm-env", extra_cli_args: Optional[List[str]] = None, log_stream: bool = True, log_all: bool = False, backend: str = "cli", api_key: Optional[str] = None, max_tokens: int = 16000, thinking: Optional[str] = "adaptive", max_tool_iterations: int = 100, ): super().__init__(system_message=system_message) self.logging_level = logging_level self.logging_name = logging_name self.retries = retries self.timeout_seconds = timeout_seconds self.model = model self.extra_cli_args = extra_cli_args or [] self.logger = logging.getLogger(logging_name) self.log_stream = log_stream self.log_all = log_all # Backend selection: "cli" (default, the ``claude --print`` subprocess) # or "api" (the Anthropic Python SDK; see the Anthropic API backend # section below). if backend not in ("cli", "api"): raise ValueError(f"backend must be 'cli' or 'api', got {backend!r}") self.backend = backend self.api_key = api_key self.max_tokens = max_tokens self.thinking = thinking self.max_tool_iterations = max_tool_iterations # The CLI backend ignores the API-only parameters; warn if any were # set away from their defaults so a misdirected config doesn't pass # silently. if self.backend == "cli": ignored = [ name for name, value, default in ( ("api_key", api_key, None), ("max_tokens", max_tokens, 16000), ("thinking", thinking, "adaptive"), ("max_tool_iterations", max_tool_iterations, 100), ) if value != default ] if ignored: self.logger.warning( "backend='cli' ignores API-only parameter(s): %s", ", ".join(ignored), ) # NOTE: the "api" backend has so far only been exercised by the tests # in chia/models/tests/test_claude_api.py (mocked unit tests, plus the # opt-in live tests). It has NOT been validated in production use — # treat it as experimental until it has real mileage. elif self.backend == "api": self.logger.warning( "ClaudeCodeLLM backend='api' is experimental: it has only been " "exercised by unit tests so far, not validated in production." ) self._call_counter = 0 self._session_id = str(uuid4()) if resume_session else None self._last_metadata: dict = {} # populated by _process_event_line self._rate_limit_event: Optional[dict] = None # populated by _process_event_line self._projects_cwd = projects_cwd self._session_transcript: Optional[bytes] = None self._session_transcript_path: Optional[str] = None self._log_dir = log_dir if log_dir is not None: os.makedirs(log_dir, exist_ok=True) session_tag = f"_{self._session_id[:8]}" if self._session_id else "" run_id = datetime.now().strftime("%Y%m%d_%H%M%S") self._log_prefix = os.path.join(log_dir, f"{logging_name}_{run_id}{session_tag}") else: self._log_prefix = None # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] @_session_tracked @ChiaFunction(resources={"claude_creds": 0.01}) def prompt( self, user_message: str, tools: Optional[List[ChiaTool]] = [], ) -> ClaudeCodeQueryResult: """Send *user_message* to Claude Code CLI and return the response. Returns: :class:`ClaudeCodeQueryResult` with ``success=True`` when the CLI ran cleanly, or ``success=False`` when every retry attempt failed (in which case ``result`` is empty and ``returncode`` is ``-1``). Raises: RateLimitError: Usage limit hit — propagates immediately. AuthenticationError: Auth failure — propagates immediately. BillingError: Billing/payment issue — propagates immediately. InvalidRequestError: Malformed request — propagates immediately. ServerError: After all retries with exponential backoff. MaxOutputTokensError: After one retry attempt. """ import time as _time from chia.trace.profiler import get_profiler profiler = get_profiler() for attempt in range(self.retries): try: self._last_metadata = {} self._rate_limit_event = None # Paste any carried transcript onto this machine so a --resume # run finds the conversation, regardless of which worker the # previous call landed on. No-op when not resuming. self._restore_transcript() if self.backend == "api": cli = self._run_api(user_message, tools) elif self.log_stream: cli = self._run_claude_streaming(user_message, tools) else: cli = self._run_claude(user_message, tools) self._call_counter += 1 self._last_metadata["model"] = self.model self._last_metadata["tools"] = [ {"name": t.name, "hostname": getattr(t, "hostname", None), "port": getattr(t, "port", None), "node_id": getattr(t, "node_id", None)} for t in tools ] if profiler.enabled and self._last_metadata: profiler.add_info(self._last_metadata) # Classify and raise typed errors self._classify_error(cli) # Read the freshly-written transcript back into memory (and # onto the ClaudeCodeQueryResult) so the caller can resume on another worker. self._capture_transcript(cli) cli.success = True return cli # -- Never retry: propagate immediately -- except (RateLimitError, AuthenticationError, BillingError, InvalidRequestError): raise # -- Retry once: stochastic generation may produce shorter output -- except MaxOutputTokensError: if attempt == 0: self.logger.warning( "Max output tokens on attempt %d/%d, retrying once", attempt + 1, self.retries, ) continue raise # -- Retry with exponential backoff: transient API issue -- except ServerError: backoff = min(5 * 2 ** attempt, 60) self.logger.warning( "Server error on attempt %d/%d, backing off %ds", attempt + 1, self.retries, backoff, ) _time.sleep(backoff) # -- Standard retry for unknown errors -- except UnknownClaudeError as exc: self.logger.warning( "Unknown error on attempt %d/%d: %s", attempt + 1, self.retries, exc, ) except subprocess.TimeoutExpired: # A timeout means the session was likely created; # switch to --resume for subsequent attempts. if self._session_id is not None and self._call_counter == 0: self._call_counter = 1 self.logger.warning( "Timeout on attempt %d/%d", attempt + 1, self.retries, ) except Exception as exc: self.logger.warning( "Unexpected error on attempt %d/%d: %s", attempt + 1, self.retries, exc, ) return ClaudeCodeQueryResult(result="", returncode=-1, stderr="", stream_result="", success=False)
def _sync_transcript(self, cli: ClaudeCodeQueryResult) -> ClaudeCodeQueryResult: """Copy a worker-captured transcript off *cli* onto this instance. When ``resume_session=True``, ``prompt.chia_remote`` returns an :class:`ObjectRefCallback` carrying this method, so ``get(ref)`` runs it automatically — the caller doesn't pass a ``callback``:: cli = get(llm.prompt.chia_remote(llm, msg)) # auto-syncs transcript The remote worker's own ``self`` mutations are discarded, so the transcript must be harvested locally (in the process that calls ``get``). A follow-up call then resumes the same session even if it lands on a different worker. Guarded so a transcript-less result (api backend / error path) doesn't clobber a prior capture, and a no-op when ``resume_session`` is False. Returns *cli* (so it's a pass-through callback). """ if self._session_id is not None and cli.session_transcript is not None: self._session_transcript = cli.session_transcript self._session_transcript_path = cli.session_transcript_path return cli def _get_node_id(self) -> str: try: return ray.get_runtime_context().get_node_id() except Exception: return "unknown" def _classify_error(self, cli: ClaudeCodeQueryResult) -> None: """Inspect *cli* and raise a typed error if something went wrong. Check order: 1. Rate limit (text regex + streaming event) — highest priority 2. Non-zero exit code — classify by stderr patterns 3. Success — return without raising """ # --- 1. Rate limit (can appear even with exit code 0) --- reset_time = parse_rate_limit_reset(cli.result) if reset_time is None and self._rate_limit_event is not None: reset_time = parse_rate_limit_event(self._rate_limit_event) if reset_time is not None: node_id = self._get_node_id() self.logger.warning( "Rate limit detected on %s (resets %s). Response text:\n%s", node_id, reset_time.isoformat(), cli.result, ) raise RateLimitError( node_id=node_id, reset_time=reset_time, raw_message=cli.result[:300], exit_code=cli.returncode, ) # --- 2. Non-zero exit code → classify by stderr --- if cli.returncode != 0: node_id = self._get_node_id() stderr_lower = cli.stderr.lower() if any(kw in stderr_lower for kw in ( "authentication", "unauthorized", "401", "not authenticated", "login", "auth token", )): raise AuthenticationError( node_id=node_id, exit_code=cli.returncode, raw_message=cli.stderr[:300], ) if any(kw in stderr_lower for kw in ( "billing", "payment", "402", "overdue", "subscription", "plan expired", )): raise BillingError( node_id=node_id, exit_code=cli.returncode, raw_message=cli.stderr[:300], ) if any(kw in stderr_lower for kw in ( "invalid request", "malformed", "400", "bad request", "invalid model", )): raise InvalidRequestError( node_id=node_id, exit_code=cli.returncode, raw_message=cli.stderr[:300], ) if any(kw in stderr_lower for kw in ( "500", "503", "server error", "overloaded", "internal error", "service unavailable", )): raise ServerError( node_id=node_id, exit_code=cli.returncode, raw_message=cli.stderr[:300], ) if any(kw in stderr_lower for kw in ( "max_output_tokens", "output token limit", "maximum output", "response too long", )): raise MaxOutputTokensError( node_id=node_id, exit_code=cli.returncode, raw_message=cli.stderr[:300], partial_text=cli.result, ) # Fallback: unknown error raise UnknownClaudeError( node_id=node_id, exit_code=cli.returncode, raw_message=cli.stderr[:300], stderr=cli.stderr, ) # ------------------------------------------------------------------ # Session transcript (CLI backend, resume across machines) # ------------------------------------------------------------------ # # The Claude Code CLI persists each session to # ``<projects_dir>/<session_id>.jsonl`` where ``<projects_dir>`` is # derived from the CLI process CWD (leading "/" dropped, remaining "/" # turned into "-") under ``~/.claude/projects/``. Because :meth:`prompt` # may be scheduled on a different ``claude_creds`` worker each call, the # on-disk transcript does not follow the conversation. We therefore carry # the bytes in-process (``self._session_transcript``, also surfaced on # :class:`ClaudeCodeQueryResult`) and re-paste them before a ``--resume`` run. def _resolve_projects_dir(self) -> str: """Directory holding ``<session_id>.jsonl`` on the current machine.""" if self._projects_cwd: return self._projects_cwd # The CLI escapes every non-alphanumeric char in its CWD to "-". escaped = re.sub(r"[^a-zA-Z0-9]", "-", os.getcwd()) return os.path.join(os.path.expanduser("~"), ".claude", "projects", escaped) def _transcript_path(self) -> Optional[str]: """Full path to this session's transcript, or None when not resuming.""" if self._session_id is None: return None return os.path.join(self._resolve_projects_dir(), f"{self._session_id}.jsonl") def _restore_transcript(self) -> None: """Paste a carried transcript onto this machine before a resume run. Unconditional: overwrites any existing file so the resumed session always reflects the conversation we hold. Forces ``--resume`` semantics even when this object/worker has never run the CLI itself. No-op unless ``resume_session=True`` and a transcript has been captured. """ if self._session_id is None or self._session_transcript is None: return path = self._transcript_path() os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as fh: fh.write(self._session_transcript) self._session_transcript_path = path if self._call_counter == 0: self._call_counter = 1 def _capture_transcript(self, cli: ClaudeCodeQueryResult) -> None: """Read the on-disk transcript after a run into memory and onto *cli*. No-op unless ``resume_session=True`` and the CLI actually wrote a transcript (the ``api`` backend keeps history in memory and writes nothing, so this leaves the state untouched there). """ if self._session_id is None: return path = self._transcript_path() if path and os.path.exists(path): with open(path, "rb") as fh: data = fh.read() self._session_transcript = data self._session_transcript_path = path cli.session_transcript = data cli.session_transcript_path = path # ------------------------------------------------------------------ # Internals # ------------------------------------------------------------------ def _build_mcp_config(self, tools: List[ChiaTool]) -> dict: """Build the JSON object expected by ``--mcp-config``.""" servers = {} for tool in tools: port = getattr(tool, "port", 8000) servers[tool.name] = { "type": "http", "url": f"http://{tool.hostname}:{port}/{tool.name}/mcp", } return {"mcpServers": servers} def _build_allowed_tools(self, tools: List[ChiaTool]) -> List[str]: """Return ``--allowedTools`` entries for every registered MCP tool.""" allowed: list[str] = [] for tool in tools: # FastMCP registers tools under server_name; the MCP tool ID # that Claude Code recognises is mcp__<server>__<tool_name>. for fn_info in tool.mcp._tool_manager.list_tools(): allowed.append(f"mcp__{tool.name}__{fn_info.name}") return allowed def _build_cmd(self, tools: Optional[List[ChiaTool]] = None) -> list: """Build the ``claude`` CLI command list. The user message is piped via stdin (``-p -``) to avoid OS argument-length limits with long prompts. """ cmd = [ "claude", "--print", "--model", self.model, "--dangerously-skip-permissions", ] if self.extra_cli_args: cmd += self.extra_cli_args if self.log_stream: cmd += ["--output-format", "stream-json", "--verbose"] if self.system_message: cmd += ["--system-prompt", self.system_message] if self._session_id is not None: if self._call_counter > 0: cmd += ["--resume", self._session_id] else: cmd += ["--session-id", self._session_id] if tools: cfg = self._build_mcp_config(tools) tmp = tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) json.dump(cfg, tmp) tmp.close() cmd += ["--mcp-config", tmp.name] allowed = self._build_allowed_tools(tools) if allowed: cmd += ["--allowedTools", ",".join(allowed)] cmd += ["-p", "-"] return cmd def _run_claude( self, user_message: str, tools: Optional[List[ChiaTool]] = None, ) -> ClaudeCodeQueryResult: """Run claude with simple capture (no event streaming).""" cmd = self._build_cmd(tools) self.logger.info("Running: %s", " ".join(cmd[:6]) + " ...") env = {k: v for k, v in os.environ.items() if k != "CLAUDECODE"} result = subprocess.run( cmd, input=user_message, capture_output=True, text=True, timeout=self.timeout_seconds, env=env, ) if result.returncode != 0: self.logger.warning("claude exited %d: %s", result.returncode, result.stderr[:500]) if self._log_prefix is not None: truncated = user_message[:500] + ("..." if len(user_message) > 500 else "") with open(f"{self._log_prefix}.log", "a") as f: f.write("=" * 80 + "\n") f.write(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Prompt #{self._call_counter}\n") f.write("=" * 80 + "\n\n") f.write(f"[User Message]\n{truncated}\n\n") f.write(f"[Response]\n{result.stdout}\n\n") f.write("-" * 80 + "\n\n") return ClaudeCodeQueryResult( result=result.stdout, returncode=result.returncode, stderr=result.stderr, stream_result="", ) # ------------------------------------------------------------------ # Streaming log implementation # ------------------------------------------------------------------ def _run_claude_streaming( self, user_message: str, tools: Optional[List[ChiaTool]] = None, ) -> ClaudeCodeQueryResult: """Run claude with ``--output-format stream-json``. Stdout (NDJSON events) is parsed by ``_process_event_line`` into structured entries. Stderr lines are captured with a ``[stderr]`` prefix. A lock serialises writes from the two drain threads. Every parsed entry is appended to an in-memory ``stream_result`` buffer (returned on ``ClaudeCodeQueryResult``) so callers can surface the event trace — tool calls, thinking, metadata — directly. When ``log_dir`` was set on the constructor, the same entries are also mirrored to ``<prefix>.log`` on disk. """ cmd = self._build_cmd(tools) self.logger.info("Running: %s", " ".join(cmd[:6]) + " ...") env = {k: v for k, v in os.environ.items() if k != "CLAUDECODE"} result_text_parts: list[str] = [] stderr_parts: list[str] = [] stream_parts: list[str] = [] lock = threading.Lock() class _TeeWriter: """File-like wrapper that always mirrors to *accumulator* and optionally writes through to *file* when one is provided.""" def __init__(self, file, accumulator: list[str]): self._file = file self._accumulator = accumulator def write(self, s: str): if self._file is not None: self._file.write(s) self._accumulator.append(s) def flush(self): if self._file is not None: self._file.flush() def close(self): if self._file is not None: self._file.close() proc = subprocess.Popen( cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env, ) proc.stdin.write(user_message) proc.stdin.close() # Only mirror the live event stream to disk when log_all is set. # Without log_all we still accumulate stream_parts in memory but # write a compact result-only entry to the log after the run. file_handle = ( open(f"{self._log_prefix}.log", "a") if self._log_prefix is not None and self.log_all else None ) log_file = _TeeWriter(file_handle, stream_parts) log_file.write("=" * 80 + "\n") log_file.write(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Prompt #{self._call_counter}\n") log_file.write("=" * 80 + "\n\n") truncated = user_message[:500] + ("..." if len(user_message) > 500 else "") log_file.write(f"[User Message]\n{truncated}\n\n") log_file.flush() def drain_stdout(): for line in proc.stdout: line = line.strip() if line: with lock: self._process_event_line(line, log_file, result_text_parts) log_file.flush() def drain_stderr(): for line in proc.stderr: with lock: stderr_parts.append(line) log_file.write(f"[stderr] {line}") log_file.flush() t1 = threading.Thread(target=drain_stderr) t2 = threading.Thread(target=drain_stdout) t1.start() t2.start() proc.wait() t1.join() t2.join() if not result_text_parts: log_file.write("[DEBUG] No events parsed.\n") log_file.write("-" * 80 + "\n\n") log_file.flush() log_file.close() if self._log_prefix is not None and not self.log_all: with open(f"{self._log_prefix}.log", "a") as f: f.write("=" * 80 + "\n") f.write(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Prompt #{self._call_counter}\n") f.write("=" * 80 + "\n\n") f.write(f"[User Message]\n{truncated}\n\n") f.write(f"[Response]\n{''.join(result_text_parts)}\n\n") f.write("-" * 80 + "\n\n") if proc.returncode != 0: self.logger.warning("claude exited %d", proc.returncode) return ClaudeCodeQueryResult( result="".join(result_text_parts), returncode=proc.returncode, stderr="".join(stderr_parts), stream_result="".join(stream_parts), ) def _process_event_line(self, line: str, f, result_text_parts: list) -> None: """Parse a single NDJSON event line and write to the log file.""" if not line: return try: event = json.loads(line) except json.JSONDecodeError: f.write(f"[UNPARSED] {line[:200]}\n") f.flush() return event_type = event.get("type", "") if event_type == "assistant": msg = event.get("message", {}) for block in msg.get("content", []): block_type = block.get("type", "") if block_type == "thinking": f.write("[Thinking]\n") f.write(block.get("thinking", "")) f.write("\n\n") f.flush() elif block_type == "text": text = block.get("text", "") result_text_parts.append(text) f.write("[Response]\n") f.write(text) f.write("\n\n") f.flush() elif block_type == "tool_use": tool_name = block.get("name", "unknown") tool_input = json.dumps(block.get("input", {})) if len(tool_input) > 2000: tool_input = tool_input[:2000] + "\n... [truncated]" f.write(f"[Tool Call: {tool_name}]\n") f.write(f"Args: {tool_input}\n\n") f.flush() elif block_type == "tool_result": content = block.get("content", "") if isinstance(content, list): content = "\n".join( c.get("text", "") for c in content if isinstance(c, dict) ) if len(content) > 2000: content = content[:2000] + "\n... [truncated]" f.write(f"[Tool Result]\n{content}\n\n") f.flush() elif event_type == "user": # Tool results ride on user events — the CLI echoes each tool # reply back to the assistant as a user-turn ``tool_result`` # block. Capturing them here is the only way to see what the # assistant actually saw when it chose its next action. msg = event.get("message", {}) content_blocks = msg.get("content", []) if isinstance(content_blocks, str): # Plain-text user message (initial prompt echo); skip. return for block in content_blocks: if not isinstance(block, dict): continue if block.get("type") != "tool_result": continue content = block.get("content", "") if isinstance(content, list): content = "\n".join( c.get("text", "") for c in content if isinstance(c, dict) ) if len(content) > 2000: content = content[:2000] + "\n... [truncated]" label = "Tool Result (error)" if block.get("is_error") else "Tool Result" f.write(f"[{label}]\n{content}\n\n") f.flush() elif event_type == "result": result_text = event.get("result", "") if result_text and not result_text_parts: result_text_parts.append(result_text) parts = [] meta: dict = {} cost = event.get("total_cost_usd") if cost is not None: parts.append(f"Cost: ${cost:.4f}") meta["cost_usd"] = cost duration = event.get("duration_ms") if duration is not None: parts.append(f"Duration: {duration / 1000:.1f}s") meta["duration_s"] = round(duration / 1000, 2) turns = event.get("num_turns") if turns is not None: parts.append(f"Turns: {turns}") meta["num_turns"] = turns usage = event.get("usage", {}) in_tok = usage.get("input_tokens", 0) out_tok = usage.get("output_tokens", 0) cc_tok = usage.get("cache_creation_input_tokens", 0) cr_tok = usage.get("cache_read_input_tokens", 0) if in_tok: parts.append(f"Input tokens: {in_tok}") meta["input_tokens"] = in_tok if cc_tok: parts.append(f"Cache creation: {cc_tok}") meta["cache_creation_input_tokens"] = cc_tok if cr_tok: parts.append(f"Cache read: {cr_tok}") meta["cache_read_input_tokens"] = cr_tok if out_tok: parts.append(f"Output tokens: {out_tok}") meta["output_tokens"] = out_tok if parts: f.write(f"[Metadata]\n{' | '.join(parts)}\n") f.flush() if meta: self._last_metadata = meta elif event_type == "rate_limit_event": self._rate_limit_event = event # system — skip silently # ================================================================== # Anthropic Python API backend # ------------------------------------------------------------------ # Everything below is NEW: an alternative to the ``claude --print`` # CLI path, selected with ``backend="api"`` in the constructor. It # calls the Anthropic Messages API (``anthropic`` SDK) directly and # runs the agentic tool loop client-side, executing each ChiaTool's # MCP server over HTTP. Returns the same :class:`ClaudeCodeQueryResult` shape as # the CLI path so callers don't change; ``returncode`` is synthesised # (0 on success). Errors are translated to the same typed exceptions # the CLI path raises, so :meth:`prompt`'s retry loop is unchanged. # # WARNING: this backend is experimental. It is only exercised by the # tests in chia/models/tests/test_claude_api.py and has not been # validated in production — selecting backend="api" logs a warning. # ================================================================== def _run_api( self, user_message: str, tools: Optional[List[ChiaTool]] = None, ) -> ClaudeCodeQueryResult: """Synchronous entry point for the API backend. Drives an async agent loop (MCP connections + tool execution are async) from :meth:`prompt`, which is synchronous. """ return self._run_coroutine(self._run_api_async(user_message, tools or [])) def _run_coroutine(self, coro): """Run *coro* to completion, whether or not an event loop is live. ``prompt`` is sync and may be invoked either from plain sync code or from inside a running event loop (e.g. a Ray async actor). When a loop is already running, ``asyncio.run`` would raise, so the coroutine is offloaded to a worker thread with its own loop. """ import asyncio try: asyncio.get_running_loop() except RuntimeError: return asyncio.run(coro) import concurrent.futures with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: return pool.submit(asyncio.run, coro).result() async def _run_api_async( self, user_message: str, tools: Optional[List[ChiaTool]] = None, ) -> ClaudeCodeQueryResult: """Connect to each ChiaTool's MCP server, then run the agent loop. The raw Messages API returns one assistant turn at a time; when it emits ``tool_use`` blocks we execute them against the MCP servers and feed ``tool_result`` blocks back, looping until the model stops requesting tools (or :attr:`max_tool_iterations` is reached). """ import anthropic from contextlib import AsyncExitStack from mcp import ClientSession from mcp.client.streamable_http import streamable_http_client client = ( anthropic.AsyncAnthropic(api_key=self.api_key) if self.api_key else anthropic.AsyncAnthropic() ) stream_parts: list[str] = [] stream_parts.append("=" * 80 + "\n") stream_parts.append( f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Prompt #{self._call_counter} (api)\n" ) stream_parts.append("=" * 80 + "\n\n") truncated = user_message[:500] + ("..." if len(user_message) > 500 else "") stream_parts.append(f"[User Message]\n{truncated}\n\n") async with AsyncExitStack() as stack: # --- Connect to every MCP server and gather tool schemas --- anthropic_tools: list[dict] = [] dispatch: dict = {} # api tool name -> (session, mcp tool name) for tool in tools or []: port = getattr(tool, "port", 8000) url = f"http://{tool.hostname}:{port}/{tool.name}/mcp" transport = await stack.enter_async_context(streamable_http_client(url)) read, write = transport[0], transport[1] session = await stack.enter_async_context(ClientSession(read, write)) await session.initialize() listed = await session.list_tools() for fn in listed.tools: api_name = f"{tool.name}__{fn.name}"[:64] anthropic_tools.append({ "name": api_name, "description": fn.description or "", "input_schema": fn.inputSchema or {"type": "object", "properties": {}}, }) dispatch[api_name] = (session, fn.name) # --- Prompt caching: cache the stable tools+system prefix. # Render order is tools -> system, so a breakpoint on the last # system block caches both; with no system prompt, cache the # last tool definition instead. --- system = None if self.system_message: system = [{ "type": "text", "text": self.system_message, "cache_control": {"type": "ephemeral"}, }] elif anthropic_tools: anthropic_tools[-1]["cache_control"] = {"type": "ephemeral"} messages: list[dict] = [{"role": "user", "content": user_message}] meta = { "input_tokens": 0, "output_tokens": 0, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0, "num_turns": 0, } final_text = "" for _ in range(self.max_tool_iterations): kwargs: dict = { "model": self.model, "max_tokens": self.max_tokens, "messages": messages, } if system is not None: kwargs["system"] = system if anthropic_tools: kwargs["tools"] = anthropic_tools if self.thinking: kwargs["thinking"] = {"type": self.thinking} try: resp = await client.messages.create(**kwargs) except Exception as exc: translated = self._translate_api_error(exc) if translated is not None: raise translated from exc raise meta["num_turns"] += 1 usage = getattr(resp, "usage", None) if usage is not None: meta["input_tokens"] += getattr(usage, "input_tokens", 0) or 0 meta["output_tokens"] += getattr(usage, "output_tokens", 0) or 0 meta["cache_creation_input_tokens"] += ( getattr(usage, "cache_creation_input_tokens", 0) or 0 ) meta["cache_read_input_tokens"] += ( getattr(usage, "cache_read_input_tokens", 0) or 0 ) # --- Log this turn's content blocks --- turn_text_parts: list[str] = [] tool_uses = [] for block in resp.content: btype = getattr(block, "type", "") if btype == "thinking": stream_parts.append( f"[Thinking]\n{getattr(block, 'thinking', '')}\n\n" ) elif btype == "text": turn_text_parts.append(block.text) stream_parts.append(f"[Response]\n{block.text}\n\n") elif btype == "tool_use": tool_uses.append(block) tool_input = json.dumps(block.input) if len(tool_input) > 2000: tool_input = tool_input[:2000] + "\n... [truncated]" stream_parts.append( f"[Tool Call: {block.name}]\nArgs: {tool_input}\n\n" ) if turn_text_parts: final_text = "".join(turn_text_parts) # Preserve the assistant turn verbatim (thinking + tool_use # blocks must round-trip for the next request). messages.append({"role": "assistant", "content": resp.content}) if resp.stop_reason == "max_tokens": raise MaxOutputTokensError( node_id=self._get_node_id(), exit_code=-1, raw_message="response truncated at max_tokens", partial_text=final_text, ) if resp.stop_reason != "tool_use" or not tool_uses: break # --- Execute each requested tool over its MCP server --- tool_results = [] for tu in tool_uses: session, fn_name = dispatch.get(tu.name, (None, None)) if session is None: result_text = f"Unknown tool: {tu.name}" is_error = True else: try: mcp_result = await session.call_tool( fn_name, tu.input or {} ) result_text = self._mcp_result_to_text(mcp_result) is_error = bool(getattr(mcp_result, "isError", False)) except Exception as exc: # tool failure is recoverable result_text = f"Tool execution error: {exc}" is_error = True logged = result_text if len(logged) > 2000: logged = logged[:2000] + "\n... [truncated]" label = "Tool Result (error)" if is_error else "Tool Result" stream_parts.append(f"[{label}]\n{logged}\n\n") tool_results.append({ "type": "tool_result", "tool_use_id": tu.id, "content": result_text, "is_error": is_error, }) messages.append({"role": "user", "content": tool_results}) else: stream_parts.append( f"[DEBUG] Reached max_tool_iterations={self.max_tool_iterations}\n" ) # --- Metadata + log file (mirrors the CLI path's output) --- self._last_metadata = {k: v for k, v in meta.items() if v} stream_parts.append("-" * 80 + "\n\n") if self._log_prefix is not None: with open(f"{self._log_prefix}.log", "a") as f: f.write("".join(stream_parts)) return ClaudeCodeQueryResult( result=final_text, returncode=0, stderr="", stream_result="".join(stream_parts), ) @staticmethod def _mcp_result_to_text(result) -> str: """Flatten an MCP ``CallToolResult`` into plain text for a tool_result.""" parts = [] for item in getattr(result, "content", None) or []: text = getattr(item, "text", None) parts.append(text if text is not None else str(item)) return "\n".join(parts) def _translate_api_error(self, exc) -> Optional[Exception]: """Map an ``anthropic`` SDK exception to a local typed error. Returns the local exception to raise, or ``None`` when *exc* is not an Anthropic API error (so the caller re-raises it untouched). The local types mirror the CLI path so :meth:`prompt`'s retry loop sees the same exceptions regardless of backend. """ import anthropic node_id = self._get_node_id() msg = str(exc)[:300] if isinstance(exc, anthropic.RateLimitError): retry_after = 60 try: ra = exc.response.headers.get("retry-after") if ra: retry_after = int(ra) except Exception: pass reset_time = datetime.now(timezone.utc) + timedelta(seconds=retry_after) return RateLimitError( node_id=node_id, reset_time=reset_time, raw_message=msg, exit_code=-1 ) if isinstance(exc, anthropic.AuthenticationError): return AuthenticationError(node_id=node_id, raw_message=msg) if isinstance(exc, anthropic.PermissionDeniedError): if "billing" in (getattr(exc, "type", "") or ""): return BillingError(node_id=node_id, raw_message=msg) return AuthenticationError(node_id=node_id, raw_message=msg) if isinstance(exc, (anthropic.BadRequestError, anthropic.NotFoundError)): return InvalidRequestError(node_id=node_id, raw_message=msg) if isinstance(exc, (anthropic.InternalServerError, anthropic.APIConnectionError, anthropic.APITimeoutError)): return ServerError(node_id=node_id, raw_message=msg) if isinstance(exc, anthropic.APIStatusError): if getattr(exc, "status_code", 0) >= 500: return ServerError(node_id=node_id, raw_message=msg) return UnknownClaudeError(node_id=node_id, raw_message=msg, stderr=msg) return None