"""SLURM helpers for the MMIRAGE CLI."""
from __future__ import annotations
import logging
import os
import shlex
import subprocess
import time
from typing import Optional, Sequence
from mmirage.config.config import MMirageConfig
from mmirage.cli_utils.runtime import create_directories, expand_path, get_project_root
logger = logging.getLogger(__name__)
def _bash_double_quote(value: str) -> str:
"""Return a double-quoted bash string literal.
We intentionally do NOT escape '$' so that $VARS from config can expand on
compute nodes (e.g. $SCRATCH). This matches typical SLURM job scripts.
To avoid command injection, we reject values containing shell command
substitution syntax such as ``$(...)`` or backticks. Variable expansion
using ``$VAR`` or ``${VAR}`` is still allowed.
"""
# Disallow command substitution while still allowing $VAR expansion.
if "`" in value or "$(" in value:
raise ValueError(
"Config value contains unsupported shell command substitution "
"(` or '$('). Command substitution is not allowed in SLURM-generated scripts."
)
escaped = value.replace("\\", "\\\\").replace('"', '\\"')
return f'"{escaped}"'
def _shell_path(value: str, project_root: str) -> str:
"""Expand user home and make relative paths project-rooted.
If the path starts with '$' we assume it will expand on the compute node and
therefore do not attempt to join it with project_root.
"""
raw = value.strip()
if not raw:
return raw
raw = os.path.expanduser(raw)
if raw.startswith("$"):
return raw
if not os.path.isabs(raw):
raw = os.path.join(project_root, raw)
return raw
[docs]
def build_sbatch_script(cfg: MMirageConfig, config_path: str) -> str:
"""Build the sbatch payload executed for each array task."""
project_root = get_project_root(cfg)
hf_home = _shell_path(cfg.execution_params.hf_home, project_root)
state_root = _shell_path(cfg.loading_params.get_state_root(), project_root)
src_root = os.path.join(project_root, "src")
shard_process_path = os.path.join(src_root, "mmirage", "shard_process.py")
lines = [
"#!/bin/bash",
"set -euo pipefail",
f"export PYTHONPATH={_bash_double_quote(src_root)}:${{PYTHONPATH:-}}",
f"export SHARD_PROCESS={_bash_double_quote(shard_process_path)}",
f"export HF_HOME={_bash_double_quote(hf_home)}",
f"export MMIRAGE_CONFIG={_bash_double_quote(config_path)}",
f"mkdir -p {_bash_double_quote(hf_home)}",
f"mkdir -p {_bash_double_quote(state_root)}",
"srun_args=(--cpus-per-task ${SLURM_CPUS_PER_TASK:-1} --wait 60)",
]
if cfg.execution_params.edf_env:
edf_env = expand_path(cfg.execution_params.edf_env, project_root)
lines.append(f"srun_args+=(--environment={shlex.quote(edf_env)})")
account = cfg.execution_params.account
if not account:
raise ValueError("execution_params.account must be set in slurm mode")
lines.append(f"srun_args+=(-A {shlex.quote(account)})")
if cfg.execution_params.reservation:
lines.append(f"srun_args+=(--reservation={shlex.quote(cfg.execution_params.reservation)})")
lines.extend(
[
"srun \"${srun_args[@]}\" bash -c 'if command -v python3 >/dev/null 2>&1; then PYTHON_CMD=python3; elif command -v python >/dev/null 2>&1; then PYTHON_CMD=python; else echo \"python3/python not found in PATH\" >&2; exit 127; fi; echo \"Using Python: ${PYTHON_CMD} ($(${PYTHON_CMD} --version 2>&1))\"; ${PYTHON_CMD} -c \"import sys; raise SystemExit(0 if sys.version_info >= (3, 10) else 2)\" || { echo \"MMIRAGE requires Python >= 3.10 on compute nodes\" >&2; exit 2; }; exec ${PYTHON_CMD} \"$SHARD_PROCESS\" --config \"$MMIRAGE_CONFIG\"'",
"echo \"Shard ${SLURM_ARRAY_TASK_ID:-0} completed\"",
]
)
return "\n".join(lines) + "\n"
[docs]
def submit_slurm_job(
cfg: MMirageConfig,
config_path: str,
shard_ids: Optional[Sequence[int]] = None,
) -> Optional[int]:
"""Submit a SLURM array job and return its job ID."""
project_root = get_project_root(cfg)
report_dir = expand_path(cfg.execution_params.report_dir, project_root)
create_directories([report_dir])
command = [
"sbatch",
"--parsable",
f"--job-name={cfg.execution_params.job_name}",
f"--chdir={project_root}",
f"--output={os.path.join(report_dir, 'R-%x.%A_%a.out')}",
f"--error={os.path.join(report_dir, 'R-%x.%A_%a.err')}",
f"--nodes={cfg.execution_params.nodes}",
f"--ntasks-per-node={cfg.execution_params.ntasks_per_node}",
f"--gres=gpu:{cfg.execution_params.gpus}",
f"--cpus-per-task={cfg.execution_params.cpus_per_task}",
f"--time={cfg.execution_params.time_limit}",
f"--account={cfg.execution_params.account}",
]
if cfg.execution_params.reservation:
command.append(f"--reservation={cfg.execution_params.reservation}")
requested_shards = list(shard_ids or [])
if requested_shards:
command.append(f"--array={','.join(str(shard_id) for shard_id in requested_shards)}")
else:
num_shards = cfg.loading_params.get_num_shards()
last_shard_id = num_shards - 1
command.append(f"--array=0-{last_shard_id}")
logger.info("Submitting SLURM job: %s", " ".join(command))
result = subprocess.run(
command,
input=build_sbatch_script(cfg, config_path),
text=True,
capture_output=True,
check=False,
)
if result.returncode != 0:
logger.error("sbatch failed: %s", result.stderr.strip())
return None
raw_job_id = result.stdout.strip().split(";", 1)[0]
try:
return int(raw_job_id)
except ValueError:
logger.error("Unable to parse job id from sbatch output: %s", result.stdout.strip())
return None
[docs]
def wait_for_slurm_job(job_id: int, cfg: MMirageConfig) -> None:
"""Wait for a SLURM job array to leave the queue."""
logger.info("Waiting for SLURM job %s", job_id)
while True:
result = subprocess.run(
["squeue", "-h", "-j", str(job_id)],
capture_output=True,
text=True,
check=False,
)
if result.returncode == 0 and not result.stdout.strip():
break
time.sleep(cfg.execution_params.poll_interval_seconds)
if cfg.execution_params.settle_time_seconds > 0:
logger.info("Waiting %ss for state files to settle", cfg.execution_params.settle_time_seconds)
time.sleep(cfg.execution_params.settle_time_seconds)
[docs]
def require_slurm(cfg: MMirageConfig, command_name: str) -> int:
"""Ensure command can only run in SLURM mode."""
if cfg.execution_params.is_slurm():
return 0
logger.error("%s requires execution_params.mode=slurm", command_name)
return 1