Source code for chia.models.vertex

"""Google Vertex AI LLM backend built on the google-genai SDK.

:class:`VertexGeminiLLM` runs Google's **Gemini** models on Vertex AI and drives the
agentic tool loop client-side — executing each ChiaTool's MCP server over HTTP,
exactly like the Bedrock and Claude API backends.

Vertex has **no single unified API** across model families —
Gemini goes through google-genai, Claude-on-Vertex through ``AnthropicVertex``,
and Llama/Mistral through OpenAI-compatible MaaS endpoints. Two separate
classes are provided here, one for Gemini and one for MaaS.

WARNING: experimental. Only exercised by the tests in
chia/models/tests/test_vertex.py (mocked unit tests, plus opt-in live tests).
Not validated in production.

Auth/config: Vertex needs a GCP project + location and Application Default
Credentials (``gcloud auth application-default login``, a service-account key
via ``GOOGLE_APPLICATION_CREDENTIALS``, or a workload identity). Pass
``project``/``location`` or rely on ``GOOGLE_CLOUD_PROJECT`` /
``GOOGLE_CLOUD_LOCATION``. ``google-genai`` is imported lazily.
"""

from __future__ import annotations

import json
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, List, Optional

import ray

from chia.base.ChiaFunction import ChiaFunction
from chia.base.llm_call import QueryResult, LLMCallBase
from chia.models.openai_compat import OpenAICompatLLM

if TYPE_CHECKING:
    from chia.base.tools.ChiaTool import ChiaTool


# ---------------------------------------------------------------------------
# Exceptions
#
# A parallel taxonomy to claude.py / bedrock.py. Kept separate so this module
# stands alone; each carries ``__reduce__`` for Ray transport.
# ---------------------------------------------------------------------------


