"""Command-line interface for MMIRAGE pipeline."""
from __future__ import annotations
import argparse
import json
import logging
import os
import subprocess
import sys
from dataclasses import asdict
from typing import List, Optional
from mmirage.cli_utils.runtime import setup_runtime, validate_edf_env_path
from mmirage.cli_utils.slurm import require_slurm, submit_slurm_job, wait_for_slurm_job
from mmirage.cli_utils.status import (
check_failed_shards,
is_retry_budget_exceeded,
shard_state_dir,
get_shard_status,
status_exit_code,
submit_failed_shards,
)
from mmirage.config.config import MMirageConfig
from mmirage.config.utils import load_mmirage_config
from mmirage.merge_shards import MergeReport, merge_from_config, merge_input_dir
logger = logging.getLogger(__name__)
[docs]
def run_local(config_path: str, shard_id: Optional[int] = None) -> int:
"""Run one shard in the current Python environment.
Args:
config_path: Absolute path to the MMIRAGE YAML config file.
shard_id: Optional shard id to inject via SLURM_ARRAY_TASK_ID.
Returns:
Process return code from shard execution.
"""
command = [sys.executable, "-m", "mmirage.shard_process", "--config", config_path]
env = os.environ.copy()
if shard_id is not None:
env["SLURM_ARRAY_TASK_ID"] = str(shard_id)
logger.info("Running local shard processing: %s", " ".join(command))
result = subprocess.run(command, env=env, check=False)
return result.returncode
[docs]
def launch_pipeline(
cfg: MMirageConfig,
config_path: str,
force_retry: bool = False,
require_completion: bool = False,
) -> int:
"""Launch the pipeline according to execution mode and retry settings.
Args:
cfg: Parsed MMIRAGE configuration object.
config_path: Absolute path to the MMIRAGE YAML config file.
force_retry: If True, enable retry orchestration regardless of config flag.
require_completion: If True, wait for completion and verify shard
status before returning success in SLURM mode when auto-retry is off.
Returns:
Exit code: 0 on success, 1 on failure.
"""
auto_retry = force_retry or cfg.execution_params.retry
if not cfg.execution_params.is_slurm():
initial_shard_id = cfg.loading_params.get_shard_id()
if not auto_retry:
exit_code = run_local(config_path, initial_shard_id)
if exit_code == 0:
logger.info("All shards completed successfully")
return exit_code
shard_ids: List[int] = [initial_shard_id]
attempts_by_shard = {initial_shard_id: 0}
state_root = cfg.loading_params.get_state_root()
while True:
run_exit_codes = {}
for shard_id in shard_ids:
attempts_by_shard[shard_id] = attempts_by_shard.get(shard_id, 0) + 1
run_exit_codes[shard_id] = run_local(config_path, shard_id)
failed_shards, summary = check_failed_shards(cfg)
if status_exit_code(failed_shards, summary) == 0:
logger.info("All shards completed successfully")
return 0
runtime_failed = [shard_id for shard_id, rc in run_exit_codes.items() if rc != 0]
candidates = sorted(set(failed_shards) | set(runtime_failed))
retryable_shards: List[int] = []
for shard_id in candidates:
_, state_attempt_count = get_shard_status(shard_state_dir(state_root, shard_id))
memory_attempt_count = attempts_by_shard.get(shard_id, 0)
effective_attempt_count = max(state_attempt_count, memory_attempt_count)
if not is_retry_budget_exceeded(
effective_attempt_count,
cfg.execution_params.max_retries,
):
retryable_shards.append(shard_id)
if not retryable_shards:
logger.error("Pipeline ended with shards that exceeded max retries")
return 1
logger.warning("Retrying failed shards locally: %s", ",".join(map(str, retryable_shards)))
shard_ids = retryable_shards
shard_ids: List[int] = []
while True:
job_id = submit_slurm_job(cfg, config_path, shard_ids)
if job_id is None:
return 1
logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids or 'ALL'}")
if not auto_retry:
if not require_completion:
return 0
wait_for_slurm_job(job_id, cfg)
failed_shards, summary = check_failed_shards(cfg)
status_code = status_exit_code(failed_shards, summary)
if status_code == 0:
logger.info("All shards completed successfully")
else:
logger.error("SLURM run completed with failed shards; merge will not start")
return status_code
wait_for_slurm_job(job_id, cfg)
failed_shards, summary = check_failed_shards(cfg)
if status_exit_code(failed_shards, summary) == 0:
logger.info("All shards completed successfully")
return 0
if not failed_shards:
logger.error("Pipeline ended with shards that exceeded max retries")
return 1
logger.warning("Retrying failed shards: %s", ",".join(map(str, failed_shards)))
shard_ids = failed_shards
[docs]
def add_shared_arguments(parser: argparse.ArgumentParser) -> None:
"""Attach common CLI arguments to a subcommand parser.
Args:
parser: Subcommand parser receiving shared arguments.
"""
parser.add_argument("--config", required=True, help="Path to a MMIRAGE YAML config file")
parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Log verbosity",
)
[docs]
def build_argparser() -> argparse.ArgumentParser:
"""Build the CLI parser.
Returns:
Configured top-level argparse parser.
"""
parser = argparse.ArgumentParser(description="MMIRAGE command-line interface")
subparsers = parser.add_subparsers(dest="command", required=True)
submit_parser = subparsers.add_parser("submit", help="Submit one SLURM array job")
add_shared_arguments(submit_parser)
submit_parser.add_argument(
"--shard-ids",
help="Comma-separated shard ids to submit instead of the full array",
)
submit_parser.add_argument("--wait", action="store_true", help="Wait for the submitted job")
check_parser = subparsers.add_parser("check", help="Inspect shard status")
add_shared_arguments(check_parser)
check_parser.add_argument(
"--retry",
dest="retry",
action="store_true",
help="Submit a retry job for failed shards.",
)
check_parser.set_defaults(retry=False)
check_parser.add_argument(
"-y",
"--yes",
dest="confirm_mode",
action="store_const",
const="yes",
help="Submit retries without prompting.",
)
check_parser.set_defaults(confirm_mode="prompt")
retry_parser = subparsers.add_parser("retry", help="Submit only failed shards")
add_shared_arguments(retry_parser)
retry_parser.add_argument(
"-y",
"--yes",
dest="confirm_mode",
action="store_const",
const="yes",
help="Submit retries without prompting.",
)
retry_parser.set_defaults(confirm_mode="prompt")
run_parser = subparsers.add_parser(
"run",
help="Run according to execution_params.mode and execution_params.retry",
)
add_shared_arguments(run_parser)
run_parser.add_argument(
"--force-retry",
action="store_true",
help="Enable retry orchestration even if execution_params.retry is false",
)
run_parser.add_argument(
"--shard-id",
type=int,
default=None,
help="Run a single shard locally (overrides execution mode)",
)
merge_parser = subparsers.add_parser(
"merge",
help="Merge shard outputs listed in config.loading_params.datasets",
)
add_shared_arguments(merge_parser)
merge_parser.add_argument(
"--output-dir",
"--output-root",
dest="output_dir",
default=None,
help=(
"Optional root directory for merged outputs. MMIRAGE creates one "
"subdirectory per configured dataset under this root. If omitted, "
"each dataset is merged into <dataset.output_dir>/merged"
),
)
merge_dir_parser = subparsers.add_parser(
"merge-dir",
help="Merge shards directly from an input directory into an output directory",
)
merge_dir_parser.add_argument(
"--input-dir",
required=True,
help=(
"Input directory containing one dataset with shard_* folders, or "
"multiple dataset subdirectories each containing shard_* folders"
),
)
merge_dir_parser.add_argument(
"--output-dir",
required=True,
help="Output directory for merged dataset(s)",
)
merge_dir_parser.add_argument(
"--log-level",
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Log verbosity",
)
return parser
[docs]
def log_merge_reports(reports: List[MergeReport]) -> None:
"""Log merge summary for one or more datasets."""
for report in reports:
skipped_total = report.skipped_invalid_dirs + report.skipped_zero_rows
logger.info(
"Merged dataset %s: shards=%d rows=%d output=%s skipped=%d "
"(invalid=%d, zero_rows=%d)",
report.dataset_name,
report.used_shards,
report.merged_rows,
report.output_dir,
skipped_total,
report.skipped_invalid_dirs,
report.skipped_zero_rows,
)
[docs]
def parse_shard_ids(raw_value: Optional[str], num_shards: Optional[int] = None) -> List[int]:
"""Parse a comma-separated shard id list.
Args:
raw_value: Comma-separated shard ids, or None/empty for full array.
num_shards: Optional upper bound used for range validation.
Returns:
Parsed shard ids.
"""
if not raw_value:
return []
shard_ids: List[int] = []
for raw_shard_id in raw_value.split(","):
candidate = raw_shard_id.strip()
if not candidate:
continue
if candidate.isdigit():
shard_id = int(candidate)
else:
raise ValueError(f"Invalid shard id {candidate!r}; expected integers")
if num_shards is not None and shard_id >= num_shards:
raise ValueError(f"Invalid shard id {shard_id}; expected 0 <= shard_id < {num_shards}")
shard_ids.append(shard_id)
return shard_ids
[docs]
def handle_run(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int:
"""Handle the canonical run command.
Args:
args: Parsed CLI namespace.
cfg: Parsed MMIRAGE configuration object.
config_path: Absolute path to the MMIRAGE YAML config file.
Returns:
Exit code for the run operation.
"""
if args.shard_id is not None:
return run_local(config_path, args.shard_id)
exit_code = launch_pipeline(
cfg,
config_path,
force_retry=args.force_retry,
require_completion=cfg.execution_params.merge,
)
if exit_code != 0:
return exit_code
if cfg.execution_params.merge:
logger.info("Execution_params.merge is true; merging shard outputs")
reports = merge_from_config(cfg)
log_merge_reports(reports)
return 0
[docs]
def handle_submit(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int:
"""Submit a SLURM array job and optionally wait.
Args:
args: Parsed CLI namespace.
cfg: Parsed MMIRAGE configuration object.
config_path: Absolute path to the MMIRAGE YAML config file.
Returns:
Exit code for submission/wait outcome.
"""
if require_slurm(cfg, "submit") != 0:
return 1
shard_ids = parse_shard_ids(args.shard_ids, cfg.loading_params.get_num_shards())
job_id = submit_slurm_job(cfg, config_path, shard_ids)
if job_id is None:
return 1
logger.info(f"Submitted SLURM job {job_id} for shard ids: {shard_ids or 'ALL'}")
if not args.wait:
return 0
wait_for_slurm_job(job_id, cfg)
failed_shards, summary = check_failed_shards(cfg)
status_code = status_exit_code(failed_shards, summary)
if status_code == 0:
logger.info("All shards completed successfully")
return status_code
[docs]
def handle_check(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int:
"""Inspect shard status and optionally submit retries.
Args:
args: Parsed CLI namespace.
cfg: Parsed MMIRAGE configuration object.
config_path: Absolute path to the MMIRAGE YAML config file.
Returns:
Exit code based on shard status and optional retry submission.
"""
failed_shards, summary = check_failed_shards(cfg)
print(json.dumps(asdict(summary), indent=2))
status_code = status_exit_code(failed_shards, summary)
if not cfg.execution_params.is_slurm():
return status_code
if not args.retry:
return status_code
if not failed_shards:
return status_code
return submit_failed_shards(
cfg=cfg,
config_path=config_path,
failed_shards=failed_shards,
confirm_mode=args.confirm_mode,
)
[docs]
def handle_retry(args: argparse.Namespace, cfg: MMirageConfig, config_path: str) -> int:
"""Submit retries for failed shards only.
Args:
args: Parsed CLI namespace.
cfg: Parsed MMIRAGE configuration object.
config_path: Absolute path to the MMIRAGE YAML config file.
Returns:
Exit code for retry submission outcome.
"""
if require_slurm(cfg, "retry") != 0:
return 1
failed_shards, summary = check_failed_shards(cfg)
print(json.dumps(asdict(summary), indent=2))
if not failed_shards:
if summary.max_retries_exceeded > 0:
logger.error("No retryable shards remain")
return 1
logger.info("All shards already succeeded.")
return 0
return submit_failed_shards(
cfg=cfg,
config_path=config_path,
failed_shards=failed_shards,
confirm_mode=args.confirm_mode,
)
[docs]
def handle_merge(args: argparse.Namespace, cfg: MMirageConfig, _config_path: str) -> int:
"""Merge shard outputs defined in config.loading_params.datasets.
Args:
args: Parsed CLI namespace.
cfg: Parsed MMIRAGE configuration object.
_config_path: Absolute path to the MMIRAGE YAML config file (not needed here).
Returns:
Exit code for merge outcome.
"""
reports = merge_from_config(cfg, output_root=args.output_dir)
log_merge_reports(reports)
return 0
[docs]
def handle_merge_dir(args: argparse.Namespace) -> int:
"""Merge shard outputs directly from input/output directory arguments.
Args:
args: Parsed CLI namespace.
Returns:
Exit code for merge outcome.
"""
reports = merge_input_dir(args.input_dir, args.output_dir)
log_merge_reports(reports)
return 0
[docs]
def main() -> None:
"""CLI entry point."""
parser = build_argparser()
args = parser.parse_args()
configure_logging(args.log_level)
try:
if args.command == "merge-dir":
sys.exit(handle_merge_dir(args))
config_path = os.path.abspath(args.config)
cfg = load_mmirage_config(config_path)
setup_runtime(cfg, args.log_level)
validate_edf_env_path(cfg)
handlers = {
"run": handle_run,
"submit": handle_submit,
"check": handle_check,
"retry": handle_retry,
"merge": handle_merge,
}
handler = handlers.get(args.command)
if handler is None:
logger.error("Unknown command: %s", args.command)
sys.exit(2)
sys.exit(handler(args, cfg, config_path))
except Exception as exc:
logger.error("Error: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
sys.exit(1)
if __name__ == "__main__":
main()