Source code for mmirage.cli_utils.status
"""Shard status and retry helpers for the MMIRAGE CLI."""
from __future__ import annotations
import json
import logging
import os
import sys
from dataclasses import dataclass
from typing import List, Literal, Sequence, Tuple
from mmirage.config.config import MMirageConfig
from mmirage.cli_utils.slurm import submit_slurm_job
from mmirage.shard_utils import ShardStatus
logger = logging.getLogger(__name__)
[docs]
@dataclass
class ShardSummary:
"""Compact status summary for shard execution."""
total: int
successful: int
running: int
failed: int
max_retries_exceeded: int
[docs]
def max_allowed_attempts(max_retries: int) -> int:
"""Return max allowed total attempts for a shard.
Total attempts = initial attempt + max_retries.
"""
return max_retries + 1
[docs]
def is_retry_budget_exceeded(attempt_count: int, max_retries: int) -> bool:
"""Return whether a shard has exceeded the retry budget."""
return attempt_count > max_allowed_attempts(max_retries)
[docs]
def shard_state_dir(state_root: str, shard_id: int) -> str:
"""Return the state directory for a shard."""
return os.path.join(state_root, f"shard_{shard_id}")
[docs]
def get_shard_status(state_dir: str) -> Tuple[str, int]:
"""Read the current status and attempt counter for a shard."""
status_file = os.path.join(state_dir, "status.json")
if not os.path.exists(status_file):
return ("missing", 0)
try:
with open(status_file, "r", encoding="utf-8") as handle:
data = json.load(handle)
if not isinstance(data, dict):
logger.warning("Invalid shard status format in %s; expected object", status_file)
return ("unknown", 0)
except (OSError, json.JSONDecodeError) as exc:
logger.warning("Failed to read shard status from %s: %s", status_file, exc)
return ("unknown", 0)
parsed = ShardStatus.from_dict(data)
return (parsed.status, parsed.retry_count)
[docs]
def check_failed_shards(cfg: MMirageConfig) -> Tuple[List[int], ShardSummary]:
"""Return retryable failed shards and a compact summary."""
state_root = cfg.loading_params.get_state_root()
num_shards = cfg.loading_params.get_num_shards()
max_retries = cfg.execution_params.max_retries
failed_shards: List[int] = []
success_count = 0
running_count = 0
exhausted_count = 0
allowed_attempts = max_allowed_attempts(max_retries)
for shard_id in range(num_shards):
status, attempt_count = get_shard_status(shard_state_dir(state_root, shard_id))
if status == "success":
success_count += 1
elif status == "running":
running_count += 1
elif is_retry_budget_exceeded(attempt_count, max_retries):
exhausted_count += 1
logger.warning(
"Shard %s exceeded retry budget (attempts=%s, max_allowed_attempts=%s)",
shard_id,
attempt_count,
allowed_attempts,
)
else:
failed_shards.append(shard_id)
summary = ShardSummary(
total=num_shards,
successful=success_count,
running=running_count,
failed=len(failed_shards),
max_retries_exceeded=exhausted_count,
)
return failed_shards, summary
[docs]
def confirm_retry(count: int, confirm_mode: Literal["prompt", "yes"]) -> bool:
"""Return whether retry submission is confirmed.
Modes:
- prompt: ask the user interactively
- yes: submit without prompting
"""
if confirm_mode == "yes":
return True
if not sys.stdin.isatty():
logger.error("Interactive confirmation requested but stdin is not a TTY; use --yes")
return False
response = input(f"Retry {count} shard(s)? (y/N) ")
return response.strip().lower() == "y"
[docs]
def status_exit_code(failed_shards: Sequence[int], summary: ShardSummary) -> int:
"""Map shard status to an exit code."""
return (
0
if not failed_shards
and summary.max_retries_exceeded == 0
and summary.running == 0
and summary.successful == summary.total
else 1
)
[docs]
def submit_failed_shards(
cfg: MMirageConfig,
config_path: str,
failed_shards: Sequence[int],
confirm_mode: Literal["prompt", "yes"],
) -> int:
"""Submit retry jobs for failed shards when requested."""
if not failed_shards:
return 0
if not confirm_retry(len(failed_shards), confirm_mode):
return 1
job_id = submit_slurm_job(cfg, config_path, failed_shards)
if job_id is None:
return 1
return 0