[docs] class VertexError(Exception): """Base for all Vertex backend errors.""" def __init__( self, node_id: str, error_type: str, status_code: Optional[int] = None, raw_message: str = "", ): self.node_id = node_id self.error_type = error_type self.status_code = status_code self.raw_message = raw_message super().__init__(f"{error_type} on {node_id}: {raw_message[:200]}") def __reduce__(self): return ( self.__class__, (self.node_id, self.error_type, self.status_code, self.raw_message), )
[docs] class RateLimitError(VertexError): """Quota / rate exhaustion (HTTP 429, ``RESOURCE_EXHAUSTED``).""" def __init__( self, node_id: str, reset_time: datetime, raw_message: str = "", status_code: Optional[int] = None, ): self.reset_time = reset_time super().__init__(node_id, "rate_limit", status_code, raw_message) def __reduce__(self): return ( self.__class__, (self.node_id, self.reset_time, self.raw_message, self.status_code), )
[docs] class AuthenticationError(VertexError): """Invalid / missing credentials or permission (HTTP 401 / 403).""" def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "authentication_failed", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class InvalidRequestError(VertexError): """Malformed request or unknown model (HTTP 400 / 404).""" def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "invalid_request", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class ServerError(VertexError): """Transient service-side failure (HTTP 5xx).""" def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "server_error", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class MaxOutputTokensError(VertexError): """The response was truncated at ``max_output_tokens``.""" def __init__( self, node_id: str, status_code: Optional[int] = None, raw_message: str = "", partial_text: str = "", ): self.partial_text = partial_text super().__init__(node_id, "max_output_tokens", status_code, raw_message) def __reduce__(self): return ( self.__class__, (self.node_id, self.status_code, self.raw_message, self.partial_text), )
[docs] class UnknownVertexError(VertexError): """Unclassified Vertex error.""" def __init__(self, node_id: str, status_code: Optional[int] = None, raw_message: str = ""): super().__init__(node_id, "unknown", status_code, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs] class ContentBlockedError(VertexError): """The model returned no usable content because the prompt or the response was blocked (safety, recitation, blocklist, ...). This is NOT an HTTP/API error — Gemini reports it as a 200-OK response whose candidate carries a blocking ``finish_reason`` (or whose ``prompt_feedback`` carries a ``block_reason`` with no candidates). Surfacing it as a typed error keeps a block from masquerading as a successful empty answer. Never automatically retried: re-sending the same prompt will be blocked again. """ def __init__(self, node_id: str, block_reason: str = "", raw_message: str = ""): self.block_reason = block_reason super().__init__(node_id, "content_blocked", None, raw_message) def __reduce__(self): return (self.__class__, (self.node_id, self.block_reason, self.raw_message))
# Gemini ``finish_reason`` values that mean the candidate was blocked / unusable # rather than a normal STOP or a MAX_TOKENS truncation (which is handled # separately as MaxOutputTokensError). _BLOCK_FINISH_REASONS = frozenset({ "SAFETY", "RECITATION", "BLOCKLIST", "PROHIBITED_CONTENT", "SPII", "IMAGE_SAFETY", }) def _unwrap_exception_group(exc): """Drill through (possibly nested) ExceptionGroups to a representative leaf. MCP's anyio task group re-wraps any error that escapes the tool-connected ``AsyncExitStack`` in an ``ExceptionGroup`` (the ``exceptiongroup`` backport on Python < 3.11), which would otherwise hide a perfectly classifiable google-genai error. Prefer an already-typed :class:`VertexError`, then a google-genai ``APIError``, else the first leaf. Non-group exceptions are returned unchanged. """ subs = getattr(exc, "exceptions", None) if not subs or not isinstance(subs, (list, tuple)): return exc leaves = [_unwrap_exception_group(s) for s in subs] for leaf in leaves: if isinstance(leaf, VertexError): return leaf try: from google.genai import errors as genai_errors for leaf in leaves: if isinstance(leaf, genai_errors.APIError): return leaf except Exception: pass return leaves[0]
[docs] class VertexGeminiLLM(LLMCallBase): """Gemini-on-Vertex LLM backend with client-side MCP tool execution. Returns the same :class:`QueryResult` shape as the other backends so callers are interchangeable; ``returncode`` is synthesised (0 on success, -1 when every retry fails) and ``stderr`` is unused. """ def __init__( self, model: str, system_message: str = "", timeout_seconds: int = 600, retries: int = 3, logging_name: str = "vertex_gemini", logging_level: int = logging.DEBUG, log_dir: Optional[str] = None, project: Optional[str] = None, location: Optional[str] = None, max_tokens: int = 16000, max_tool_iterations: int = 100, client_kwargs: Optional[dict] = None, ): super().__init__(system_message=system_message) self.logging_level = logging_level self.logging_name = logging_name self.retries = retries self.timeout_seconds = timeout_seconds self.model = model self.project = project or os.environ.get("GOOGLE_CLOUD_PROJECT") self.location = ( location or os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1" ) self.max_tokens = max_tokens self.max_tool_iterations = max_tool_iterations self.client_kwargs = client_kwargs or {} self.logger = logging.getLogger(logging_name) self._last_metadata: dict = {} self.logger.warning( "VertexGeminiLLM is experimental: only exercised by unit tests so far, " "not validated in production." ) self._log_dir = log_dir if log_dir is not None: os.makedirs(log_dir, exist_ok=True) run_id = datetime.now().strftime("%Y%m%d_%H%M%S") self._log_prefix = os.path.join(log_dir, f"{logging_name}_{run_id}") else: self._log_prefix = None # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] @ChiaFunction(resources={"vertex_creds": 0.01}) def prompt( self, user_message: str, tools: Optional[List[ChiaTool]] = [], ) -> QueryResult: """Send *user_message* via Gemini on Vertex and return the response, retrying transient failures with the same policy the other backends use. """ import time as _time from chia.trace.profiler import get_profiler profiler = get_profiler() for attempt in range(self.retries): try: self._last_metadata = {} cli = self._run_generate(user_message, tools) self._last_metadata["model"] = self.model self._last_metadata["tools"] = [ {"name": t.name, "hostname": getattr(t, "hostname", None), "port": getattr(t, "port", None), "node_id": getattr(t, "node_id", None)} for t in tools ] if profiler.enabled and self._last_metadata: profiler.add_info(self._last_metadata) cli.success = True return cli # -- Never retry: propagate immediately -- except (RateLimitError, AuthenticationError, InvalidRequestError, ContentBlockedError): raise # -- Retry once: a shorter generation may fit -- except MaxOutputTokensError: if attempt == 0: self.logger.warning( "Max output tokens on attempt %d/%d, retrying once", attempt + 1, self.retries, ) continue raise # -- Retry with exponential backoff: transient service issue -- except ServerError: backoff = min(5 * 2 ** attempt, 60) self.logger.warning( "Server error on attempt %d/%d, backing off %ds", attempt + 1, self.retries, backoff, ) _time.sleep(backoff) # -- Standard retry for unknown errors -- except UnknownVertexError as exc: self.logger.warning( "Unknown error on attempt %d/%d: %s", attempt + 1, self.retries, exc, ) except Exception as exc: self.logger.warning( "Unexpected error on attempt %d/%d: %s", attempt + 1, self.retries, exc, ) return QueryResult(result="", returncode=-1, stderr="", stream_result="", success=False)
def _get_node_id(self) -> str: try: return ray.get_runtime_context().get_node_id() except Exception: return "unknown" # ------------------------------------------------------------------ # generate_content implementation # ------------------------------------------------------------------ def _run_generate( self, user_message: str, tools: Optional[List[ChiaTool]] = None, ) -> QueryResult: """Synchronous entry point; drives the async agent loop. A second translation guard catches the case where a typed error (or a raw google-genai error) escapes the tool-connected ``AsyncExitStack`` wrapped in an ``ExceptionGroup`` by MCP's anyio task group — without this, the wrapped error never matches ``prompt()``'s ``except`` clauses and every API error in the tools path is misfiled as "unexpected". """ try: return self._run_coroutine(self._run_generate_async(user_message, tools or [])) except Exception as exc: translated = self._translate_error(exc) if translated is not None: raise translated from exc raise def _run_coroutine(self, coro): """Run *coro* whether or not an event loop is already running.""" import asyncio try: asyncio.get_running_loop() except RuntimeError: return asyncio.run(coro) import concurrent.futures with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: return pool.submit(asyncio.run, coro).result() async def _run_generate_async( self, user_message: str, tools: Optional[List[ChiaTool]] = None, ) -> QueryResult: """Connect to each ChiaTool's MCP server, then run the Gemini loop. Gemini returns one model turn at a time; when that turn contains ``function_call`` parts we execute them against their MCP servers and feed ``function_response`` parts back, looping until the model returns no more calls (or :attr:`max_tool_iterations` is hit). """ import asyncio from contextlib import AsyncExitStack from google import genai from google.genai import types from mcp import ClientSession from mcp.client.streamable_http import streamable_http_client client = genai.Client( vertexai=True, project=self.project, location=self.location, **self.client_kwargs, ) stream_parts: list[str] = [] stream_parts.append("=" * 80 + "\n") stream_parts.append( f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] generate_content ({self.model})\n" ) stream_parts.append("=" * 80 + "\n\n") truncated = user_message[:500] + ("..." if len(user_message) > 500 else "") stream_parts.append(f"[User Message]\n{truncated}\n\n") async with AsyncExitStack() as stack: # --- Connect to every MCP server and gather function declarations --- function_decls = [] dispatch: dict = {} # gemini function name -> (session, mcp tool name) for tool in tools or []: port = getattr(tool, "port", 8000) url = f"http://{tool.hostname}:{port}/{tool.name}/mcp" transport = await stack.enter_async_context(streamable_http_client(url)) read, write = transport[0], transport[1] session = await stack.enter_async_context(ClientSession(read, write)) await session.initialize() listed = await session.list_tools() for fn in listed.tools: api_name = f"{tool.name}__{fn.name}"[:64] function_decls.append(types.FunctionDeclaration( name=api_name, description=fn.description or " ", parameters=self._sanitize_schema( fn.inputSchema or {"type": "object", "properties": {}} ), )) dispatch[api_name] = (session, fn.name) config_kwargs: dict = { "max_output_tokens": self.max_tokens, # We run the loop ourselves; don't let the SDK auto-call. "automatic_function_calling": types.AutomaticFunctionCallingConfig( disable=True ), } if self.system_message: config_kwargs["system_instruction"] = self.system_message if function_decls: config_kwargs["tools"] = [types.Tool(function_declarations=function_decls)] config = types.GenerateContentConfig(**config_kwargs) contents = [types.Content( role="user", parts=[types.Part.from_text(text=user_message)] )] meta = {"input_tokens": 0, "output_tokens": 0, "num_turns": 0} final_text = "" for _ in range(self.max_tool_iterations): try: resp = await asyncio.to_thread( client.models.generate_content, model=self.model, contents=contents, config=config, ) except Exception as exc: translated = self._translate_error(exc) if translated is not None: raise translated from exc raise meta["num_turns"] += 1 usage = getattr(resp, "usage_metadata", None) if usage is not None: meta["input_tokens"] += getattr(usage, "prompt_token_count", 0) or 0 meta["output_tokens"] += getattr(usage, "candidates_token_count", 0) or 0 candidate = (resp.candidates or [None])[0] if candidate is None or candidate.content is None: # No candidate at all -> the prompt itself may have been # blocked (safety etc.); surface that rather than returning # a silent empty success. block = self._prompt_block_reason(resp) if block: raise ContentBlockedError( node_id=self._get_node_id(), block_reason=block, raw_message=f"prompt blocked: {block}", ) break finish = getattr(candidate, "finish_reason", None) finish_name = getattr(finish, "name", str(finish) if finish else "") # --- Log this turn's parts; collect function calls --- turn_text_parts: list[str] = [] function_calls = [] for part in candidate.content.parts or []: fc = getattr(part, "function_call", None) if fc is not None: function_calls.append(fc) args_str = json.dumps(dict(fc.args or {})) if len(args_str) > 2000: args_str = args_str[:2000] + "\n... [truncated]" stream_parts.append( f"[Tool Call: {fc.name}]\nArgs: {args_str}\n\n" ) elif getattr(part, "thought", False) and part.text: stream_parts.append(f"[Thinking]\n{part.text}\n\n") elif getattr(part, "text", None): turn_text_parts.append(part.text) stream_parts.append(f"[Response]\n{part.text}\n\n") if turn_text_parts: final_text = "".join(turn_text_parts) # Echo the model turn back verbatim (function_call parts must # round-trip for the follow-up request). contents.append(candidate.content) if finish_name == "MAX_TOKENS": raise MaxOutputTokensError( node_id=self._get_node_id(), raw_message="response truncated at max_output_tokens", partial_text=final_text, ) # The response was content-blocked (safety/recitation/...) -> the # turn has no trustworthy output; surface it instead of breaking # into a silent empty success. if finish_name in _BLOCK_FINISH_REASONS and not function_calls: raise ContentBlockedError( node_id=self._get_node_id(), block_reason=finish_name, raw_message=f"response blocked: {finish_name}", ) if not function_calls: break # --- Execute each requested tool over its MCP server --- response_parts = [] for fc in function_calls: session, fn_name = dispatch.get(fc.name, (None, None)) if session is None: result_text = f"Unknown tool: {fc.name}" is_error = True else: try: mcp_result = await session.call_tool( fn_name, dict(fc.args or {}) ) result_text = self._mcp_result_to_text(mcp_result) is_error = bool(getattr(mcp_result, "isError", False)) except Exception as exc: result_text = f"Tool execution error: {exc}" is_error = True logged = result_text if len(logged) > 2000: logged = logged[:2000] + "\n... [truncated]" label = "Tool Result (error)" if is_error else "Tool Result" stream_parts.append(f"[{label}]\n{logged}\n\n") response_parts.append(types.Part.from_function_response( name=fc.name, response={"error": result_text} if is_error else {"result": result_text}, )) contents.append(types.Content(role="user", parts=response_parts)) else: stream_parts.append( f"[DEBUG] Reached max_tool_iterations={self.max_tool_iterations}\n" ) # --- Metadata + log file --- self._last_metadata = {k: v for k, v in meta.items() if v} stream_parts.append("-" * 80 + "\n\n") if self._log_prefix is not None: with open(f"{self._log_prefix}.log", "a") as f: f.write("".join(stream_parts)) return QueryResult( result=final_text, returncode=0, stderr="", stream_result="".join(stream_parts), ) @staticmethod def _sanitize_schema(schema): """Drop JSON-schema keys Gemini's function-declaration schema rejects. Gemini accepts only a subset of OpenAPI schema; keys like ``$schema``/``additionalProperties``/``title`` cause 400s, so strip them recursively. MCP servers often emit them. """ drop = {"$schema", "$id", "$defs", "definitions", "additionalProperties", "title"} if isinstance(schema, dict): return { k: VertexGeminiLLM._sanitize_schema(v) for k, v in schema.items() if k not in drop } if isinstance(schema, list): return [VertexGeminiLLM._sanitize_schema(v) for v in schema] return schema @staticmethod def _prompt_block_reason(resp) -> str: """The ``prompt_feedback.block_reason`` name when the prompt was blocked (no candidates produced), else ``""``.""" feedback = getattr(resp, "prompt_feedback", None) if feedback is None: return "" reason = getattr(feedback, "block_reason", None) if reason is None: return "" return getattr(reason, "name", str(reason)) @staticmethod def _mcp_result_to_text(result) -> str: """Flatten an MCP ``CallToolResult`` into plain text.""" parts = [] for item in getattr(result, "content", None) or []: text = getattr(item, "text", None) parts.append(text if text is not None else str(item)) return "\n".join(parts) def _translate_error(self, exc) -> Optional[Exception]: """Map a google-genai error to a local typed error. Returns the local exception to raise, or ``None`` when *exc* is not a recognised google-genai API error (so the caller re-raises it). """ # MCP's anyio task group wraps errors that escape the tool-connected # AsyncExitStack in (possibly nested) ExceptionGroups — drill down to # the real error first, otherwise nothing classifies in the tools path. exc = _unwrap_exception_group(exc) # Already one of ours (e.g. translated inside the loop, then re-wrapped # by the task group) -> pass straight through. if isinstance(exc, VertexError): return exc try: from google.genai import errors as genai_errors except Exception: return None if not isinstance(exc, genai_errors.APIError): return None node_id = self._get_node_id() code = getattr(exc, "code", None) msg = (getattr(exc, "message", None) or str(exc))[:300] if code == 429: reset_time = datetime.now(timezone.utc) + timedelta(seconds=60) return RateLimitError( node_id=node_id, reset_time=reset_time, raw_message=msg, status_code=code, ) if code in (401, 403): return AuthenticationError(node_id, status_code=code, raw_message=msg) if code in (400, 404): return InvalidRequestError(node_id, status_code=code, raw_message=msg) if isinstance(code, int) and code >= 500: return ServerError(node_id, status_code=code, raw_message=msg) return UnknownVertexError(node_id, status_code=code, raw_message=msg)
# --------------------------------------------------------------------------- # Non-Gemini Vertex models via the OpenAI-compatible MaaS endpoint # --------------------------------------------------------------------------- def _vertex_adc_token_provider() -> str: """Mint a fresh GCP access token from Application Default Credentials. Used as :class:`OpenAICompatLLM`'s ``token_provider`` for Vertex MaaS, where auth is a short-lived GCP bearer token rather than a static key. It's invoked each time the OpenAI client is built, so the token is always current. Runs wherever the client is constructed — ADC must be available there (``gcloud auth application-default login`` / service account / WI). """ import google.auth import google.auth.transport.requests creds, _ = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"] ) creds.refresh(google.auth.transport.requests.Request()) return creds.token
[docs] class VertexGenericLLM(OpenAICompatLLM): """Non-Gemini Vertex models (Llama, Mistral, ...) via the Vertex Model-as-a- Service **OpenAI-compatible** endpoint. Vertex has no single unified API, so Gemini uses google-genai (:class:`VertexGeminiLLM`) while the open/partner families are reached through the OpenAI-compatible MaaS endpoint — which is exactly what :class:`OpenAICompatLLM` already speaks. The only Vertex-specifics are the endpoint URL (built from project/location) and auth: a GCP ADC bearer token that rotates, supplied via ``token_provider``. Everything else — the agent loop, tool handling, error translation — is inherited unchanged. ``model`` is the MaaS model id, e.g. ``meta/llama-3.1-8b-instruct-maas``. """ def __init__( self, model: str, project: Optional[str] = None, location: Optional[str] = None, **kwargs, ): project = project or os.environ.get("GOOGLE_CLOUD_PROJECT") location = ( location or os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1" ) kwargs.setdefault( "base_url", f"https://{location}-aiplatform.googleapis.com/v1beta1/" f"projects/{project}/locations/{location}/endpoints/openapi", ) kwargs.setdefault("logging_name", "vertex_maas") kwargs.setdefault("token_provider", _vertex_adc_token_provider) # User ADC requires a billing/quota project. google-auth normally sends # it via the ``x-goog-user-project`` header from the ADC quota project, # but we hand the OpenAI client only a bearer token (the token_provider), # so that header would be lost and Vertex 403s with "requires a quota # project". Set it explicitly when we know the project. Harmless for # service-account creds (they carry their own project). if project: client_kwargs = dict(kwargs.get("client_kwargs") or {}) headers = dict(client_kwargs.get("default_headers") or {}) headers.setdefault("x-goog-user-project", project) client_kwargs["default_headers"] = headers kwargs["client_kwargs"] = client_kwargs super().__init__(model=model, **kwargs)
[docs] @ChiaFunction(resources={"vertex_creds": 0.01}) def prompt( self, user_message: str, tools: Optional[List[ChiaTool]] = [], ) -> QueryResult: return OpenAICompatLLM.prompt(self, user_message, tools)