Source code for mmirage.core.process.variables

"""Variable system for MMIRAGE pipeline with multimodal support."""

from __future__ import annotations

import abc
import os
from dataclasses import dataclass
from types import MappingProxyType
from typing import Any, Dict, List, Literal, Optional, Sequence

import jmespath
from PIL import Image

# Cache compiled JMESPath expressions to avoid recompilation across samples.
_JMESPATH_CACHE: Dict[str, jmespath.parser.ParsedResult] = dict()


def _get_compiled_query(key: str):
    """Get (and cache) a compiled JMESPath expression for "key".

    This avoids recompiling the same expression for each sample.
    """
    expr: Optional[jmespath.parser.ParsedResult] = _JMESPATH_CACHE.get(key)
    if expr is None:
        expr = jmespath.compile(key)
        _JMESPATH_CACHE[key] = expr
    return expr


[docs] @dataclass class BaseVar(abc.ABC): """Base class for variables in the MMIRAGE pipeline. Attributes: name: Name of the variable. """ name: str = ""
[docs] @dataclass class InputVar(BaseVar): """Input variable extracted from source datasets. Attributes: key: JMESPath query to extract the variable from a sample. type: Variable type - "text" or "image". """ key: str = "" type: Literal["text", "image"] = "text"
[docs] def is_image(self) -> bool: """Check if this input variable represents an image. Returns: bool: True if the variable type is "image", False otherwise. """ return self.type == "image"
[docs] @dataclass class OutputVar(BaseVar): """Output variable generated by processors. Output variables are created by processors (e.g., LLMs) and can depend on input variables and previously computed output variables. Attributes: name: Name of the variable. type: Type identifier for the processor that generates this variable. """ type: str = ""
[docs] @abc.abstractmethod def is_computable(self, vars: Sequence[BaseVar]) -> bool: raise NotImplementedError()
def _to_abs_path(path: str) -> str: """Normalize a local filesystem path to an absolute path.""" return os.path.abspath(path) def _resolve_image_input(value: Any, image_base_path: Optional[str] = None) -> Any: """Resolve image input to a format SGLang can use. In your SGLang setup, local images should be passed as plain filesystem paths (NOT file:// URIs). This function: - passes through PIL images and http(s) URLs - normalizes local paths to absolute paths - resolves relative paths against image_base_path (preferred) or CWD (fallback) - strips a leading file:// if it is present (for backward compatibility) Args: value: The image value to resolve (string path/URL, PIL image, etc.). image_base_path: Optional base directory for resolving relative image paths. Returns: A resolved image spec: PIL.Image.Image, http(s) URL string, or absolute path string. Raises: FileNotFoundError: If the resolved file does not exist. RuntimeError: If the resolved path exists but is not a file. """ # Already a PIL Image - pass through if isinstance(value, Image.Image): return value # Not a string - pass through if not isinstance(value, str): return value # HTTP(S) URL - pass through if value.startswith(("http://", "https://")): return value # If value is a file:// URI, strip it to a local path if value.startswith("file://"): value = value[len("file://") :] # Absolute local path if os.path.isabs(value): path = os.path.realpath(value) if os.path.islink(value) else value path = _to_abs_path(path) if not os.path.exists(path): raise FileNotFoundError(f"Absolute image path does not exist: {path}") if not os.path.isfile(path): raise RuntimeError(f"The provided path exists but is not a file: {path}") return path # Relative path: resolve against base path if provided if image_base_path: resolved_path = _to_abs_path(os.path.join(image_base_path, value)) if not os.path.exists(resolved_path): raise FileNotFoundError( f"Resolved image path '{resolved_path}' does not exist " f"(base='{image_base_path}', rel='{value}')." ) if not os.path.isfile(resolved_path): raise RuntimeError(f"Resolved image path is not a file: {resolved_path}") return resolved_path # Fallback: resolve against current working directory cwd_path = _to_abs_path(value) if os.path.exists(cwd_path): if not os.path.isfile(cwd_path): raise RuntimeError(f"Relative image path exists but is not a file: {cwd_path}") return cwd_path raise FileNotFoundError( f"Relative image path '{value}' cannot be resolved. " "Set image_base_path in the dataset configuration or use an absolute path." )
[docs] class VariableEnvironment: """Environment for storing and accessing variables during processing."""
[docs] def __init__(self, var_env: Dict[str, Any], image_vars: Optional[set] = None) -> None: """Initialize a variable environment. Args: var_env: Dictionary mapping variable names to their values. image_vars: Set of variable names that represent images. Defaults to empty set. """ self._vars_env = var_env self._image_vars = image_vars or set()
[docs] def with_variable(self, key: str, value: Any, is_image: bool = False) -> "VariableEnvironment": """Create a new environment with an additional variable. Args: key: Name of the variable to add. value: Value of the variable. is_image: Whether the variable represents an image. Returns: VariableEnvironment: New environment with the added variable. """ new_image_vars = self._image_vars.copy() if is_image: new_image_vars.add(key) return VariableEnvironment(self._vars_env | {key: value}, new_image_vars)
[docs] def to_dict(self) -> MappingProxyType: """Get an immutable view of the variable dictionary. Returns: MappingProxyType providing read-only access to variables. """ return MappingProxyType(self._vars_env)
[docs] def get(self, key: str, default: Any = None) -> Any: """Get a variable value by name. Args: key: Name of the variable to retrieve. default: Default value to return if variable is not found. Returns: Variable value, or default if not found. """ return self._vars_env.get(key, default)
[docs] def is_image_var(self, key: str) -> bool: """Check if a variable represents an image. Args: key: Name of the variable to check. Returns: bool: True if the variable is an image variable, False otherwise. """ return key in self._image_vars
[docs] def get_image_vars(self) -> set: """Get all image variable names. Returns: set: Copy of the set containing names of all image variables. """ return self._image_vars.copy()
[docs] def get_images(self) -> List[Any]: """Get image values in a deterministic order.""" return [self._vars_env[k] for k in sorted(self._image_vars) if k in self._vars_env]
[docs] def has_images(self) -> bool: """Check if the environment contains any image variables. Returns: bool: True if at least one image variable is present, False otherwise. """ return any(k in self._vars_env for k in self._image_vars)
[docs] @staticmethod def from_input_variables(sample: Dict[str, Any], input_vars: List[InputVar], image_base_path: Optional[str] = None) -> "VariableEnvironment": """Create a variable environment from a single sample. Args: sample: Dictionary containing the data for one sample. input_vars: List of input variable definitions to extract. image_base_path: Optional base directory for resolving relative image paths. Returns: VariableEnvironment: Environment populated with extracted variables. Raises: ValueError: If a required input variable is not found in the sample. """ ret: Dict[str, Any] = {} image_vars: set = set() for input_var in input_vars: compiled_query = _get_compiled_query(input_var.key) value = compiled_query.search(sample) if value is None: raise ValueError( f"Input variable '{input_var.name}' with key '{input_var.key}' not found in the sample." ) if input_var.is_image(): value = _resolve_image_input(value, image_base_path) image_vars.add(input_var.name) ret[input_var.name] = value return VariableEnvironment(ret, image_vars)
[docs] @staticmethod def from_batch_input_variables( batch: Dict[str, List[Any]], input_vars: List[InputVar], image_base_path: Optional[str] = None ) -> List["VariableEnvironment"]: """Extract input variables from a batch of samples. Args: batch: Dictionary mapping column names to lists of values. input_vars: List of input variable definitions. image_base_path: Optional base directory for resolving relative image paths. Returns: List of VariableEnvironments, one for each sample in the batch. """ vars_samples: List[VariableEnvironment] = [] batch_size = len(next(iter(batch.values()))) batch_list: List[Dict[str, Any]] = [ {k: batch[k][i] for k in batch.keys()} for i in range(batch_size) ] for sample in batch_list: vars_samples.append(VariableEnvironment.from_input_variables(sample, input_vars, image_base_path)) return vars_samples