Source code for chia.models.openai_compat

"""OpenAI-compatible LLM backend built on the openai SDK Chat Completions API.

:class:`OpenAICompatLLM` talks to **any** provider that implements the OpenAI
Chat Completions wire format and drives the agentic tool loop client-side —
executing each ChiaTool's MCP server over HTTP, like the other backends.

It is deliberately *not* provider-specific. The OpenAI Chat Completions format
is a de-facto multi-vendor standard, so the only provider-specific inputs are:

* ``base_url`` — selects the provider's endpoint (default: OpenAI itself).
* auth — the credential for that endpoint.

The same loop/tool/parse/error code therefore covers OpenAI, Fireworks, Groq,
OpenRouter, self-hosted vLLM/TGI, Vertex MaaS, etc. Each provider is
configuration, not a new module.

Auth, by default, lives entirely in the environment (``OPENAI_API_KEY`` /
``OPENAI_BASE_URL``) — construct with no key and the SDK reads them, matching
the other backends. For providers whose credential is a **rotating token**
(Vertex MaaS GCP token, Azure AD), pass ``token_provider`` — a zero-arg
callable that returns a fresh token; it's invoked when the client is built.

WARNING: experimental. Only exercised by the tests in
chia/models/tests/test_openai_compat.py. Not validated in production.
``openai`` is imported lazily, so importing this module does not require it.
"""

from __future__ import annotations

import json
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Callable, List, Optional

import ray

from chia.base.ChiaFunction import ChiaFunction
from chia.base.llm_call import QueryResult, LLMCallBase

if TYPE_CHECKING:
    from chia.base.tools.ChiaTool import ChiaTool


# ---------------------------------------------------------------------------
# Exceptions
#
# A parallel taxonomy to the other backends. Kept separate so this module
# stands alone; each carries ``__reduce__`` for Ray transport.
# ---------------------------------------------------------------------------


