"""Head-node output cache for ChiaFunction tasks.
A small key/value store, pinned to the head node as a Ray actor, that pickles
arbitrary Python objects keyed by a string tag, with an LRU byte budget and a
warm-start scan from disk. It has two halves that work together with the
:mod:`chia.base.bypass` mechanism:
* **Write is automatic.** A function marked ``cache: true`` in the YAML has its
real-run output written to the cache automatically by ``ChiaFunction`` (via an
``ObjectRefCallback`` that fires when ``get()`` resolves the result), keyed by
the call's ``_chia_tag``. No user code is needed.
* **Read is manual, via bypass.** The cache does *not* auto-serve. To replay a
cached value, register a bypass provider that reads it back::
def cache_provider(tag, data_path, *args, **kwargs):
hit, value = ray.get(get_active_cache().read.remote(tag))
if not hit:
raise KeyError(f"cache miss for tag {tag!r}")
return value
bypass.set_provider("run_verilator_test", cache_provider)
Cache = write-through populate; bypass = read path. They share the tag.
Usage
-----
::
from chia.base.cache import start_cache
# Call once on the driver after ray.init(). Idempotent.
start_cache(size=4, units="GB", cache_dir_path="/data/chia_cache",
yaml_path=args.bypass_config)
YAML format (same file as bypass, parallel ``cache:`` section)
--------------------------------------------------------------
::
cache:
run_verilator_test:
cache: true
tags: ["iter.*"] # optional; mirror bypass tag patterns
# shorthand (cache, no tags):
build_megaboom: true
"""
from __future__ import annotations
import hashlib
import logging
import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional
import ray
import yaml
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from chia.base.ChiaFunction import chia_actor, get
logger = logging.getLogger("chia.cache")
# Multipliers from a units string to bytes. KiB/MiB/... are aliases for the
# binary KB/MB/... used here (everything is powers of 1024).
UNITS = {
"B": 1,
"KB": 1024,
"MB": 1024 ** 2,
"GB": 1024 ** 3,
"TB": 1024 ** 4,
"KIB": 1024,
"MIB": 1024 ** 2,
"GIB": 1024 ** 3,
"TIB": 1024 ** 4,
}
# Name used to register / look up the cache actor in Ray. The actor is the
# single source of truth: every process resolves it fresh via ray.get_actor
# (see get_active_cache), so there is no process-local handle or config to go
# stale when the cache is restarted or replaced.
_CACHE_ACTOR_NAME = "ChiaCacheStore"
# ---------------------------------------------------------------------------
# Serialization helpers
# ---------------------------------------------------------------------------
def _dumps(obj: Any) -> bytes:
"""Serialize *obj* with Ray's vendored ``ray.cloudpickle``.
cloudpickle is a strict superset of stdlib pickle — it also handles
lambdas, closures, and locally-defined classes — so we use it for
everything. ``ray.cloudpickle`` is always importable (Ray is a hard
dependency); the standalone ``cloudpickle`` package may not be installed.
"""
from ray import cloudpickle
return cloudpickle.dumps(obj)
def _loads(blob: bytes) -> Any:
from ray import cloudpickle
return cloudpickle.loads(blob)
def _tag_filename(tag: str) -> str:
"""Deterministic, filesystem-safe filename for *tag*.
A readable slug plus a hash suffix so distinct tags that slug to the same
string don't collide. Deterministic across runs so the warm-start scan and
later writes address the same file.
"""
slug = re.sub(r"[^A-Za-z0-9._-]", "_", tag)[:80]
digest = hashlib.sha256(tag.encode("utf-8")).hexdigest()[:16]
return f"{slug}_{digest}.pkl"
# ---------------------------------------------------------------------------
# The cache actor
# ---------------------------------------------------------------------------
@ray.remote(num_cpus=0)
class Cache:
"""Head-pinned LRU object cache backed by pickle files on disk.
Single-threaded by design (default Ray actor concurrency): every method is
atomic, so the LRU index, eviction, and flush never race even when many
workers write concurrently. Do NOT make methods async or raise
``max_concurrency``.
Each pickle on disk stores ``(tag, data)`` so a cold actor can recover the
raw tag (the index key) during the warm-start scan.
"""
def __init__(self, cache_dir_path: str, budget_bytes: int, config: dict):
self._dir = Path(cache_dir_path)
self._budget = int(budget_bytes)
self._config = dict(config or {})
# raw_tag -> (filename, nbytes); insertion/access order == LRU order
# (front = least recently used, back = most recently used).
self._index: "OrderedDict[str, tuple[str, int]]" = OrderedDict()
self._size_bytes = 0
self._dir.mkdir(parents=True, exist_ok=True)
self._warm_start()
logger.info(
"Cache started: dir=%s budget=%d bytes, warm-started %d entries (%d bytes)",
self._dir, self._budget, len(self._index), self._size_bytes,
)
# ------------------------------------------------------------------
# Warm start
# ------------------------------------------------------------------
def _warm_start(self) -> None:
"""Rebuild the index from ``*.pkl`` files already on disk.
Seeds LRU order by file mtime (oldest first) so the existing on-disk
recency is approximately preserved across driver restarts.
"""
files = sorted(self._dir.glob("*.pkl"), key=lambda p: p.stat().st_mtime)
for path in files:
try:
tag, _data = _loads(path.read_bytes())
except Exception: # noqa: BLE001 — skip corrupt/foreign files
logger.warning("Cache: skipping unreadable file %s", path)
continue
nbytes = path.stat().st_size
# Drop a stale entry for the same tag if the deterministic filename
# differs (shouldn't happen, but keep the index consistent).
if tag in self._index:
old_name, old_nbytes = self._index.pop(tag)
self._size_bytes -= old_nbytes
self._index[tag] = (path.name, nbytes)
self._size_bytes += nbytes
# ------------------------------------------------------------------
# Storage API
# ------------------------------------------------------------------
def write(self, tag: str, data: Any) -> bool:
"""Pickle ``(tag, data)`` to disk under *tag*, evicting LRU as needed.
Returns True if stored, False if skipped (a single item larger than the
whole budget is never cached).
"""
blob = _dumps((tag, data))
nbytes = len(blob)
if nbytes > self._budget:
logger.warning(
"Cache: item %r is %d bytes > budget %d; not caching",
tag, nbytes, self._budget,
)
return False
# Overwrite: remove the old accounting first so eviction math is correct.
if tag in self._index:
old_name, old_nbytes = self._index.pop(tag)
self._size_bytes -= old_nbytes
# Evict least-recently-used entries until the new item fits.
while self._size_bytes + nbytes > self._budget and self._index:
victim_tag, (victim_name, victim_nbytes) = self._index.popitem(last=False)
self._size_bytes -= victim_nbytes
self._unlink(victim_name)
filename = _tag_filename(tag)
(self._dir / filename).write_bytes(blob)
self._index[tag] = (filename, nbytes)
self._index.move_to_end(tag) # mark MRU
self._size_bytes += nbytes
return True
def read(self, tag: str) -> tuple[bool, Any]:
"""Return ``(True, value)`` on hit (and mark MRU), else ``(False, None)``.
A single round-trip so callers never race a has()/read() pair against a
concurrent eviction or flush.
"""
entry = self._index.get(tag)
if entry is None:
return (False, None)
filename, _nbytes = entry
path = self._dir / filename
try:
_tag, data = _loads(path.read_bytes())
except Exception: # noqa: BLE001 — file vanished or corrupt → treat as miss
self._size_bytes -= self._index.pop(tag, (None, 0))[1]
return (False, None)
self._index.move_to_end(tag)
return (True, data)
def has(self, tag: str) -> bool:
return tag in self._index
def evict(self, tag: str) -> bool:
"""Delete the entry for *tag*. Returns True if something was removed."""
entry = self._index.pop(tag, None)
if entry is None:
return False
filename, nbytes = entry
self._size_bytes -= nbytes
self._unlink(filename)
return True
def flush(self) -> None:
"""Clear ALL entries: delete every pickle and reset the index."""
for filename, _nbytes in self._index.values():
self._unlink(filename)
self._index.clear()
self._size_bytes = 0
def keys(self) -> list[str]:
return list(self._index.keys())
def size_bytes(self) -> int:
return self._size_bytes
def get_config(self) -> dict:
"""Return the static ``cache:`` config (for the local is_cached() check)."""
return dict(self._config)
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _unlink(self, filename: str) -> None:
try:
(self._dir / filename).unlink()
except FileNotFoundError:
pass
# ---------------------------------------------------------------------------
# Module functions
# ---------------------------------------------------------------------------
def _load_cache_config(yaml_path: str) -> dict:
"""Parse the ``cache:`` section of *yaml_path* into ``{func: True | [tags]}``.
Mirrors :meth:`chia.base.bypass.Bypass._load_yaml` but for the ``cache:``
key. Functions with ``cache: false`` (or absent) are omitted.
"""
with open(yaml_path) as f:
cfg = yaml.safe_load(f) or {}
config: dict = {}
for func_name, spec in cfg.get("cache", {}).items():
if isinstance(spec, bool):
cache_on, tags = spec, None
elif isinstance(spec, dict):
cache_on = spec.get("cache", False)
tags = spec.get("tags")
else:
continue
if not cache_on:
continue
if tags is not None:
if isinstance(tags, str):
tags = [tags]
config[func_name] = tags
else:
config[func_name] = True
logger.info("Cache: %s tags=%s", func_name, tags)
return config
[docs]
def start_cache(
size: float,
cache_dir_path: str,
units: str = "B",
yaml_path: Optional[str] = None,
namespace: Optional[str] = None,
):
"""Create (or find) the head-pinned cache actor. Idempotent.
Call from the driver after ``ray.init()``. Mirrors
:func:`chia.trace.profiler.start_collector`.
The actor is created ``detached`` so it is reachable from workers in a
*different* Ray job (e.g. a bypass provider running on a remote worker) — a
plain named actor is only visible within its creating job. Pair it with an
explicit ``namespace`` so cross-job lookups resolve it. ``stop_cache`` still
tears it down; cross-run reuse remains the on-disk pickles + warm start, not
a lingering live actor.
Args:
size: Numeric budget; multiplied by ``UNITS[units]`` to get bytes.
cache_dir_path: Directory for the pickle files (on the head node).
units: One of :data:`UNITS` (default ``"B"``).
yaml_path: Optional path to the bypass/cache YAML; the ``cache:`` section
opts functions into automatic caching.
namespace: Optional Ray namespace for the named actor.
Returns:
The ``Cache`` actor handle.
"""
# Idempotent: if the named actor already exists, return it as-is.
existing = get_active_cache(namespace)
if existing is not None:
return existing
units_key = units.upper()
if units_key not in UNITS:
raise ValueError(f"start_cache: unknown units {units!r}; choose from {sorted(UNITS)}")
config = _load_cache_config(yaml_path) if yaml_path else {}
budget = int(size * UNITS[units_key])
# Pin the actor to the node running the driver (where the files live).
node_id = ray.get_runtime_context().get_node_id()
scheduling = NodeAffinitySchedulingStrategy(node_id=node_id, soft=False)
opts = dict(
name=_CACHE_ACTOR_NAME,
get_if_exists=True,
num_cpus=0,
scheduling_strategy=scheduling,
# Detached so a bypass provider on a worker in a different Ray job can
# still reach the cache by name (a plain named actor is job-scoped).
lifetime="detached",
)
if namespace:
opts["namespace"] = namespace
handle = chia_actor(Cache.options(**opts).remote(cache_dir_path, budget, config))
# Block until the actor is live and responding.
get(handle.size_bytes.remote())
return handle
[docs]
def stop_cache(namespace: Optional[str] = None) -> None:
"""Kill the cache actor. Idempotent.
Cross-run reuse comes from the on-disk pickles + warm-start scan, not from a
persisted live actor, so killing the actor loses nothing on disk.
"""
handle = get_active_cache(namespace)
if handle is not None:
try:
# get_active_cache returns a ChiaActorHandle; ray.kill needs the raw
# Ray handle. getattr(..., "actor", handle) also tolerates a raw one.
ray.kill(getattr(handle, "actor", handle))
except Exception: # noqa: BLE001 — best-effort teardown
pass
[docs]
def get_active_cache(namespace: Optional[str] = None):
"""Return the cache actor handle, or None if no cache was started.
Always resolves the current named actor via ``ray.get_actor`` — no
process-local handle is cached. A worker (e.g. inside a bypass provider or a
nested dispatch) reaches the cache with no state threading, and a cache that
was restarted/replaced is picked up automatically rather than serving a
stale, dead handle.
"""
lookup_kwargs = {"namespace": namespace} if namespace else {}
try:
return chia_actor(ray.get_actor(_CACHE_ACTOR_NAME, **lookup_kwargs))
except ValueError:
return None
[docs]
def is_cached(func_name: str, tag: Optional[str] = None) -> bool:
"""Should *func_name*'s output be cached on this call?
True only when the function is configured ``cache: true`` and, if ``tags:``
patterns are present, *tag* ``re.fullmatch``es one. Mirrors
:meth:`chia.base.bypass.Bypass.is_bypassed`. The config is read from the
actor (the source of truth) so it never goes stale across a cache restart;
only tagged dispatches reach this path (see ``_maybe_wrap_cache``).
"""
cache = get_active_cache()
if cache is None:
return False
config = get(cache.get_config.remote())
spec = config.get(func_name)
if not spec:
return False
if isinstance(spec, list):
if tag is None:
return False
return any(re.fullmatch(p, tag) for p in spec)
return True