"""PID tracking and cancellation for ChiaFunction remote tasks.
Tracks subprocess PIDs spawned during ``@ChiaFunction`` remote execution
and provides :func:`chia_cancel` to kill those process trees before
cancelling the Ray task.
The Popen hook is scoped per-task via ``threading.local`` so concurrent
Ray tasks on the same worker don't interfere with each other.
"""
from __future__ import annotations
import logging
import os
import signal
import subprocess
import threading
import time
from contextlib import contextmanager
from typing import Any
import ray
logger = logging.getLogger(__name__)
_REGISTRY_ACTOR_NAME = "ChiaPidRegistry"
_REGISTRY_NAMESPACE = "chia"
# Max time to wait for a SIGTERM'd process to exit before escalating to
# SIGKILL. The kill helpers poll for actual death and return as soon as
# the target is gone, so this is a ceiling, not a fixed delay.
_KILL_GRACE_SECONDS = 25.0
_KILL_POLL_INTERVAL = 0.1
# ---------------------------------------------------------------------------
# Ray actor — centralized PID registry on the head node
# ---------------------------------------------------------------------------
[docs]
class PidRegistryActor:
"""Ray actor that maps task IDs to subprocess PIDs.
Created lazily on first registration, looked up by workers and the
driver via :func:`_get_registry`.
"""
def __init__(self):
# task_id_hex -> [(node_id, pid, is_pgid)]
self._tasks: dict[str, list[tuple[str, int, bool]]] = {}
def register(self, task_id: str, node_id: str, pid: int, is_pgid: bool) -> None:
self._tasks.setdefault(task_id, []).append((node_id, pid, is_pgid))
def unregister(self, task_id: str) -> None:
self._tasks.pop(task_id, None)
def get_and_remove(self, task_id: str) -> list[tuple[str, int, bool]]:
return self._tasks.pop(task_id, [])
[docs]
def kill_all(self, grace: float = _KILL_GRACE_SECONDS) -> int:
"""Kill all tracked subprocess PIDs across all nodes.
Dispatches remote kill tasks and waits for completion. Each kill
sends SIGTERM, polls for the process to exit, and only escalates to
SIGKILL if it is still alive after ``grace`` seconds. Returns the
number of PIDs targeted.
"""
all_pids = []
for pid_list in self._tasks.values():
all_pids.extend(pid_list)
self._tasks.clear()
if not all_pids:
return 0
kill_refs = []
for node_id, pid, is_pgid in all_pids:
try:
scheduling = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id, soft=False,
)
kill_ref = ray.remote(_kill_pid).options(
num_cpus=0,
scheduling_strategy=scheduling,
).remote(pid, is_pgid, grace)
kill_refs.append(kill_ref)
except Exception:
pass
if kill_refs:
try:
# Each _kill_pid blocks up to `grace`; allow a margin for
# dispatch/scheduling overhead before giving up the wait.
ray.get(kill_refs, timeout=grace + 10)
except Exception:
pass
return len(all_pids)
# ---------------------------------------------------------------------------
# Cached actor handle (same pattern as profiler.py get_profiler / get_collector)
# ---------------------------------------------------------------------------
_registry_handle = None
_registry_lock = threading.Lock()
_cleanup_installed = False
def _install_driver_cleanup():
"""Install a SIGTERM handler on the driver so ``ray job stop`` kills tracked PIDs.
Only effective on the main thread (signal handlers can't be set elsewhere).
Chains to any previously-installed SIGTERM handler after cleanup.
"""
global _cleanup_installed
if _cleanup_installed:
return
_cleanup_installed = True
prev_handler = signal.getsignal(signal.SIGTERM)
def _sigterm_cleanup(signum, frame):
# Read the cached handle directly — avoid _get_registry() which
# takes a lock that the main thread may already hold.
handle = _registry_handle
if handle is not None:
try:
n = ray.get(handle.kill_all.remote(), timeout=35)
if n:
logger.info(f"SIGTERM cleanup: killed {n} tracked subprocess(es)")
except Exception:
logger.debug("SIGTERM cleanup: kill_all failed", exc_info=True)
if callable(prev_handler) and prev_handler not in (signal.SIG_DFL, signal.SIG_IGN):
prev_handler(signum, frame)
else:
raise SystemExit(128 + signum)
signal.signal(signal.SIGTERM, _sigterm_cleanup)
def _reset_registry():
"""Discard the cached registry handle so the next call recreates it."""
global _registry_handle
with _registry_lock:
_registry_handle = None
def _get_registry():
"""Return a handle to the PID registry actor, creating it if needed.
Returns ``None`` if Ray is not initialized. Inherits the driver's
``runtime_env`` and pins the actor to the current node (same pattern
as :func:`chia.trace.profiler.start_collector`).
"""
global _registry_handle
if _registry_handle is not None:
return _registry_handle
with _registry_lock:
if _registry_handle is not None:
return _registry_handle
if not ray.is_initialized():
return None
try:
_registry_handle = ray.get_actor(
_REGISTRY_ACTOR_NAME, namespace=_REGISTRY_NAMESPACE,
)
except ValueError:
# Actor doesn't exist yet — create it.
try:
ctx = ray.get_runtime_context()
local_node_id = ctx.get_node_id()
scheduling = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=local_node_id, soft=False,
)
except Exception:
scheduling = None
opts = {
"name": _REGISTRY_ACTOR_NAME,
"namespace": _REGISTRY_NAMESPACE,
"num_cpus": 0,
}
if scheduling is not None:
opts["scheduling_strategy"] = scheduling
try:
_registry_handle = ray.remote(PidRegistryActor).options(**opts).remote()
except Exception:
# Another process created the actor first — just look it up.
try:
_registry_handle = ray.get_actor(
_REGISTRY_ACTOR_NAME, namespace=_REGISTRY_NAMESPACE,
)
except Exception:
logger.debug("Failed to get/create PID registry actor",
exc_info=True)
return None
except Exception:
logger.debug("Failed to get/create PID registry actor", exc_info=True)
return None
# On the driver (no task context), install a SIGTERM handler so
# ``ray job stop`` kills all tracked subprocesses before exiting.
try:
ray.get_runtime_context().get_task_id()
except RuntimeError:
# No task ID → we're on the driver.
if threading.current_thread() is threading.main_thread():
try:
_install_driver_cleanup()
except Exception:
pass
return _registry_handle
# ---------------------------------------------------------------------------
# Thread-local Popen hook
# ---------------------------------------------------------------------------
_tls = threading.local()
_hook_installed = False
_hook_lock = threading.Lock()
_original_popen_init = None
def _install_popen_hook():
"""Wrap ``subprocess.Popen.__init__`` to track PIDs. Idempotent."""
global _hook_installed, _original_popen_init
if _hook_installed:
return
with _hook_lock:
if _hook_installed:
return
_original_popen_init = subprocess.Popen.__init__
def _tracked_popen_init(self, *args, **kwargs):
_original_popen_init(self, *args, **kwargs)
task_id = getattr(_tls, "pid_task_id", None)
if task_id is not None:
is_pgid = kwargs.get("start_new_session", False)
# Store locally so the trampoline can kill on BaseException
# (e.g. ray job stop) without needing the registry.
local_list = getattr(_tls, "tracked_pids", None)
if local_list is not None:
local_list.append((self.pid, is_pgid))
# Also register with the central actor for chia_cancel().
node_id = getattr(_tls, "pid_node_id", "")
registry = _get_registry()
if registry is not None:
try:
registry.register.remote(task_id, node_id, self.pid, is_pgid)
except Exception:
pass
subprocess.Popen.__init__ = _tracked_popen_init
_hook_installed = True
# ---------------------------------------------------------------------------
# Context manager for trampolines
# ---------------------------------------------------------------------------
def _proc_state_and_pgrp(pid: int) -> tuple[str, int] | None:
"""Return ``(state, pgrp)`` from ``/proc/<pid>/stat``, or ``None`` if gone.
The ``comm`` field is wrapped in parentheses and may itself contain
spaces or parentheses, so we split on everything after the final
``)`` — state is then the first field, pgrp the third.
"""
try:
with open(f"/proc/{pid}/stat", "rb") as f:
data = f.read()
except OSError:
return None
rparen = data.rfind(b")")
if rparen == -1:
return None
fields = data[rparen + 2:].split()
try:
return fields[0].decode(), int(fields[2])
except (IndexError, ValueError, UnicodeDecodeError):
return None
def _is_running(pid: int, is_pgid: bool) -> bool:
"""True if a *live* (non-zombie) process still exists for ``pid``.
A terminated-but-unreaped process (zombie, state ``Z``) counts as
dead: it is no longer executing, which is what callers care about.
``os.kill(pid, 0)`` cannot make this distinction — it reports zombies
as alive — so we read ``/proc`` directly. For a process group
(``is_pgid``), returns True if any group member is still live.
"""
if is_pgid:
try:
entries = os.listdir("/proc")
except OSError:
return False
for entry in entries:
if not entry.isdigit():
continue
info = _proc_state_and_pgrp(int(entry))
if info is None:
continue
state, pgrp = info
if pgrp == pid and state != "Z":
return True
return False
info = _proc_state_and_pgrp(pid)
return info is not None and info[0] != "Z"
def _wait_for_death(
targets: list[tuple[int, bool]],
grace: float = _KILL_GRACE_SECONDS,
poll_interval: float = _KILL_POLL_INTERVAL,
) -> list[tuple[int, bool]]:
"""Poll until every target is dead or ``grace`` seconds elapse.
``targets`` is a list of ``(pid, is_pgid)``. Returns the targets
still alive when the grace period expired (empty if all died first),
so callers know which ones to SIGKILL. Returns as soon as the last
target dies — it does not wait out the full grace period.
"""
deadline = time.monotonic() + grace
alive = list(targets)
while alive:
alive = [t for t in alive if _is_running(*t)]
if not alive or time.monotonic() >= deadline:
break
time.sleep(poll_interval)
return alive
def _kill_local_pids(pids: list[tuple[int, bool]]) -> None:
"""Kill all tracked PIDs locally (SIGTERM, then SIGKILL survivors).
Called from the trampoline when a ``BaseException`` is caught (e.g.
``ray job stop`` delivering ``TaskCancelledError``). This is the
local-cleanup counterpart to :func:`chia_cancel`'s remote-kill path.
Returns as soon as every target has exited; only escalates to
SIGKILL for processes still alive after ``_KILL_GRACE_SECONDS``.
"""
sent = []
for pid, is_pgid in pids:
kill_fn = os.killpg if is_pgid else os.kill
try:
kill_fn(pid, signal.SIGTERM)
sent.append((pid, is_pgid))
except OSError:
continue # already dead
for pid, is_pgid in _wait_for_death(sent):
kill_fn = os.killpg if is_pgid else os.kill
try:
kill_fn(pid, signal.SIGKILL)
except OSError:
pass
@contextmanager
def _pid_tracking_scope():
"""Track subprocess PIDs for the current Ray task.
Sets thread-local state so the Popen hook registers PIDs under
this task's ID. On normal exit, cleans up the registry entry.
On ``BaseException`` (e.g. ``ray job stop``), kills all tracked
subprocesses locally before re-raising.
"""
_install_popen_hook()
registry = _get_registry()
if registry is None:
yield
return
try:
task_id = ray.get_runtime_context().get_task_id()
node_id = ray.get_runtime_context().get_node_id()
except Exception:
yield
return
_tls.pid_task_id = task_id
_tls.pid_node_id = node_id
_tls.tracked_pids = []
try:
yield
except BaseException:
# ray job stop / ray.cancel delivers TaskCancelledError here.
# Kill subprocesses locally — we're on the same node.
tracked = list(_tls.tracked_pids)
if tracked:
logger.info(f"Task interrupted — killing {len(tracked)} tracked subprocess(es)")
_kill_local_pids(tracked)
raise
finally:
_tls.pid_task_id = None
_tls.pid_node_id = None
_tls.tracked_pids = []
try:
registry.unregister.remote(task_id)
except Exception:
pass
# ---------------------------------------------------------------------------
# Remote kill function — dispatched to the target node
# ---------------------------------------------------------------------------
def _kill_pid(pid: int, is_pgid: bool, grace: float = _KILL_GRACE_SECONDS) -> None:
"""Kill a process (or process group): SIGTERM, then SIGKILL if it lingers.
Returns as soon as the target exits. Only if it is still alive after
``grace`` seconds do we escalate to SIGKILL.
"""
kill_fn = os.killpg if is_pgid else os.kill
try:
kill_fn(pid, signal.SIGTERM)
except OSError:
return # already dead
if _wait_for_death([(pid, is_pgid)], grace):
try:
kill_fn(pid, signal.SIGKILL)
except OSError:
pass
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
def chia_cancel(ref: ray.ObjectRef, force: bool = False) -> None:
"""Cancel a running ChiaFunction task, killing its subprocesses first.
Looks up any subprocess PIDs spawned by the task, kills them on the
correct remote nodes (using process group kill for ``start_new_session``
subprocesses), then calls ``ray.cancel()``.
Args:
ref: The ``ObjectRef`` returned by ``chia_remote()``.
force: Passed through to ``ray.cancel()``. If ``True``, the Ray
worker is killed; if ``False`` (default), a
``TaskCancelledError`` is raised cooperatively.
"""
registry = _get_registry()
if registry is not None:
try:
task_id = ref.task_id().hex()
pids = ray.get(registry.get_and_remove.remote(task_id))
if pids:
kill_refs = []
for node_id, pid, is_pgid in pids:
scheduling = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id, soft=False,
)
kill_ref = ray.remote(_kill_pid).options(
num_cpus=0,
scheduling_strategy=scheduling,
).remote(pid, is_pgid)
kill_refs.append(kill_ref)
# Wait for kills to complete before cancelling the task.
try:
ray.get(kill_refs, timeout=35)
except Exception:
pass
except Exception:
logger.debug("PID cleanup failed, falling back to ray.cancel only",
exc_info=True)
ray.cancel(ref, force=force)