Source code for mmirage.shard_process

"""Main script for processing dataset shards with MMIRAGE.

Supports both text-only and multimodal (vision-language) processing.
"""

import argparse
import logging
import sys
import traceback
from typing import Any, Dict, List

from mmirage.config.utils import load_mmirage_config
from mmirage.core.loader.base import DatasetLike
from mmirage.core.loader.utils import load_datasets_from_configs
from mmirage.core.process.mapper import MMIRAGEMapper
from mmirage.core.writer.renderer import TemplateRenderer
from mmirage.shard_utils import (
    _cleanup_old_shard_data,
    _count_rows,
    _dataset_out_dir,
    _mark_failure,
    _mark_running,
    _mark_success,
    _remove_columns,
    _save_dataset_atomic,
    _shard_dataset,
    _shard_state_dir,
)

logger = logging.getLogger(__name__)


[docs] def rewrite_batch( batch: Dict[str, List[Any]], mapper: MMIRAGEMapper, renderer: TemplateRenderer, image_base_path: str = None, ) -> Dict[str, List[Any]]: """Rewrite a batch of samples by applying transformations. Args: batch: Dictionary mapping column names to lists of values. mapper: MMIRAGEMapper for processing transformations. renderer: TemplateRenderer for generating output. image_base_path: Optional base directory for resolving relative image paths. Returns: Dictionary mapping output keys to lists of rendered values. Raises: ValueError: If variables are not computable given the configuration. """ if not mapper.validate_vars(): raise ValueError( "Uncomputable variables detected. Verify your configuration and make sure that there is no undefined variables" ) batch_environment = mapper.rewrite_batch(batch, image_base_path) rendered_list = renderer.batch_render(batch_environment) return rendered_list
[docs] def main(): """ Process a single shard of the dataset. Loads configuration, datasets, processes the shard using MMIRAGE transformations (including multimodal), and saves the result to disk. """ ap = argparse.ArgumentParser("Process dataset shards using MMIRAGE with SGLang.") ap.add_argument( "--config", help="YAML config for MMIRAGE pipeline.", required=True, ) args = ap.parse_args() cfg = load_mmirage_config(args.config) loading_params = cfg.loading_params processing_params = cfg.processing_params datasets_config = loading_params.datasets if not datasets_config: raise ValueError("No datasets provided in config.loading_params.datasets") shard_id = loading_params.get_shard_id() num_shards = loading_params.get_num_shards() last_shard_id = num_shards - 1 if not (0 <= shard_id < num_shards): raise ValueError(f"Invalid shard_id={shard_id}, num_shards={num_shards}") state_dir = _shard_state_dir(shard_id, loading_params.get_state_root()) try: retry_count = _mark_running(state_dir, shard_id, datasets_config) logger.info(f"Starting shard {shard_id}/{last_shard_id} (attempt #{retry_count})") if retry_count > 1: for ds_config in datasets_config: out_dir = _dataset_out_dir(shard_id, ds_config) _cleanup_old_shard_data(out_dir) ds_all = load_datasets_from_configs(datasets_config) total_rows = sum(_count_rows(ds) for ds in ds_all) ds_all_shard = [_shard_dataset(ds, num_shards, shard_id) for ds in ds_all] shard_rows = sum(_count_rows(ds) for ds in ds_all_shard) logger.info( f"Loaded {len(datasets_config)} dataset(s): {datasets_config} " f"→ {total_rows} total rows; this logical shard has {shard_rows} rows." ) mapper = MMIRAGEMapper( cfg.processors, processing_params.inputs, processing_params.outputs, ) renderer = TemplateRenderer(processing_params.output_schema) ds_processed_all: List[DatasetLike] = [] for ds_idx, ds_shard in enumerate(ds_all_shard): ds_config = datasets_config[ds_idx] if processing_params.remove_columns: remove_columns = _remove_columns(ds_shard) else: remove_columns = [] logger.info( f"Processing dataset {ds_idx} for shard {shard_id}: " f"path={ds_config.path}, output_dir={ds_config.output_dir}" ) ds_processed = ds_shard.map( rewrite_batch, batched=True, batch_size=loading_params.get_batch_size(), load_from_cache_file=False, desc=f"Shard {shard_id}/{last_shard_id} dataset {ds_idx}", fn_kwargs={ "mapper": mapper, "renderer": renderer, "image_base_path": ds_config.image_base_path, }, remove_columns=remove_columns, ) ds_processed_all.append(ds_processed) for ds_idx, (ds_config, ds_processed) in enumerate(zip(datasets_config, ds_processed_all)): out_dir = _dataset_out_dir(shard_id, ds_config) _save_dataset_atomic(ds_processed, out_dir) logger.info(f"✅ Saved dataset {ds_idx} shard in: {out_dir}") _mark_success(state_dir) logger.info(f"✅ Logical shard {shard_id} completed successfully") except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" logger.error(f"❌ Shard {shard_id} failed: {error_msg}") logger.error(traceback.format_exc()) _mark_failure(state_dir, error_msg) sys.exit(1)
if __name__ == "__main__": main()