[docs] class OpenAICompatError(Exception): """Base for all OpenAI-compatible backend errors.""" def __init__( self, node_id: str, error_type: str, status_code: Optional[int] = None, raw_message: str = "", ): self.node_id = node_id self.error_type = error_type self.status_code = status_code self.raw_message = raw_message super().__init__(f"{error_type} on {node_id}: {raw_message[:200]}") def __reduce__(self): return ( self.__class__, (self.node_id, self.error_type, self.status_code, self.raw_message), )
[docs] class RateLimitError(OpenAICompatError): """HTTP 429.""" def __init__( self, node_id: str, reset_time: datetime, raw_message: str = "", status_code: Optional[int] = None, ): self.reset_time = reset_time super().__init__(node_id, "rate_limit", status_code, raw_message) def __reduce__(self): return ( self.__class__, (self.node_id, self.reset_time, self.raw_message, self.status_code), )
[docs] class AuthenticationError(OpenAICompatError): """HTTP 401 / 403.""" def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "authentication_failed", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class InvalidRequestError(OpenAICompatError): """HTTP 400 / 404.""" def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "invalid_request", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class ServerError(OpenAICompatError): """HTTP 5xx, connection, or timeout.""" def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "server_error", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class MaxOutputTokensError(OpenAICompatError): """The response was truncated at ``max_tokens`` (finish_reason='length').""" def __init__( self, node_id: str, status_code: Optional[int] = None, raw_message: str = "", partial_text: str = "", ): self.partial_text = partial_text super().__init__(node_id, "max_output_tokens", status_code, raw_message) def __reduce__(self): return ( self.__class__, (self.node_id, self.status_code, self.raw_message, self.partial_text), )
[docs] class ContextLengthExceededError(InvalidRequestError): """HTTP 400 (or 413) whose body shows the prompt exceeded the context window. Subclasses :class:`InvalidRequestError` so it inherits never-retry semantics (re-sending the same oversized prompt cannot help) while letting callers detect the context-overflow case specifically. Distinct from :class:`MaxOutputTokensError`, which is about *output* truncation. """ def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): OpenAICompatError.__init__( self, node_id, "context_length_exceeded", status_code, raw_message ) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class BillingError(OpenAICompatError): """Payment / quota problem (HTTP 402, or a 429 carrying ``insufficient_quota``). Never retried: spending more requests cannot restore quota or credit. """ def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "billing", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class UnknownOpenAIError(OpenAICompatError): """Unclassified OpenAI-compatible error.""" def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "unknown", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
# --------------------------------------------------------------------------- # Body-level disambiguation # # Some providers collapse semantically distinct failures into the same HTTP # status (e.g. context-overflow and a bad parameter are both 400; quota # exhaustion arrives as a 429). These patterns let ``_translate_error`` split # them by inspecting the error code / message. Kept deliberately conservative # to avoid false positives. # # - context-length: backed by a REAL OpenRouter capture (the message hits both # "context length" and "reduce the length"); the string codes cover # OpenAI/vLLM-native ("context_length_exceeded"). # - billing: follows common provider conventions (OpenAI insufficient_quota, # 402 Payment Required). NOT yet backed by a real capture. # --------------------------------------------------------------------------- _CONTEXT_LENGTH_CODES = {"context_length_exceeded", "string_above_max_length"} _CONTEXT_LENGTH_PHRASES = ( "context length", "context_length_exceeded", "maximum context", "prompt is too long", "too many tokens", "reduce the length", "maximum prompt", ) _BILLING_CODES = {"insufficient_quota", "billing_hard_limit_reached"} _BILLING_PHRASES = ( "insufficient_quota", "exceeded your current quota", "billing", "payment required", "not enough credits", "insufficient credits", ) def _error_code_and_text(exc): """Extract a string error ``code`` (or None) and a lowercased message.""" code = getattr(exc, "code", None) if not isinstance(code, str): code = None msg = getattr(exc, "message", None) or str(exc) return code, (msg or "").lower() def _unwrap_exception_group(exc): """Drill through (possibly nested) ExceptionGroups to a representative leaf. MCP's anyio task group re-wraps any error that propagates out of the tool-connected ``AsyncExitStack`` in an ``ExceptionGroup`` (the ``exceptiongroup`` backport on Python < 3.11), which would otherwise hide a perfectly classifiable provider error. Prefer an already-typed :class:`OpenAICompatError`, then an ``openai`` API error, else the first leaf. Non-group exceptions are returned unchanged. """ subs = getattr(exc, "exceptions", None) if not subs or not isinstance(subs, (list, tuple)): return exc leaves = [_unwrap_exception_group(s) for s in subs] for leaf in leaves: if isinstance(leaf, OpenAICompatError): return leaf try: import openai for leaf in leaves: if isinstance(leaf, openai.APIError): return leaf except Exception: pass return leaves[0]
[docs] class OpenAICompatLLM(LLMCallBase): """OpenAI-compatible Chat Completions backend with client-side MCP tools. Returns the same :class:`QueryResult` shape as the other backends so callers are interchangeable; ``returncode`` is synthesised (0 on success, -1 when every retry fails) and ``stderr`` is unused. """ def __init__( self, model: str, system_message: str = "", timeout_seconds: int = 600, retries: int = 3, logging_name: str = "openai_compat_llm", logging_level: int = logging.DEBUG, log_dir: Optional[str] = None, base_url: Optional[str] = None, api_key: Optional[str] = None, token_provider: Optional[Callable[[], str]] = None, max_tokens: int = 16000, max_tool_iterations: int = 100, client_kwargs: Optional[dict] = None, ): super().__init__(system_message=system_message) self.logging_level = logging_level self.logging_name = logging_name self.retries = retries self.timeout_seconds = timeout_seconds self.model = model self.base_url = base_url # None -> OPENAI_BASE_URL env / OpenAI default self.api_key = api_key # None -> OPENAI_API_KEY env self.token_provider = token_provider self.max_tokens = max_tokens self.max_tool_iterations = max_tool_iterations self.client_kwargs = client_kwargs or {} self.logger = logging.getLogger(logging_name) self._last_metadata: dict = {} self.logger.warning( "OpenAICompatLLM is experimental: only exercised by unit tests so " "far, not validated in production." ) self._log_dir = log_dir if log_dir is not None: os.makedirs(log_dir, exist_ok=True) run_id = datetime.now().strftime("%Y%m%d_%H%M%S") self._log_prefix = os.path.join(log_dir, f"{logging_name}_{run_id}") else: self._log_prefix = None # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] @ChiaFunction(resources={"openai_creds": 0.01}) def prompt( self, user_message: str, tools: Optional[List[ChiaTool]] = [], ) -> QueryResult: """Send *user_message* via Chat Completions and return the response, retrying transient failures with the same policy the other backends use. """ import time as _time from chia.trace.profiler import get_profiler profiler = get_profiler() for attempt in range(self.retries): try: self._last_metadata = {} cli = self._run_openai(user_message, tools) self._last_metadata["model"] = self.model self._last_metadata["tools"] = [ {"name": t.name, "hostname": getattr(t, "hostname", None), "port": getattr(t, "port", None), "node_id": getattr(t, "node_id", None)} for t in tools ] if profiler.enabled and self._last_metadata: profiler.add_info(self._last_metadata) cli.success = True return cli # -- Never retry: propagate immediately -- # (ContextLengthExceededError is caught here as an InvalidRequestError # subclass; BillingError is listed explicitly.) except (RateLimitError, AuthenticationError, InvalidRequestError, BillingError): raise # -- Retry once: a shorter generation may fit -- except MaxOutputTokensError: if attempt == 0: self.logger.warning( "Max output tokens on attempt %d/%d, retrying once", attempt + 1, self.retries, ) continue raise # -- Retry with exponential backoff: transient service issue -- except ServerError: backoff = min(5 * 2 ** attempt, 60) self.logger.warning( "Server error on attempt %d/%d, backing off %ds", attempt + 1, self.retries, backoff, ) _time.sleep(backoff) # -- Standard retry for unknown errors -- except UnknownOpenAIError as exc: self.logger.warning( "Unknown error on attempt %d/%d: %s", attempt + 1, self.retries, exc, ) except Exception as exc: self.logger.warning( "Unexpected error on attempt %d/%d: %s", attempt + 1, self.retries, exc, ) return QueryResult(result="", returncode=-1, stderr="", stream_result="", success=False)
def _get_node_id(self) -> str: try: return ray.get_runtime_context().get_node_id() except Exception: return "unknown" # ------------------------------------------------------------------ # Chat Completions implementation # ------------------------------------------------------------------ def _run_openai( self, user_message: str, tools: Optional[List[ChiaTool]] = None, ) -> QueryResult: """Synchronous entry point; drives the async agent loop. A second translation guard catches the case where a typed error (or a raw openai error) escapes the tool-connected ``AsyncExitStack`` wrapped in an ``ExceptionGroup`` by MCP's anyio task group — without this, the wrapped error never matches ``prompt()``'s ``except`` clauses and every provider error in the tools path is misfiled as "unexpected". """ try: return self._run_coroutine(self._run_openai_async(user_message, tools or [])) except Exception as exc: translated = self._translate_error(exc) if translated is not None: raise translated from exc raise def _run_coroutine(self, coro): """Run *coro* whether or not an event loop is already running.""" import asyncio try: asyncio.get_running_loop() except RuntimeError: return asyncio.run(coro) import concurrent.futures with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: return pool.submit(asyncio.run, coro).result() def _make_client(self): """Build the AsyncOpenAI client. With no ``base_url``/``api_key``/``token_provider`` the SDK reads ``OPENAI_API_KEY`` / ``OPENAI_BASE_URL`` from the environment — the same env-driven auth model as the other backends. """ from openai import AsyncOpenAI kwargs = dict(self.client_kwargs) if self.base_url is not None: kwargs["base_url"] = self.base_url if self.token_provider is not None: kwargs["api_key"] = self.token_provider() elif self.api_key is not None: kwargs["api_key"] = self.api_key return AsyncOpenAI(**kwargs) async def _run_openai_async( self, user_message: str, tools: Optional[List[ChiaTool]] = None, ) -> QueryResult: """Connect to each ChiaTool's MCP server, then run the Chat Completions loop: on ``tool_calls`` execute them against the MCP servers and feed ``role="tool"`` messages back, until the model stops requesting tools (or :attr:`max_tool_iterations` is hit). """ from contextlib import AsyncExitStack from mcp import ClientSession from mcp.client.streamable_http import streamable_http_client client = self._make_client() stream_parts: list[str] = [] stream_parts.append("=" * 80 + "\n") stream_parts.append( f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] chat.completions ({self.model})\n" ) stream_parts.append("=" * 80 + "\n\n") truncated = user_message[:500] + ("..." if len(user_message) > 500 else "") stream_parts.append(f"[User Message]\n{truncated}\n\n") async with AsyncExitStack() as stack: # --- Connect to every MCP server and gather tool schemas --- tool_schemas: list[dict] = [] dispatch: dict = {} # openai tool name -> (session, mcp tool name) for tool in tools or []: port = getattr(tool, "port", 8000) url = f"http://{tool.hostname}:{port}/{tool.name}/mcp" transport = await stack.enter_async_context(streamable_http_client(url)) read, write = transport[0], transport[1] session = await stack.enter_async_context(ClientSession(read, write)) await session.initialize() listed = await session.list_tools() for fn in listed.tools: api_name = f"{tool.name}__{fn.name}"[:64] tool_schemas.append({ "type": "function", "function": { "name": api_name, "description": fn.description or "", "parameters": fn.inputSchema or {"type": "object", "properties": {}}, }, }) dispatch[api_name] = (session, fn.name) messages: list[dict] = [] if self.system_message: messages.append({"role": "system", "content": self.system_message}) messages.append({"role": "user", "content": user_message}) meta = {"input_tokens": 0, "output_tokens": 0, "num_turns": 0} final_text = "" for _ in range(self.max_tool_iterations): kwargs: dict = { "model": self.model, "messages": messages, "max_tokens": self.max_tokens, } if tool_schemas: kwargs["tools"] = tool_schemas kwargs["tool_choice"] = "auto" try: resp = await client.chat.completions.create(**kwargs) except Exception as exc: translated = self._translate_error(exc) if translated is not None: raise translated from exc raise meta["num_turns"] += 1 usage = getattr(resp, "usage", None) if usage is not None: meta["input_tokens"] += getattr(usage, "prompt_tokens", 0) or 0 meta["output_tokens"] += getattr(usage, "completion_tokens", 0) or 0 choice = resp.choices[0] msg = choice.message finish_reason = choice.finish_reason tool_calls = getattr(msg, "tool_calls", None) or [] if msg.content: final_text = msg.content stream_parts.append(f"[Response]\n{msg.content}\n\n") # Echo the assistant turn back (content + tool_calls must # round-trip for the follow-up request). assistant_msg: dict = {"role": "assistant", "content": msg.content} if tool_calls: assistant_msg["tool_calls"] = [ { "id": tc.id, "type": "function", "function": { "name": tc.function.name, "arguments": tc.function.arguments, }, } for tc in tool_calls ] messages.append(assistant_msg) if finish_reason == "length": raise MaxOutputTokensError( node_id=self._get_node_id(), raw_message="response truncated at max_tokens", partial_text=final_text, ) if not tool_calls: break # --- Execute each requested tool over its MCP server --- for tc in tool_calls: name = tc.function.name try: args = json.loads(tc.function.arguments or "{}") except (json.JSONDecodeError, TypeError): args = None if args is None: result_text = f"Invalid tool arguments (not valid JSON): {tc.function.arguments!r}" is_error = True else: session, fn_name = dispatch.get(name, (None, None)) if session is None: result_text = f"Unknown tool: {name}" is_error = True else: stream_parts.append( f"[Tool Call: {name}]\nArgs: {json.dumps(args)[:2000]}\n\n" ) try: mcp_result = await session.call_tool(fn_name, args) result_text = self._mcp_result_to_text(mcp_result) is_error = bool(getattr(mcp_result, "isError", False)) except Exception as exc: result_text = f"Tool execution error: {exc}" is_error = True logged = result_text if len(logged) > 2000: logged = logged[:2000] + "\n... [truncated]" label = "Tool Result (error)" if is_error else "Tool Result" stream_parts.append(f"[{label}]\n{logged}\n\n") messages.append({ "role": "tool", "tool_call_id": tc.id, "content": result_text, }) else: stream_parts.append( f"[DEBUG] Reached max_tool_iterations={self.max_tool_iterations}\n" ) # --- Metadata + log file --- self._last_metadata = {k: v for k, v in meta.items() if v} stream_parts.append("-" * 80 + "\n\n") if self._log_prefix is not None: with open(f"{self._log_prefix}.log", "a") as f: f.write("".join(stream_parts)) return QueryResult( result=final_text, returncode=0, stderr="", stream_result="".join(stream_parts), ) @staticmethod def _mcp_result_to_text(result) -> str: """Flatten an MCP ``CallToolResult`` into plain text.""" parts = [] for item in getattr(result, "content", None) or []: text = getattr(item, "text", None) parts.append(text if text is not None else str(item)) return "\n".join(parts) def _translate_error(self, exc) -> Optional[Exception]: """Map an openai SDK error to a local typed error. Returns the local exception to raise, or ``None`` when *exc* is not a recognised openai API error (so the caller re-raises it). """ # MCP's anyio task group wraps errors that escape the tool-connected # AsyncExitStack in (possibly nested) ExceptionGroups — drill down to # the real error first, otherwise nothing classifies in the tools path. exc = _unwrap_exception_group(exc) # Already one of ours (e.g. translated inside the loop, then re-wrapped # by the task group) -> pass straight through. if isinstance(exc, OpenAICompatError): return exc try: import openai except Exception: return None node_id = self._get_node_id() msg = str(exc)[:300] err_code, err_text = _error_code_and_text(exc) def _is_billing(): return err_code in _BILLING_CODES or any(p in err_text for p in _BILLING_PHRASES) def _is_context_length(): return err_code in _CONTEXT_LENGTH_CODES or any( p in err_text for p in _CONTEXT_LENGTH_PHRASES ) # Billing/quota first: providers smuggle it inside a 429 # (insufficient_quota) or a 402, which would otherwise look like a rate # limit or a generic status error. Restrict phrase-matching to 4xx # (or status-less) so a 5xx message can't be misread as billing. if isinstance(exc, openai.APIStatusError): status = getattr(exc, "status_code", None) in_4xx = status is None or 400 <= status < 500 if status == 402 or (in_4xx and _is_billing()): return BillingError(node_id, status_code=status, raw_message=msg) if isinstance(exc, openai.RateLimitError): retry_after = 60 try: ra = exc.response.headers.get("retry-after") if ra: retry_after = int(ra) except Exception: pass reset_time = datetime.now(timezone.utc) + timedelta(seconds=retry_after) return RateLimitError( node_id=node_id, reset_time=reset_time, raw_message=msg, status_code=429, ) if isinstance(exc, (openai.AuthenticationError, openai.PermissionDeniedError)): return AuthenticationError( node_id, status_code=getattr(exc, "status_code", None), raw_message=msg ) if isinstance(exc, (openai.BadRequestError, openai.NotFoundError)): status = getattr(exc, "status_code", None) # Context-overflow comes back as a 400 with a length message; split # it out so callers can detect "prompt too big" specifically. if _is_context_length(): return ContextLengthExceededError(node_id, status_code=status, raw_message=msg) return InvalidRequestError(node_id, status_code=status, raw_message=msg) if isinstance(exc, (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError)): return ServerError( node_id, status_code=getattr(exc, "status_code", None), raw_message=msg ) if isinstance(exc, openai.APIStatusError): code = getattr(exc, "status_code", None) if isinstance(code, int) and code >= 500: return ServerError(node_id, status_code=code, raw_message=msg) # e.g. a 413 Payload Too Large used by some providers for overflow. if _is_context_length(): return ContextLengthExceededError(node_id, status_code=code, raw_message=msg) return UnknownOpenAIError(node_id, status_code=code, raw_message=msg) return None