multimeditron.model.modalities.moe package¶
Submodules¶
multimeditron.model.modalities.moe.gating module¶
- class multimeditron.model.modalities.moe.gating.GatingNetwork(config: GatingNetworkConfig, resnet_path: str | None = None)¶
Bases:
PreTrainedModelA Gating Network model that uses a pretrained ResNet50 as the backbone. This model outputs logits for each expert, selects the top-k experts, and computes softmax weights. .. attribute:: config_class
The configuration class for the model.
- type:
GatingNetworkConfig
- resnet¶
The ResNet50 model used as the backbone.
- Type:
nn.Module
- processor¶
The image processor for preprocessing input images.
- Type:
AutoImageProcessor
- top_k¶
Number of top predictions to consider for gating.
- Type:
int
- config_class¶
alias of
GatingNetworkConfig
- forward(pixel_values: Tensor) Tuple[Tensor, Tensor, Tensor]¶
Forward pass of the GatingNetwork. :param pixel_values: Input image tensor of shape (batch_size, channels, height, width). :type pixel_values: torch.Tensor
- Returns:
Logits for each expert of shape (batch_size, num_classes). topk_indices (torch.Tensor): Indices of the top-k experts of shape (batch_size, top_k). weights (torch.Tensor): Softmax weights for each expert of shape (batch_size, num_classes).
- Return type:
logits (torch.Tensor)
- preprocess_images(images: List[Image]) Tensor¶
Preprocesses input images using the image processor. :param images: List of input images or a tensor. :type images: List[PIL.Image] or torch.Tensor
- Returns:
Preprocessed image tensor.
- Return type:
torch.Tensor
- class multimeditron.model.modalities.moe.gating.GatingNetworkConfig(num_classes: int = 2, top_k: int = 1, image_processor_path: str = 'openai/clip-vit-base-patch32', class_names: List[str] = [], **kwargs)¶
Bases:
PretrainedConfigConfiguration class for the Gating Network model. .. attribute:: num_classes
Number of output classes for the gating network (number of experts).
- type:
int
- top_k¶
Number of top predictions to consider for gating.
- Type:
int
- model_type: str = 'gating_network'¶