Adding new modality¶
Structure of the repository:
src
└── multimeditron
├── cli
├── config
├── dataset
│ └── loader
│ └── image
├── model
│ ├── modalities
│ └── projectors
├── train
├── utils
└── verl
In order to add a new modality, we must first understand how the training pipeline process raw modalities:
Modality loading: This step loads modality from the dataset and transforms it into a raw modality format (for instance image bytes).
Modality preprocessing: This step transforms raw modality into
torch.TensorModality embedding: This step is the
forwardstep of your modality embedder. It forwards thetorch.Tensorobject of the preprocessing step to create atorch.Tensor: the modality embedding.
Note that:
Step 1 is model agnostic, every model uses the same loading functions.
Step 2 and 3 are model dependent
This means that if you implement a model for an existing modality, you don’t need to implement the modality loading step.
Implementation example¶
To create a new modality embedder, you need to implement 3 classes:
BaseModalityLoader(only if implementing a new modality type): The modality loader to load the modality from the datasetBaseModalityConfig: The configuration file for both the processor and the modality modelBaseModalityProcessor: The processor class to preprocess your modalitiesBaseModality: The modality model that forward your modalities
In this walkthrough, we will show how to load images and how to create a simple modality embedder.
Modality loader¶
Here is an example to load images from bytes:
from typing import Dict, Any, Union
from multimeditron.dataset.loader import BaseModalityLoader, AutoModalityLoader
from multimeditron.model.constants import MODALITY_VALUE_KEY
import PIL
import io
@AutoModalityLoader.register("raw-image")
class RawImageLoader(BaseModalityLoader):
def __init__(self):
super().__init__()
def load(self, sample: Dict[str, Any]) -> PIL.Image.Image:
image_bytes = sample[MODALITY_VALUE_KEY]["bytes"]
image = PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
return image
A modality loader should always inherit from BaseModalityLoader and be registered using the python annotation register()
The load function has the following signature:
Input: A dictionary that contains a key
"value", i.e.{"value" : <something>}. This is the case for every modality. The actual format of the value field depends on the dataset format. See TODOOutput returns the raw modality (here a
PIL.Image.Image).
Modality configuration¶
The configuration, processor, model architecture follows the same philosophy as Huggingface custom model.
The configuration file configures both the processor and the modality:
from multimeditron.model.modalities.base import BaseModality
class ImageConfig(BaseModalityConfig):
def __init__(
self,
hidden_size: int = 4096,
max_batch_size: int = 32,
clip_name: str = "openai/clip-vit-large-patch14",
projection_type: str = "mlp",
**kwargs
):
super().__init__(
max_batch_size=max_batch_size,
modality_type="image",
hidden_size=hidden_size,
kwargs=kwargs
)
self.clip_name = clip_name
self.projection_type = projection_type
Every configuration needs to inherit BaseModalityConfig and call the __init__ function from BaseModalityConfig wth the arguments:
max_batch_size: the maximum amount of modalities that can be processed in a single batch by the forward function of the modality embeddermodality_type: which modality type does this processor/modality pair handle. This field should match the"type"field in the dataset. See TODOhidden_size: the projected shape of the modality embedder (i.e. the size of a LLM token embedding)
This configuration can be arbitrarily expanded with any JSON-serializable attributes. See Huggingface custom model
Modality (pre)processor¶
A modality processor preprocess modalities to transform the raw modality from the loading step (here a PIL.Image.Image) into a torch.Tensor. This processing phase is applied during the collator phase (unlike the forward pass of the BaseModality)
from multimeditron.model.constants import NUM_EMBEDDINGS_KEY, MODALITY_VALUE_KEY
from multimeditron.model.modalities.base import BaseModalityProcessor
from transformers import AutoImageProcessor, AutoConfig
from typing import Dict, Any
class ImageProcessor(BaseModalityProcessor):
def __init__(self, config):
super().__init__(config)
assert config.clip_name is not None, "clip_name must be specified in the config"
self.image_processor = AutoImageProcessor.from_pretrained(config.clip_name)
feature_extractor_config = AutoConfig.from_pretrained(config.clip_name, trust_remote_code=True)
self._num_patches_per_entry = (feature_extractor_config.vision_config.image_size // feature_extractor_config.vision_config.patch_size) ** 2
def process(self, modality: Dict[str, Any]) -> Dict[str, Any]:
processed_modality = modality.copy()
image = modality[MODALITY_VALUE_KEY]
processed_modality[MODALITY_VALUE_KEY] = self.image_processor(images=image, return_tensors="pt")["pixel_values"][0]
processed_modality[NUM_EMBEDDINGS_KEY] = self._num_patches_per_entry
return processed_modality
Each processor must inherit BaseModalityProcessor (which inherit from ProcessorMixin).
The modality processor must impement the process() function. This function takes:
A
Dict, this is exactly the output of the previous loading phaseThis function returns the exact same
Dictwith the preprocessed modality in the"value"key
Modality modeling¶
Lastly, we implement the modality model. This is the model that performs the forward pass during training. To optimize GPU throughput, you should only put operations that can be parallelized on GPU.
A modality class must inherit BaseModality is typically created with 2 main modules:
A pretrained modality embedder (like a CLIP model): This module produces meaningful embeddings for given modalities
A tunable projection module (usually a simple MLP or a linear layer): This module map embeddings from the modality embedder to the LLM embedding space. The dimension of this embedding space is given by the hidden_size attribute of
BaseModalityConfig
from multimeditron.model.constants import NUM_EMBEDDINGS_KEY, MODALITY_VALUE_KEY
from multimeditron.model.modalities.base import BaseModalityProcessor
from transformers import AutoModel, AutoConfig
import torch
from typing import Dict, Any
@AutoModality.register("meditron_clip")
class ImageModality(BaseModality):
config_class = ImageConfig
preprocessor_class = ImageProcessor
def __init__(self, config: ImageConfig):
super().__init__(config)
self.vision_tower_name = config.clip_name
assert self.vision_tower_name is not None, "vision_tower_name must be specified in the config"
self.feature_extractor = AutoModel.from_pretrained(self.vision_tower_name, trust_remote_code=True)
self.embedding_size = self.feature_extractor.vision_embed_dim
self._num_patches_per_entry = (self.feature_extractor.vision_model.config.image_size // self.feature_extractor.vision_model.config.patch_size) ** 2
self.projector = MLPProjector(self.embedding_size, config.hidden_size, dtype=self.dtype)
def forward(self, inputs: List[torch.Tensor]) -> torch.FloatTensor:
inputs = torch.stack(inputs, dim=0)
inputs = inputs.to(self.feature_extractor.device)
image_features = self.feature_extractor.vision_model(inputs).last_hidden_state[:, 1:, :]
projected = self.projector(image_features)
return projected
def freeze_modality_embedder(self):
for parameters in self.feature_extractor.parameters():
parameters.requires_grad = False
def unfreeze_modality_embedder(self):
for parameters in self.feature_extractor.parameters():
parameters.requires_grad = True
def unfreeze_projection(self):
for parameters in self.projector.parameters():
parameters.requires_grad = True
A modality class must implement 3 functions:
forward(): this is the definition of the forward pass (which include the forward of both the modality embedder and the projection module)freeze_modality_embedder(): this function freezes the parameters of the modality embedder onlyunfreeze_modality_embedder(): this function unfreezes the parameters of the modality embedderunfreeze_projection(): this function unfreezes the parameters of the projection module
Those “freezing” functions are used to train different part of the whole MultiMeditron architecture to ensure training stability.
TODO: Redirect to creating dataset + launching training