Source code for mmirage.core.process.processors.llm.config
"""Configuration for LLM processor in MMIRAGE."""
from dataclasses import dataclass, field
import logging
import os
from typing import Dict, Optional, Sequence, Type, Any, List
from pydantic import BaseModel, create_model
from mmirage.core.process.variables import BaseVar, OutputVar
from mmirage.core.process.base import BaseProcessorConfig
from mmirage.core.process.base import ProcessorRegistry
from jinja2 import Environment, meta
logger = logging.getLogger(__name__)
env = Environment()
def _parse_tp_size_from_env() -> int:
"""Parse tensor parallelism size from SLURM_GPUS_ON_NODE environment variable.
Defensively parses the environment variable, handling invalid values:
- Returns 1 if the variable is None or empty
- Strips whitespace before parsing
- Returns 1 for non-integer values
- Returns 1 for values <= 0
Returns:
Tensor parallelism size (>= 1), defaults to 1 on any parsing error.
"""
env_value = os.environ.get("SLURM_GPUS_ON_NODE")
if not env_value:
return 1
try:
tp_size = int(env_value.strip())
# Ensure tp_size is positive (must be >= 1)
if tp_size <= 0:
logger.warning(
f"Invalid SLURM_GPUS_ON_NODE value '{env_value}' (must be > 0), defaulting tp_size to 1"
)
return 1
return tp_size
except ValueError:
# ValueError: invalid integer format
logger.warning(
f"Invalid SLURM_GPUS_ON_NODE value '{env_value}', defaulting tp_size to 1"
)
return 1
[docs]
@dataclass
class SGLangServerArgs:
"""Server arguments for SGLang engine.
Attributes:
model_path: Path to the model or HuggingFace model ID.
tp_size: Tensor parallelism size.
trust_remote_code: Whether to trust remote code from HuggingFace.
disable_custom_all_reduce: Whether to disable custom all reduce.
"""
model_path: str = "none"
tp_size: int = field(default_factory=_parse_tp_size_from_env)
trust_remote_code: bool = True
disable_custom_all_reduce: bool = False
[docs]
@dataclass
class SGLangLLMConfig(BaseProcessorConfig):
"""Configuration for LLM processor using SGLang.
Supports both text-only and multimodal (vision-language) models.
Attributes:
type: Type identifier (must be "llm").
server_args: SGLang server arguments including model path and TP size.
default_sampling_params: Default sampling parameters for generation.
chat_template: Chat template name for vision-language models (e.g., "qwen2-vl").
"""
server_args: SGLangServerArgs = field(default_factory=SGLangServerArgs)
default_sampling_params: Dict[str, Any] = field(default_factory=dict)
chat_template: str = "" # Empty means use tokenizer's default
[docs]
@dataclass
class LLMOutputVar(OutputVar):
"""Output variable generated by LLM processor.
Uses Jinja2 templating for prompts and supports both plain text
and structured JSON outputs.
Attributes:
name: Name of the variable.
type: Type identifier (must be "llm").
prompt: Jinja2 template for the LLM prompt.
output_schema: List of field names for JSON output (empty for plain text).
output_type: Output format - "JSON" or "plain".
"""
prompt: str = ""
output_schema: List[str] = field(default_factory=list)
output_type: str = ""
[docs]
def get_output_schema(self) -> Optional[Type[BaseModel]]:
"""Generate a Pydantic model for JSON output validation.
Returns:
A Pydantic BaseModel class if output_type is "JSON" and
output_schema is non-empty, otherwise None.
"""
if self.output_type == "JSON" and self.output_schema:
fields: Dict[str, Any] = {var: (str, ...) for var in self.output_schema}
return create_model("OutputSchema", **fields)
return None
[docs]
def is_computable(self, vars: Sequence[BaseVar]) -> bool:
"""Check if all variables referenced in the prompt are available.
Args:
vars: Sequence of currently available variables.
Returns:
True if all template variables are declared, False otherwise.
"""
parsed_content = env.parse(self.prompt)
template_vars = meta.find_undeclared_variables(parsed_content)
var_names = set(map(lambda v: v.name, vars))
undeclared_vars = template_vars - var_names
if len(undeclared_vars) > 0:
logger.info(
f"⚠️ Undeclared variables found for {self.name}: {undeclared_vars}"
)
return False
return True
ProcessorRegistry.register_types("llm", SGLangLLMConfig, LLMOutputVar)