"""Amazon Bedrock LLM backend built on the boto3 Converse API.
:class:`BedrockLLM` talks to **any** tool-capable chat model on Amazon
Bedrock (Claude, Amazon Nova, Llama, Mistral, Command R, ...) It uses
the Bedrock Runtime ``converse`` API, which normalises
messages, system prompts, and tool use across model families, and runs the
agentic tool loop client-side — executing each ChiaTool's MCP server over
HTTP exactly like :class:`chia.models.claude.ClaudeCodeLLM`'s API backend.
WARNING: experimental. Only exercised by the tests in
chia/models/tests/test_bedrock.py (mocked unit tests, plus opt-in live
tests). Not validated in production.
Auth/config come from the standard AWS chain (env vars, shared profile, or
IAM role); pass ``region`` or rely on ``AWS_REGION`` / ``AWS_DEFAULT_REGION``.
``boto3`` is imported lazily, so importing this module does not require it.
"""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, List, Optional
import ray
from chia.base.ChiaFunction import ChiaFunction
from chia.base.llm_call import QueryResult, LLMCallBase
if TYPE_CHECKING:
from chia.base.tools.ChiaTool import ChiaTool
# ---------------------------------------------------------------------------
# Exceptions
#
# A parallel taxonomy to claude.py's. Kept separate (rather than imported)
# so this module stands alone; each carries ``__reduce__`` for Ray transport.
# ---------------------------------------------------------------------------
[docs]
class BedrockError(Exception):
"""Base for all Bedrock backend errors."""
def __init__(
self,
node_id: str,
error_type: str,
status_code: str = "",
raw_message: str = "",
):
self.node_id = node_id
self.error_type = error_type
self.status_code = status_code
self.raw_message = raw_message
super().__init__(f"{error_type} on {node_id}: {raw_message[:200]}")
def __reduce__(self):
return (
self.__class__,
(self.node_id, self.error_type, self.status_code, self.raw_message),
)
[docs]
class RateLimitError(BedrockError):
"""Throttling / quota exhaustion (``ThrottlingException`` etc.)."""
def __init__(
self,
node_id: str,
reset_time: datetime,
raw_message: str = "",
status_code: str = "",
):
self.reset_time = reset_time
super().__init__(node_id, "rate_limit", status_code, raw_message)
def __reduce__(self):
return (
self.__class__,
(self.node_id, self.reset_time, self.raw_message, self.status_code),
)
[docs]
class AuthenticationError(BedrockError):
"""Invalid / expired / unauthorized AWS credentials."""
def __init__(self, node_id: str, status_code: str = "", raw_message: str = ""):
super().__init__(node_id, "authentication_failed", status_code, raw_message)
def __reduce__(self):
return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs]
class InvalidRequestError(BedrockError):
"""Malformed request or unknown model (``ValidationException`` etc.)."""
def __init__(self, node_id: str, status_code: str = "", raw_message: str = ""):
super().__init__(node_id, "invalid_request", status_code, raw_message)
def __reduce__(self):
return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs]
class ServerError(BedrockError):
"""Transient service-side failure (5xx, model timeout, connection)."""
def __init__(self, node_id: str, status_code: str = "", raw_message: str = ""):
super().__init__(node_id, "server_error", status_code, raw_message)
def __reduce__(self):
return (self.__class__, (self.node_id, self.status_code, self.raw_message))
[docs]
class MaxOutputTokensError(BedrockError):
"""The response was truncated at ``maxTokens``."""
def __init__(
self,
node_id: str,
status_code: str = "",
raw_message: str = "",
partial_text: str = "",
):
self.partial_text = partial_text
super().__init__(node_id, "max_output_tokens", status_code, raw_message)
def __reduce__(self):
return (
self.__class__,
(self.node_id, self.status_code, self.raw_message, self.partial_text),
)
[docs]
class UnknownBedrockError(BedrockError):
"""Unclassified Bedrock error."""
def __init__(self, node_id: str, status_code: str = "", raw_message: str = ""):
super().__init__(node_id, "unknown", status_code, raw_message)
def __reduce__(self):
return (self.__class__, (self.node_id, self.status_code, self.raw_message))
def _unwrap_exception_group(exc):
"""Drill through (possibly nested) ExceptionGroups to a representative leaf.
MCP's anyio task group re-wraps any error that escapes the tool-connected
``AsyncExitStack`` in an ``ExceptionGroup`` (the ``exceptiongroup`` backport
on Python < 3.11), which would otherwise hide a perfectly classifiable AWS
error. Prefer an already-typed :class:`BedrockError`, then a botocore error,
else the first leaf. Non-group exceptions are returned unchanged.
"""
subs = getattr(exc, "exceptions", None)
if not subs or not isinstance(subs, (list, tuple)):
return exc
leaves = [_unwrap_exception_group(s) for s in subs]
for leaf in leaves:
if isinstance(leaf, BedrockError):
return leaf
try:
from botocore.exceptions import BotoCoreError, ClientError
for leaf in leaves:
if isinstance(leaf, (ClientError, BotoCoreError)):
return leaf
except Exception:
pass
return leaves[0]
[docs]
class BedrockLLM(LLMCallBase):
"""Bedrock Converse-API LLM backend with client-side MCP tool execution.
Returns the same :class:`QueryResult` shape as the other backends so callers
are interchangeable; ``returncode`` is synthesised (0 on success, -1 when
every retry fails) and ``stderr`` is unused.
"""
def __init__(
self,
model: str,
system_message: str = "",
timeout_seconds: int = 600,
retries: int = 3,
logging_name: str = "bedrock_llm",
logging_level: int = logging.DEBUG,
log_dir: Optional[str] = None,
region: Optional[str] = None,
max_tokens: int = 16000,
max_tool_iterations: int = 100,
client_kwargs: Optional[dict] = None,
):
super().__init__(system_message=system_message)
self.logging_level = logging_level
self.logging_name = logging_name
self.retries = retries
self.timeout_seconds = timeout_seconds
self.model = model
self.region = (
region
or os.environ.get("AWS_REGION")
or os.environ.get("AWS_DEFAULT_REGION")
)
self.max_tokens = max_tokens
self.max_tool_iterations = max_tool_iterations
self.client_kwargs = client_kwargs or {}
self.logger = logging.getLogger(logging_name)
self._last_metadata: dict = {}
self.logger.warning(
"BedrockLLM is experimental: only exercised by unit tests so far, "
"not validated in production."
)
self._log_dir = log_dir
if log_dir is not None:
os.makedirs(log_dir, exist_ok=True)
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
self._log_prefix = os.path.join(log_dir, f"{logging_name}_{run_id}")
else:
self._log_prefix = None
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
@ChiaFunction(resources={"bedrock_creds": 0.01})
def prompt(
self,
user_message: str,
tools: Optional[List[ChiaTool]] = [],
) -> QueryResult:
"""Send *user_message* via the Bedrock Converse API and return the
response, retrying transient failures with the same policy the other
backends use.
"""
import time as _time
from chia.trace.profiler import get_profiler
profiler = get_profiler()
for attempt in range(self.retries):
try:
self._last_metadata = {}
cli = self._run_converse(user_message, tools)
self._last_metadata["model"] = self.model
self._last_metadata["tools"] = [
{"name": t.name, "hostname": getattr(t, "hostname", None),
"port": getattr(t, "port", None),
"node_id": getattr(t, "node_id", None)}
for t in tools
]
if profiler.enabled and self._last_metadata:
profiler.add_info(self._last_metadata)
cli.success = True
return cli
# -- Never retry: propagate immediately --
except (RateLimitError, AuthenticationError, InvalidRequestError):
raise
# -- Retry once: a shorter generation may fit --
except MaxOutputTokensError:
if attempt == 0:
self.logger.warning(
"Max output tokens on attempt %d/%d, retrying once",
attempt + 1, self.retries,
)
continue
raise
# -- Retry with exponential backoff: transient service issue --
except ServerError:
backoff = min(5 * 2 ** attempt, 60)
self.logger.warning(
"Server error on attempt %d/%d, backing off %ds",
attempt + 1, self.retries, backoff,
)
_time.sleep(backoff)
# -- Standard retry for unknown errors --
except UnknownBedrockError as exc:
self.logger.warning(
"Unknown error on attempt %d/%d: %s",
attempt + 1, self.retries, exc,
)
except Exception as exc:
self.logger.warning(
"Unexpected error on attempt %d/%d: %s",
attempt + 1, self.retries, exc,
)
return QueryResult(result="", returncode=-1, stderr="", stream_result="", success=False)
def _get_node_id(self) -> str:
try:
return ray.get_runtime_context().get_node_id()
except Exception:
return "unknown"
# ------------------------------------------------------------------
# Converse implementation
# ------------------------------------------------------------------
def _run_converse(
self,
user_message: str,
tools: Optional[List[ChiaTool]] = None,
) -> QueryResult:
"""Synchronous entry point; drives the async agent loop.
A second translation guard catches the case where a typed error (or a
raw botocore error) escapes the tool-connected ``AsyncExitStack`` wrapped
in an ``ExceptionGroup`` by MCP's anyio task group — without this, the
wrapped error never matches ``prompt()``'s ``except`` clauses and every
AWS error in the tools path is misfiled as "unexpected".
"""
try:
return self._run_coroutine(self._run_converse_async(user_message, tools or []))
except Exception as exc:
translated = self._translate_error(exc)
if translated is not None:
raise translated from exc
raise
def _run_coroutine(self, coro):
"""Run *coro* whether or not an event loop is already running."""
import asyncio
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, coro).result()
async def _run_converse_async(
self,
user_message: str,
tools: Optional[List[ChiaTool]] = None,
) -> QueryResult:
"""Connect to each ChiaTool's MCP server, then run the Converse loop.
Converse returns one assistant turn at a time; on ``stopReason ==
"tool_use"`` we execute the requested tools against their MCP servers
and feed ``toolResult`` blocks back, looping until the model stops
(or :attr:`max_tool_iterations` is hit).
"""
import asyncio
from contextlib import AsyncExitStack
import boto3
from mcp import ClientSession
from mcp.client.streamable_http import streamable_http_client
client = boto3.client(
"bedrock-runtime", region_name=self.region, **self.client_kwargs
)
stream_parts: list[str] = []
stream_parts.append("=" * 80 + "\n")
stream_parts.append(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Converse ({self.model})\n"
)
stream_parts.append("=" * 80 + "\n\n")
truncated = user_message[:500] + ("..." if len(user_message) > 500 else "")
stream_parts.append(f"[User Message]\n{truncated}\n\n")
async with AsyncExitStack() as stack:
# --- Connect to every MCP server and gather tool specs ---
tool_specs: list[dict] = []
dispatch: dict = {} # converse tool name -> (session, mcp tool name)
for tool in tools or []:
port = getattr(tool, "port", 8000)
url = f"http://{tool.hostname}:{port}/{tool.name}/mcp"
transport = await stack.enter_async_context(streamable_http_client(url))
read, write = transport[0], transport[1]
session = await stack.enter_async_context(ClientSession(read, write))
await session.initialize()
listed = await session.list_tools()
for fn in listed.tools:
api_name = f"{tool.name}__{fn.name}"[:64]
tool_specs.append({
"toolSpec": {
"name": api_name,
"description": fn.description or " ",
"inputSchema": {
"json": fn.inputSchema
or {"type": "object", "properties": {}}
},
}
})
dispatch[api_name] = (session, fn.name)
messages: list[dict] = [
{"role": "user", "content": [{"text": user_message}]}
]
meta = {"input_tokens": 0, "output_tokens": 0, "num_turns": 0}
final_text = ""
for _ in range(self.max_tool_iterations):
kwargs: dict = {
"modelId": self.model,
"messages": messages,
"inferenceConfig": {"maxTokens": self.max_tokens},
}
if self.system_message:
kwargs["system"] = [{"text": self.system_message}]
if tool_specs:
kwargs["toolConfig"] = {"tools": tool_specs}
try:
# boto3 is synchronous; keep the event loop free for MCP.
resp = await asyncio.to_thread(client.converse, **kwargs)
except Exception as exc:
translated = self._translate_error(exc)
if translated is not None:
raise translated from exc
raise
meta["num_turns"] += 1
usage = resp.get("usage", {})
meta["input_tokens"] += usage.get("inputTokens", 0) or 0
meta["output_tokens"] += usage.get("outputTokens", 0) or 0
out_message = resp["output"]["message"]
stop_reason = resp.get("stopReason")
# --- Log this turn's content blocks ---
turn_text_parts: list[str] = []
tool_uses = []
for block in out_message.get("content", []):
if "text" in block:
turn_text_parts.append(block["text"])
stream_parts.append(f"[Response]\n{block['text']}\n\n")
elif "reasoningContent" in block:
reasoning = block["reasoningContent"]
text = reasoning.get("reasoningText", {}).get("text", "")
if text:
stream_parts.append(f"[Thinking]\n{text}\n\n")
elif "toolUse" in block:
tu = block["toolUse"]
tool_uses.append(tu)
tool_input = json.dumps(tu.get("input", {}))
if len(tool_input) > 2000:
tool_input = tool_input[:2000] + "\n... [truncated]"
stream_parts.append(
f"[Tool Call: {tu.get('name')}]\nArgs: {tool_input}\n\n"
)
if turn_text_parts:
final_text = "".join(turn_text_parts)
# Echo the assistant turn back verbatim (toolUse blocks must
# round-trip for the follow-up request).
messages.append(out_message)
if stop_reason == "max_tokens":
raise MaxOutputTokensError(
node_id=self._get_node_id(),
raw_message="response truncated at maxTokens",
partial_text=final_text,
)
if stop_reason != "tool_use" or not tool_uses:
break
# --- Execute each requested tool over its MCP server ---
result_blocks = []
for tu in tool_uses:
name = tu.get("name")
tool_use_id = tu.get("toolUseId")
session, fn_name = dispatch.get(name, (None, None))
if session is None:
result_text = f"Unknown tool: {name}"
is_error = True
else:
try:
mcp_result = await session.call_tool(
fn_name, tu.get("input") or {}
)
result_text = self._mcp_result_to_text(mcp_result)
is_error = bool(getattr(mcp_result, "isError", False))
except Exception as exc:
result_text = f"Tool execution error: {exc}"
is_error = True
logged = result_text
if len(logged) > 2000:
logged = logged[:2000] + "\n... [truncated]"
label = "Tool Result (error)" if is_error else "Tool Result"
stream_parts.append(f"[{label}]\n{logged}\n\n")
result_blocks.append({
"toolResult": {
"toolUseId": tool_use_id,
"content": [{"text": result_text}],
"status": "error" if is_error else "success",
}
})
messages.append({"role": "user", "content": result_blocks})
else:
stream_parts.append(
f"[DEBUG] Reached max_tool_iterations={self.max_tool_iterations}\n"
)
# --- Metadata + log file ---
self._last_metadata = {k: v for k, v in meta.items() if v}
stream_parts.append("-" * 80 + "\n\n")
if self._log_prefix is not None:
with open(f"{self._log_prefix}.log", "a") as f:
f.write("".join(stream_parts))
return QueryResult(
result=final_text,
returncode=0,
stderr="",
stream_result="".join(stream_parts),
)
@staticmethod
def _mcp_result_to_text(result) -> str:
"""Flatten an MCP ``CallToolResult`` into plain text."""
parts = []
for item in getattr(result, "content", None) or []:
text = getattr(item, "text", None)
parts.append(text if text is not None else str(item))
return "\n".join(parts)
def _translate_error(self, exc) -> Optional[Exception]:
"""Map a botocore exception to a local typed error.
Returns the local exception to raise, or ``None`` when *exc* is not a
recognised AWS/botocore error (so the caller re-raises it untouched).
"""
# MCP's anyio task group wraps errors that escape the tool-connected
# AsyncExitStack in (possibly nested) ExceptionGroups — drill down to
# the real error first, otherwise nothing classifies in the tools path.
exc = _unwrap_exception_group(exc)
# Already one of ours (e.g. translated inside the loop, then re-wrapped
# by the task group) -> pass straight through.
if isinstance(exc, BedrockError):
return exc
try:
from botocore.exceptions import (
BotoCoreError,
ClientError,
ConnectTimeoutError,
EndpointConnectionError,
ParamValidationError,
ReadTimeoutError,
)
except Exception:
return None
node_id = self._get_node_id()
if isinstance(exc, ClientError):
code = exc.response.get("Error", {}).get("Code", "")
msg = exc.response.get("Error", {}).get("Message", str(exc))[:300]
if code in ("ThrottlingException", "TooManyRequestsException",
"ServiceQuotaExceededException"):
reset_time = datetime.now(timezone.utc) + timedelta(seconds=60)
return RateLimitError(
node_id=node_id, reset_time=reset_time,
raw_message=msg, status_code=code,
)
if code in ("AccessDeniedException", "UnauthorizedException",
"ExpiredTokenException", "InvalidSignatureException",
"UnrecognizedClientException"):
return AuthenticationError(node_id, status_code=code, raw_message=msg)
if code in ("ValidationException", "ResourceNotFoundException",
"ServiceUnavailableException"):
# ServiceUnavailable is transient -> ServerError below; everything
# else here is a bad request.
if code == "ServiceUnavailableException":
return ServerError(node_id, status_code=code, raw_message=msg)
return InvalidRequestError(node_id, status_code=code, raw_message=msg)
if code in ("InternalServerException", "ModelTimeoutException",
"ModelNotReadyException", "ModelErrorException"):
return ServerError(node_id, status_code=code, raw_message=msg)
return UnknownBedrockError(node_id, status_code=code, raw_message=msg)
if isinstance(exc, (EndpointConnectionError, ConnectTimeoutError,
ReadTimeoutError)):
return ServerError(node_id, raw_message=str(exc)[:300])
# Client-side parameter validation (e.g. maxTokens below the minimum) is
# a deterministic bad request — never-retry, not an Unknown that burns
# the retry budget.
if isinstance(exc, ParamValidationError):
return InvalidRequestError(node_id, raw_message=str(exc)[:300])
if isinstance(exc, BotoCoreError):
return UnknownBedrockError(node_id, raw_message=str(exc)[:300])
return None