Source code for mmirage.core.loader.base
"""Base classes and registry for data loaders in MMIRAGE."""
from __future__ import annotations
import abc
from typing import Any, Callable, Generic, Optional, Type, TypeVar
from dataclasses import dataclass
from datasets import Dataset, DatasetDict
[docs]
@dataclass
class BaseDataLoaderConfig:
"""Base configuration class for data loaders.
All data loader configurations must inherit from this class and
specify a type identifier.
Attributes:
type: String identifier for the loader type (e.g., "JSONL", "loadable").
output_dir: Directory path for saving processed output shards.
image_base_path: Optional base directory for resolving relative image paths in this dataset.
"""
type: str
output_dir: str
image_base_path: Optional[str] = None
C = TypeVar("C", bound=BaseDataLoaderConfig)
DatasetLike = Dataset | DatasetDict
[docs]
class BaseDataLoader(abc.ABC, Generic[C]):
"""Abstract base class for data loaders.
Data loaders are responsible for loading datasets from various sources
(JSONL files, Hugging Face datasets, etc.) and returning them as
Hugging Face Dataset objects.
Type Parameters:
C: The configuration class type for this loader.
Methods:
from_config: Load a dataset from the given configuration.
"""
[docs]
@abc.abstractmethod
def from_config(self, ds_config: C) -> Optional[DatasetLike]:
"""Load a dataset from the given configuration.
Args:
ds_config: Configuration object for loading the dataset.
Returns:
A Hugging Face Dataset or DatasetDict, or None if loading fails.
Raises:
NotImplementedError: If not implemented by subclass.
"""
raise NotImplementedError()
[docs]
class DataLoaderRegistry:
"""Registry for managing and accessing available data loaders.
Provides a centralized registry for data loader classes and their
associated configuration classes, allowing dynamic loader instantiation
based on type names.
Attributes:
_registry: Mapping from loader name to registered loader class.
_config_registry: Mapping from loader name to its configuration class.
"""
_registry = dict()
_config_registry = dict()
[docs]
@classmethod
def register(cls, name: str, config_cls: Type[BaseDataLoaderConfig]) -> Callable:
"""Register a data loader class.
Args:
name: String identifier for the loader.
config_cls: Configuration class associated with this loader.
Returns:
Decorator function to register the loader class.
"""
def inner_register(clazz: Any):
cls._registry[name] = clazz
cls._config_registry[name] = config_cls
return inner_register
[docs]
@classmethod
def get_processor(cls, name: str) -> Type[BaseDataLoader]:
"""Get a registered loader class by name.
Args:
name: String identifier of the loader.
Returns:
The registered loader class.
Raises:
ValueError: If no loader is registered under the given name.
"""
if name not in cls._registry:
raise ValueError(
f"Loader {name} not registered. Available loaders are {list(cls._registry.keys())}"
)
return cls._registry[name]
[docs]
@classmethod
def get_config_cls(cls, name: str) -> Type[BaseDataLoaderConfig]:
"""Get a registered configuration class by loader name.
Args:
name: String identifier of the loader.
Returns:
The registered configuration class.
Raises:
ValueError: If no loader is registered under the given name.
"""
if name not in cls._config_registry:
raise ValueError(
f"Loader {name} not registered. Available loaders are {list(cls._config_registry.keys())}"
)
return cls._config_registry[name]
[docs]
class AutoDataLoader:
"""Factory class for instantiating data loaders by name."""
[docs]
@classmethod
def from_name(cls, name: str) -> Type[BaseDataLoader]:
"""Retrieve a data loader class by its registered name.
Args:
name: The registry name of the data loader.
Returns:
The registered data loader class.
Raises:
ValueError: If no data loader is registered under the given name.
"""
return DataLoaderRegistry.get_processor(name)