"""Codex CLI LLM backend.
``CodexLLM`` wraps ``codex exec`` behind the same synchronous ``prompt`` shape
as the other Chia LLM backends. Chia MCP tools are passed as per-run Codex
config overrides, so this backend does not mutate the user's persistent Codex
configuration.
"""
from __future__ import annotations
import json
import logging
import os
import re
import subprocess
import tempfile
from glob import glob
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any
import ray
from chia.base.ChiaFunction import ChiaFunction, ObjectRefCallback
from chia.base.llm_call import QueryResult, LLMCallBase
if TYPE_CHECKING:
from chia.base.tools.ChiaTool import ChiaTool
[docs]
class CodexError(Exception):
"""Base for Codex CLI errors. Subclasses are Ray-serializable."""
error_type = "unknown"
def __init__(self, node_id: str, exit_code: int = -1, raw_message: str = ""):
self.node_id = node_id
self.exit_code = exit_code
self.raw_message = raw_message
super().__init__(f"{self.error_type} on {node_id}: {raw_message[:200]}")
def __reduce__(self):
return (self.__class__, (self.node_id, self.exit_code, self.raw_message))
[docs]
class RateLimitError(CodexError):
error_type = "rate_limit"
def __init__(
self,
node_id: str,
reset_time: datetime | None = None,
raw_message: str = "",
exit_code: int = -1,
):
self.reset_time = reset_time or datetime.now(timezone.utc) + timedelta(minutes=1)
super().__init__(node_id=node_id, exit_code=exit_code, raw_message=raw_message)
def __reduce__(self):
return (
self.__class__,
(self.node_id, self.reset_time, self.raw_message, self.exit_code),
)
[docs]
class AuthenticationError(CodexError):
error_type = "authentication_failed"
[docs]
class BillingError(CodexError):
error_type = "billing_error"
[docs]
class InvalidRequestError(CodexError):
error_type = "invalid_request"
[docs]
class ServerError(CodexError):
error_type = "server_error"
[docs]
class MaxOutputTokensError(CodexError):
error_type = "max_output_tokens"
[docs]
class UnknownCodexError(CodexError):
error_type = "unknown"
_RESET_RE = re.compile(
r"(?:reset|resets|retry(?:\s|-)?after)\D+(\d{1,2})\s*(am|pm)?(?:\s*\(([^)]+)\))?",
re.IGNORECASE,
)
_ERROR_PATTERNS: tuple[tuple[type[CodexError], tuple[str, ...]], ...] = (
(
AuthenticationError,
("not logged in", "login", "authentication", "unauthorized", "401", "api key", "auth token"),
),
(BillingError, ("billing", "payment", "402", "credit", "quota exceeded")),
(
InvalidRequestError,
("invalid request", "malformed", "bad request", "invalid model", "unknown model",
"invalid config", "unrecognized option", "400"),
),
(
ServerError,
("500", "503", "server error", "overloaded", "internal error", "service unavailable",
"connection", "timeout", "timed out"),
),
(
MaxOutputTokensError,
("max output", "maximum output", "output token limit", "context length",
"context window", "truncated"),
),
)
_TOKEN_ALIASES = {
"input_tokens": ("input_tokens", "prompt_tokens", "input"),
"output_tokens": ("output_tokens", "completion_tokens", "output"),
"total_tokens": ("total_tokens",),
"reasoning_tokens": ("reasoning_tokens", "reasoning_output_tokens", "reasoning"),
"cache_read_input_tokens": ("cached_input_tokens", "cache_read_input_tokens"),
"cache_creation_input_tokens": ("cache_creation_input_tokens",),
}
_RATE_LIMIT_429_RE = re.compile(
r"\b(?:http(?:\s+status)?|status(?:code)?|code|apierror|error)\s*[:=]?\s*429\b"
r"|\b429\s+(?:too many requests|rate limit(?:ed)?)\b",
re.IGNORECASE,
)
_UUID_RE = re.compile(
r"\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b"
)
_CODEX_SESSION_STATE_PATTERNS = (
"state_*.sqlite",
"state_*.sqlite-wal",
"state_*.sqlite-shm",
)
[docs]
@dataclass
class CodexQueryResult(QueryResult):
"""QueryResult specialized for the Codex CLI backend.
Codex stores resumable conversations under ``CODEX_HOME`` rather than in a
portable JSONL transcript. The state fields carry the needed opaque files
back to the caller so a later ``codex exec resume <session_id> -`` can run
on a different Chia worker.
"""
session_id: str | None = None
session_state: dict[str, bytes] | None = None
session_state_paths: tuple[str, ...] = ()
[docs]
def parse_session_id(stdout: str) -> str | None:
"""Extract a Codex session id from JSONL stdout."""
for line in stdout.splitlines():
event = CodexLLM._json_or_none(line.strip())
if event is None:
continue
sid = _find_session_id(event)
if sid:
return sid
lower = stdout.lower()
if any(token in lower for token in ("session", "conversation", "thread")):
match = _UUID_RE.search(stdout)
if match:
return match.group(0)
return None
def _find_session_id(value: Any) -> str | None:
if isinstance(value, dict):
type_text = str(value.get("type") or value.get("event") or "").lower()
id_context = any(token in type_text for token in ("session", "conversation", "thread"))
for key, item in value.items():
key_text = str(key).lower()
if isinstance(item, str):
if any(token in key_text for token in ("session", "conversation", "thread")):
return item
if key_text == "id" and id_context:
return item
for item in value.values():
sid = _find_session_id(item)
if sid:
return sid
elif isinstance(value, list):
for item in value:
sid = _find_session_id(item)
if sid:
return sid
return None
def _session_tracked(chia_fn):
"""Attach Codex session sync to remote prompt calls when persistence is on."""
def _wrap(instance, ref):
if getattr(instance, "_resume_session", False):
return ObjectRefCallback(ref, instance._sync_session)
return ref
class _TrackedHandle:
def __init__(self, inner_handle, instance):
self._inner = inner_handle
self._instance = instance
def chia_remote(self, *args, **kwargs):
return _wrap(self._instance, self._inner.chia_remote(*args, **kwargs))
def remote(self, *args, **kwargs):
return self.chia_remote(*args, **kwargs)
class _BoundTracked:
def __init__(self, instance):
self._instance = instance
def __call__(self, *args, **kwargs):
original = getattr(chia_fn, "_chia_original", chia_fn)
return original(self._instance, *args, **kwargs)
def chia_remote(self, *args, **kwargs):
return _wrap(self._instance, chia_fn.chia_remote(*args, **kwargs))
def options(self, **opts):
return _TrackedHandle(chia_fn.options(**opts), self._instance)
def __getattr__(self, name):
return getattr(chia_fn, name)
class _TrackedDescriptor:
def __get__(self, obj, objtype=None):
if obj is None:
return chia_fn
return _BoundTracked(obj)
def __getattr__(self, name):
return getattr(chia_fn, name)
return _TrackedDescriptor()
[docs]
def parse_rate_limit_reset(text: str) -> datetime | None:
"""Parse a human reset time such as ``resets 4pm (America/Los_Angeles)``."""
match = _RESET_RE.search(text)
if match is None:
return None
hour = int(match.group(1))
ampm = (match.group(2) or "").lower()
if ampm == "pm" and hour != 12:
hour += 12
elif ampm == "am" and hour == 12:
hour = 0
try:
import zoneinfo
tz = zoneinfo.ZoneInfo((match.group(3) or "UTC").strip())
except Exception:
tz = timezone.utc
now = datetime.now(tz)
reset = now.replace(hour=hour, minute=0, second=0, microsecond=0)
if reset <= now:
reset += timedelta(days=1)
return reset.astimezone(timezone.utc)
def _toml(value: str) -> str:
return json.dumps(value)
def _toml_key(value: str) -> str:
return value if re.fullmatch(r"[A-Za-z0-9_-]+", value) else json.dumps(value)
def _truncate(text: str, limit: int = 2000) -> str:
return text if len(text) <= limit else text[:limit] + "\n... [truncated]"
def _payload(event: dict) -> dict:
payload = event.get("payload")
if isinstance(payload, dict):
return payload
item = event.get("item")
return item if isinstance(item, dict) else event
[docs]
class CodexLLM(LLMCallBase):
"""Wrap ``codex exec`` as a Chia LLM backend."""
def __init__(
self,
model: str | None = None,
system_message: str = "",
timeout_seconds: int = 600,
retries: int = 3,
logging_name: str = "codex",
logging_level: int = logging.DEBUG,
log_dir: str | None = None,
codex_bin: str = "codex",
work_dir: str | None = None,
extra_cli_args: list[str] | None = None,
sandbox: str = "workspace-write",
approval_policy: str = "never",
dangerously_bypass_approvals_and_sandbox: bool = True,
skip_git_repo_check: bool = True,
ephemeral: bool = False,
ignore_rules: bool = False,
profile: str | None = None,
reasoning_effort: str | None = None,
resume_session: bool = False,
auto_compact_token_limit: int | None = 200_000,
):
super().__init__(system_message=system_message)
self.logging_level = logging_level
self.logging_name = logging_name
self.retries = retries
self.timeout_seconds = timeout_seconds
self.model = model
self.codex_bin = codex_bin
self.work_dir = work_dir
self.extra_cli_args = extra_cli_args or []
self.sandbox = sandbox
self.approval_policy = approval_policy
self.dangerously_bypass_approvals_and_sandbox = dangerously_bypass_approvals_and_sandbox
self.skip_git_repo_check = skip_git_repo_check
self.ephemeral = ephemeral
self.ignore_rules = ignore_rules
self.profile = profile
self.reasoning_effort = reasoning_effort
self.auto_compact_token_limit = auto_compact_token_limit
self.logger = logging.getLogger(logging_name)
self._call_counter = 0
self._resume_session = resume_session
self._session_id: str | None = None
self._session_state: dict[str, bytes] | None = None
self._session_state_paths: tuple[str, ...] = ()
self._last_metadata: dict = {}
self._log_prefix = None
self.logger.warning("CodexLLM is experimental and has not been production-validated.")
if self.model is None:
self.logger.info("CodexLLM model is unset; codex exec will use its configured default model.")
if log_dir is not None:
os.makedirs(log_dir, exist_ok=True)
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self._log_prefix = os.path.join(log_dir, f"{logging_name}_{stamp}")
[docs]
@_session_tracked
@ChiaFunction(resources={"codex_creds": 0.01})
def prompt(
self,
user_message: str,
tools: list[ChiaTool] | None = None,
) -> CodexQueryResult:
"""Send *user_message* to ``codex exec``."""
import time as _time
from chia.trace.profiler import get_profiler
profiler = get_profiler()
for attempt in range(self.retries):
try:
tool_list = tools or []
self._last_metadata = {}
self._restore_session_state()
cli = self._run_codex(user_message, tool_list)
self._call_counter += 1
self._last_metadata.update({
"model": self.model or "codex-default",
"tools": [
{"name": t.name, "hostname": getattr(t, "hostname", None),
"port": getattr(t, "port", None), "node_id": getattr(t, "node_id", None)}
for t in tool_list
],
})
if profiler.enabled:
profiler.add_info(self._last_metadata)
self._classify_error(cli)
self._capture_session_state(cli)
cli.success = True
return cli
except (RateLimitError, AuthenticationError, BillingError, InvalidRequestError):
raise
except MaxOutputTokensError:
if attempt == 0:
self.logger.warning("Max output tokens on attempt %d/%d, retrying once",
attempt + 1, self.retries)
continue
raise
except ServerError:
backoff = min(5 * 2 ** attempt, 60)
self.logger.warning("Server error on attempt %d/%d, backing off %ds",
attempt + 1, self.retries, backoff)
_time.sleep(backoff)
except (UnknownCodexError, subprocess.TimeoutExpired) as exc:
self.logger.warning("Codex attempt %d/%d failed: %s",
attempt + 1, self.retries, exc)
except Exception as exc:
self.logger.warning("Unexpected Codex error on attempt %d/%d: %s",
attempt + 1, self.retries, exc)
return CodexQueryResult(
result="",
returncode=-1,
stderr="",
stream_result="",
success=False,
session_id=self._session_id,
session_state=self._session_state,
session_state_paths=self._session_state_paths,
)
def _sync_session(self, cli: CodexQueryResult) -> CodexQueryResult:
"""Copy worker-captured Codex session state onto this instance."""
if not self._resume_session:
return cli
session_id = getattr(cli, "session_id", None)
session_state = getattr(cli, "session_state", None)
if session_id:
self._session_id = session_id
if session_state is not None:
self._session_state = session_state
self._session_state_paths = getattr(cli, "session_state_paths", tuple(sorted(session_state)))
return cli
def _get_node_id(self) -> str:
try:
return ray.get_runtime_context().get_node_id()
except Exception:
return "unknown"
def _format_prompt(self, user_message: str) -> str:
if not self.system_message:
return user_message
return f"[System Instructions]\n{self.system_message}\n\n[User Request]\n{user_message}"
def _mcp_config_args(self, tools: list[ChiaTool]) -> list[str]:
args: list[str] = []
for tool in tools:
port = getattr(tool, "port", 8000)
url = f"http://{tool.hostname}:{port}/{tool.name}/mcp"
args += ["-c", f"mcp_servers.{_toml_key(tool.name)}.url={_toml(url)}"]
return args
def _build_cmd(
self,
tools: list[ChiaTool] | None = None,
output_last_message_path: str | None = None,
resume_session_id: str | None = None,
) -> list[str]:
cmd = [self.codex_bin]
if not self.dangerously_bypass_approvals_and_sandbox and self.approval_policy:
cmd += ["--ask-for-approval", self.approval_policy]
if resume_session_id:
cmd += ["exec", "resume", "--json"]
else:
cmd += ["exec", "--json", "--color", "never"]
if self.model:
cmd += ["--model", self.model]
if self.profile:
cmd += ["--profile", self.profile]
if self.work_dir and not resume_session_id:
cmd += ["--cd", self.work_dir]
if self.skip_git_repo_check:
cmd.append("--skip-git-repo-check")
if self.ephemeral:
cmd.append("--ephemeral")
if self.ignore_rules:
cmd.append("--ignore-rules")
if self.dangerously_bypass_approvals_and_sandbox:
cmd.append("--dangerously-bypass-approvals-and-sandbox")
else:
cmd += ["--sandbox", self.sandbox]
if output_last_message_path:
cmd += ["--output-last-message", output_last_message_path]
if self.reasoning_effort:
cmd += ["-c", f"model_reasoning_effort={_toml(self.reasoning_effort)}"]
if self.auto_compact_token_limit is not None:
cmd += ["-c", f"model_auto_compact_token_limit={self.auto_compact_token_limit}"]
cmd += self._mcp_config_args(tools or [])
cmd += self.extra_cli_args
if resume_session_id:
return cmd + [resume_session_id, "-"]
return cmd + ["-"]
def _run_codex(self, user_message: str, tools: list[ChiaTool] | None = None) -> CodexQueryResult:
fd, output_path = tempfile.mkstemp(suffix=".txt")
os.close(fd)
try:
resume_session_id = self._session_id if self._resume_session else None
result = subprocess.run(
self._build_cmd(
tools or [],
output_last_message_path=output_path,
resume_session_id=resume_session_id,
),
input=self._format_prompt(user_message),
capture_output=True,
text=True,
timeout=self.timeout_seconds,
cwd=self.work_dir or None,
env=os.environ.copy(),
)
with open(output_path) as f:
final_text = f.read()
stream, meta, fallback = self._parse_jsonl_stream(result.stdout, result.stderr)
parsed_session_id = parse_session_id(result.stdout)
if self._resume_session and parsed_session_id:
self._session_id = parsed_session_id
if self._session_id:
meta["session_id"] = self._session_id
self._last_metadata = meta
final_text = final_text or fallback
if self._log_prefix is not None:
self._write_log(user_message, final_text, stream)
if result.returncode != 0:
self.logger.warning("codex exited %d: %s", result.returncode, result.stderr[:500])
return CodexQueryResult(
final_text,
result.returncode,
result.stderr,
stream,
session_id=self._session_id,
)
finally:
try:
os.unlink(output_path)
except FileNotFoundError:
pass
def _codex_home(self) -> str:
return os.environ.get("CODEX_HOME") or os.path.join(os.path.expanduser("~"), ".codex")
def _session_state_files(self) -> list[tuple[str, str]]:
home = self._codex_home()
paths: list[tuple[str, str]] = []
for pattern in _CODEX_SESSION_STATE_PATTERNS:
dirname = os.path.dirname(pattern)
basename = os.path.basename(pattern)
root = os.path.join(home, dirname)
if not os.path.isdir(root):
continue
for name in os.listdir(root):
if not re.fullmatch(basename.replace(".", r"\.").replace("*", ".*"), name):
continue
full = os.path.join(root, name)
if os.path.isfile(full):
paths.append((os.path.relpath(full, home), full))
return sorted(set(paths))
def _path_relative_to_codex_home(self, path: str) -> str | None:
home = os.path.abspath(self._codex_home())
full = os.path.abspath(path if os.path.isabs(path) else os.path.join(home, path))
try:
if os.path.commonpath([home, full]) != home:
return None
except ValueError:
return None
return os.path.relpath(full, home)
def _session_rollout_files(self) -> list[tuple[str, str]]:
if self._session_id is None:
return []
home = self._codex_home()
paths: list[tuple[str, str]] = []
pattern = os.path.join(home, "sessions", "**", f"rollout-*{self._session_id}*.jsonl")
for full_rollout in glob(pattern, recursive=True):
rel_rollout = self._path_relative_to_codex_home(full_rollout)
if rel_rollout is not None and os.path.isfile(full_rollout):
paths.append((rel_rollout, full_rollout))
return sorted(set(paths))
def _restore_session_state(self) -> None:
if not self._resume_session or self._session_state is None:
return
home = self._codex_home()
for rel_path, data in self._session_state.items():
if os.path.isabs(rel_path) or rel_path.startswith(".."):
continue
path = os.path.join(home, rel_path)
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
f.write(data)
def _capture_session_state(self, cli: CodexQueryResult) -> None:
if not self._resume_session or self._session_id is None:
return
state: dict[str, bytes] = {}
for rel_path, path in self._session_state_files() + self._session_rollout_files():
try:
with open(path, "rb") as f:
state[rel_path] = f.read()
except OSError:
continue
if state:
self._session_state = state
self._session_state_paths = tuple(sorted(state))
cli.session_state = state
cli.session_state_paths = self._session_state_paths
cli.session_id = self._session_id
def _write_log(self, user_message: str, final_text: str, stream: str) -> None:
prompt = user_message[:500] + ("..." if len(user_message) > 500 else "")
with open(f"{self._log_prefix}.log", "a") as f:
f.write("=" * 80 + "\n")
f.write(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"Prompt #{self._call_counter} (codex)\n")
f.write("=" * 80 + f"\n\n[User Message]\n{prompt}\n\n")
f.write(stream if stream else f"[Response]\n{final_text}\n\n")
if stream and not stream.endswith("\n"):
f.write("\n")
f.write("-" * 80 + "\n\n")
@classmethod
def _parse_jsonl_stream(cls, stdout: str, stderr: str = "") -> tuple[str, dict, str]:
stream_parts: list[str] = []
result_parts: list[str] = []
meta: dict[str, Any] = {}
for line in stdout.splitlines():
event = cls._json_or_none(line)
if event is None:
stream_parts.append(f"[UNPARSED] {_truncate(line.strip(), 200)}\n")
continue
cls._record_usage(event, meta)
cls._record_event(event, stream_parts, result_parts)
if stderr:
stream_parts.append(f"[stderr]\n{_truncate(stderr)}\n\n")
return "".join(stream_parts), {k: v for k, v in meta.items() if v}, "".join(result_parts)
@staticmethod
def _json_or_none(line: str) -> dict | None:
try:
return json.loads(line)
except json.JSONDecodeError:
return None
@classmethod
def _record_event(cls, event: dict, stream: list[str], results: list[str]) -> None:
payload = _payload(event)
etype = str(payload.get("type") or event.get("type") or "").lower()
text = cls._text(payload)
if "tool" in etype and any(k in etype for k in ("result", "output", "finish", "complete")):
stream.append(f"[Tool Result]\n{_truncate(text)}\n\n")
elif "tool" in etype and any(k in etype for k in ("call", "start", "begin")):
name = payload.get("name") or payload.get("tool_name") or payload.get("tool") or "unknown"
args = payload.get("arguments", payload.get("args", payload.get("input", {})))
stream.append(f"[Tool Call: {name}]\nArgs: {_truncate(json.dumps(args))}\n\n")
elif "reason" in etype or "thinking" in etype:
if text:
stream.append(f"[Thinking]\n{_truncate(text)}\n\n")
elif "error" in etype:
stream.append(f"[Error]\n{_truncate(text or json.dumps(event, sort_keys=True))}\n\n")
elif any(k in etype for k in ("message", "response", "assistant", "final")) and text:
results.append(text)
stream.append(f"[Response]\n{_truncate(text)}\n\n")
@classmethod
def _text(cls, value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, list):
return "\n".join(filter(None, (cls._text(v) for v in value)))
if isinstance(value, dict):
for key in ("text", "content", "message", "output", "result", "delta"):
text = cls._text(value.get(key))
if text:
return text
return ""
return str(value)
@staticmethod
def _record_usage(event: dict, meta: dict) -> None:
payload = _payload(event)
etype = str(payload.get("type") or event.get("type") or "").lower()
if "turn" in etype and any(token in etype for token in ("complete", "end", "done")):
meta["num_turns"] = meta.get("num_turns", 0) + 1
usage = payload.get("usage") or payload.get("tokens") or payload.get("token_usage")
if isinstance(payload.get("info"), dict):
usage = usage or payload["info"].get("last_token_usage") or payload["info"].get("total_token_usage")
if not isinstance(usage, dict):
return
for dest, sources in _TOKEN_ALIASES.items():
for source in sources:
value = usage.get(source)
if isinstance(value, (int, float)):
meta[dest] = meta.get(dest, 0) + value
cache = usage.get("cache")
if isinstance(cache, dict):
for source, dest in (("read", "cache_read_input_tokens"),
("write", "cache_creation_input_tokens")):
if isinstance(cache.get(source), (int, float)):
meta[dest] = meta.get(dest, 0) + cache[source]
def _classify_error(self, cli: QueryResult) -> None:
combined = "\n".join(part for part in (cli.stderr, cli.result, cli.stream_result) if part)
lower = combined.lower()
if (
any(k in lower for k in ("rate limit", "usage limit", "too many requests"))
or _RATE_LIMIT_429_RE.search(combined)
):
raise RateLimitError(
node_id=self._get_node_id(),
reset_time=parse_rate_limit_reset(combined),
raw_message=combined[:300],
exit_code=cli.returncode,
)
if cli.returncode == 0:
return
node_id = self._get_node_id()
for error_cls, patterns in _ERROR_PATTERNS:
if any(pattern in lower for pattern in patterns):
raise error_cls(node_id=node_id, exit_code=cli.returncode, raw_message=combined[:300])
raise UnknownCodexError(node_id=node_id, exit_code=cli.returncode, raw_message=combined[:300])