"""Variable system for MMIRAGE pipeline with multimodal support."""
from __future__ import annotations
import abc
import os
from dataclasses import dataclass
from types import MappingProxyType
from typing import Any, Dict, List, Literal, Optional, Sequence
import jmespath
from PIL import Image
# Cache compiled JMESPath expressions to avoid recompilation across samples.
_JMESPATH_CACHE: Dict[str, jmespath.parser.ParsedResult] = dict()
def _get_compiled_query(key: str):
"""Get (and cache) a compiled JMESPath expression for "key".
This avoids recompiling the same expression for each sample.
"""
expr: Optional[jmespath.parser.ParsedResult] = _JMESPATH_CACHE.get(key)
if expr is None:
expr = jmespath.compile(key)
_JMESPATH_CACHE[key] = expr
return expr
[docs]
@dataclass
class BaseVar(abc.ABC):
"""Base class for variables in the MMIRAGE pipeline.
Attributes:
name: Name of the variable.
"""
name: str = ""
[docs]
@dataclass
class OutputVar(BaseVar):
"""Output variable generated by processors.
Output variables are created by processors (e.g., LLMs) and can
depend on input variables and previously computed output variables.
Attributes:
name: Name of the variable.
type: Type identifier for the processor that generates this variable.
"""
type: str = ""
[docs]
@abc.abstractmethod
def is_computable(self, vars: Sequence[BaseVar]) -> bool:
raise NotImplementedError()
def _to_abs_path(path: str) -> str:
"""Normalize a local filesystem path to an absolute path."""
return os.path.abspath(path)
def _resolve_image_input(value: Any, image_base_path: Optional[str] = None) -> Any:
"""Resolve image input to a format SGLang can use.
In your SGLang setup, local images should be passed as plain filesystem paths
(NOT file:// URIs). This function:
- passes through PIL images and http(s) URLs
- normalizes local paths to absolute paths
- resolves relative paths against image_base_path (preferred) or CWD (fallback)
- strips a leading file:// if it is present (for backward compatibility)
Args:
value: The image value to resolve (string path/URL, PIL image, etc.).
image_base_path: Optional base directory for resolving relative image paths.
Returns:
A resolved image spec: PIL.Image.Image, http(s) URL string, or absolute path string.
Raises:
FileNotFoundError: If the resolved file does not exist.
RuntimeError: If the resolved path exists but is not a file.
"""
# Already a PIL Image - pass through
if isinstance(value, Image.Image):
return value
# Not a string - pass through
if not isinstance(value, str):
return value
# HTTP(S) URL - pass through
if value.startswith(("http://", "https://")):
return value
# If value is a file:// URI, strip it to a local path
if value.startswith("file://"):
value = value[len("file://") :]
# Absolute local path
if os.path.isabs(value):
path = os.path.realpath(value) if os.path.islink(value) else value
path = _to_abs_path(path)
if not os.path.exists(path):
raise FileNotFoundError(f"Absolute image path does not exist: {path}")
if not os.path.isfile(path):
raise RuntimeError(f"The provided path exists but is not a file: {path}")
return path
# Relative path: resolve against base path if provided
if image_base_path:
resolved_path = _to_abs_path(os.path.join(image_base_path, value))
if not os.path.exists(resolved_path):
raise FileNotFoundError(
f"Resolved image path '{resolved_path}' does not exist "
f"(base='{image_base_path}', rel='{value}')."
)
if not os.path.isfile(resolved_path):
raise RuntimeError(f"Resolved image path is not a file: {resolved_path}")
return resolved_path
# Fallback: resolve against current working directory
cwd_path = _to_abs_path(value)
if os.path.exists(cwd_path):
if not os.path.isfile(cwd_path):
raise RuntimeError(f"Relative image path exists but is not a file: {cwd_path}")
return cwd_path
raise FileNotFoundError(
f"Relative image path '{value}' cannot be resolved. "
"Set image_base_path in the dataset configuration or use an absolute path."
)
[docs]
class VariableEnvironment:
"""Environment for storing and accessing variables during processing."""
[docs]
def __init__(self, var_env: Dict[str, Any], image_vars: Optional[set] = None) -> None:
"""Initialize a variable environment.
Args:
var_env: Dictionary mapping variable names to their values.
image_vars: Set of variable names that represent images. Defaults to empty set.
"""
self._vars_env = var_env
self._image_vars = image_vars or set()
[docs]
def with_variable(self, key: str, value: Any, is_image: bool = False) -> "VariableEnvironment":
"""Create a new environment with an additional variable.
Args:
key: Name of the variable to add.
value: Value of the variable.
is_image: Whether the variable represents an image.
Returns:
VariableEnvironment: New environment with the added variable.
"""
new_image_vars = self._image_vars.copy()
if is_image:
new_image_vars.add(key)
return VariableEnvironment(self._vars_env | {key: value}, new_image_vars)
[docs]
def to_dict(self) -> MappingProxyType:
"""Get an immutable view of the variable dictionary.
Returns:
MappingProxyType providing read-only access to variables.
"""
return MappingProxyType(self._vars_env)
[docs]
def get(self, key: str, default: Any = None) -> Any:
"""Get a variable value by name.
Args:
key: Name of the variable to retrieve.
default: Default value to return if variable is not found.
Returns:
Variable value, or default if not found.
"""
return self._vars_env.get(key, default)
[docs]
def is_image_var(self, key: str) -> bool:
"""Check if a variable represents an image.
Args:
key: Name of the variable to check.
Returns:
bool: True if the variable is an image variable, False otherwise.
"""
return key in self._image_vars
[docs]
def get_image_vars(self) -> set:
"""Get all image variable names.
Returns:
set: Copy of the set containing names of all image variables.
"""
return self._image_vars.copy()
[docs]
def get_images(self) -> List[Any]:
"""Get image values in a deterministic order."""
return [self._vars_env[k] for k in sorted(self._image_vars) if k in self._vars_env]
[docs]
def has_images(self) -> bool:
"""Check if the environment contains any image variables.
Returns:
bool: True if at least one image variable is present, False otherwise.
"""
return any(k in self._vars_env for k in self._image_vars)