Source code for mmirage.config.loading

"""Data loading configuration for MMIRAGE pipeline."""

import os
import re
from dataclasses import dataclass, field
from typing import Union, List, cast

from mmirage.core.loader.base import BaseDataLoaderConfig

DEFAULT_STATE_DIR = "~/.cache/MMIRAGE/state_dir"


[docs] @dataclass class LoadingParams: """Parameters for loading and distributing datasets across shards. Defines how datasets are loaded and processed in a distributed manner, supporting sharding for parallel processing. Attributes: datasets: List of dataset configurations to load. state_dir: Shared directory for logical shard state/markers/retry tracking. output_dir: Legacy top-level output directory. Prefer per-dataset output_dir. num_shards: Total number of shards to split the dataset into. shard_id: ID of this shard (0-indexed). batch_size: Batch size for processing samples. Raises: ValueError: If num_shards, shard_id, or batch_size cannot be converted to int. """ datasets: List[BaseDataLoaderConfig] = field(default_factory=list) state_dir: str = DEFAULT_STATE_DIR output_dir: str = "" num_shards: Union[int, str] = 1 shard_id: Union[int, str] = 0 batch_size: Union[int, str] = 1 def __post_init__(self): _UNRESOLVED_ENV_VAR_PATTERN = re.compile(r"^\$(?:\{[A-Za-z_][A-Za-z0-9_]*\}|[A-Za-z_][A-Za-z0-9_]*)$") def is_unresolved_env_var(s: str) -> bool: return bool(_UNRESOLVED_ENV_VAR_PATTERN.fullmatch(s.strip())) if isinstance(self.num_shards, str): try: self.num_shards = int(self.num_shards) if self.num_shards < 1: raise ValueError() except (ValueError, TypeError): if is_unresolved_env_var(self.num_shards): self.num_shards = 1 else: raise ValueError(f"Invalid value for num_shards: {self.num_shards!r}") if isinstance(self.shard_id, str): try: self.shard_id = int(self.shard_id) except (ValueError, TypeError): if is_unresolved_env_var(self.shard_id): self.shard_id = 0 else: raise ValueError(f"Invalid value for shard_id: {self.shard_id!r}") if isinstance(self.batch_size, str): try: self.batch_size = int(self.batch_size) except (ValueError, TypeError): raise ValueError(f"Invalid value for batch_size: {self.batch_size!r}") self.batch_size = max(self.batch_size, 1) raw_state_dir = "" if self.state_dir is None else str(self.state_dir) self.state_dir = raw_state_dir.strip() if not self.state_dir: self.state_dir = DEFAULT_STATE_DIR self.state_dir = os.path.expanduser(self.state_dir)
[docs] def get_state_root(self) -> str: """Get the state root path. Returns: str: State root path. """ return self.state_dir
[docs] def get_num_shards(self) -> int: """Get the total number of shards. Returns: int: Total number of shards. """ return cast(int, self.num_shards)
[docs] def get_shard_id(self) -> int: """Get the ID of this shard. Returns: int: Shard ID (0-indexed). """ return cast(int, self.shard_id)
[docs] def get_batch_size(self) -> int: """Get the batch size for processing. Returns: int: Batch size (minimum 1). """ return cast(int, self.batch_size)