multimeditron.model.modalities.moe package

Submodules

multimeditron.model.modalities.moe.gating module

class multimeditron.model.modalities.moe.gating.GatingNetwork(config: GatingNetworkConfig, resnet_path: str | None = None)

Bases: PreTrainedModel

A Gating Network model that uses a pretrained ResNet50 as the backbone. This model outputs logits for each expert, selects the top-k experts, and computes softmax weights. .. attribute:: config_class

The configuration class for the model.

type:

GatingNetworkConfig

resnet

The ResNet50 model used as the backbone.

Type:

nn.Module

processor

The image processor for preprocessing input images.

Type:

AutoImageProcessor

top_k

Number of top predictions to consider for gating.

Type:

int

config_class

alias of GatingNetworkConfig

forward(pixel_values: Tensor) Tuple[Tensor, Tensor, Tensor]

Forward pass of the GatingNetwork. :param pixel_values: Input image tensor of shape (batch_size, channels, height, width). :type pixel_values: torch.Tensor

Returns:

Logits for each expert of shape (batch_size, num_classes). topk_indices (torch.Tensor): Indices of the top-k experts of shape (batch_size, top_k). weights (torch.Tensor): Softmax weights for each expert of shape (batch_size, num_classes).

Return type:

logits (torch.Tensor)

preprocess_images(images: List[Image]) Tensor

Preprocesses input images using the image processor. :param images: List of input images or a tensor. :type images: List[PIL.Image] or torch.Tensor

Returns:

Preprocessed image tensor.

Return type:

torch.Tensor

class multimeditron.model.modalities.moe.gating.GatingNetworkConfig(num_classes: int = 2, top_k: int = 1, image_processor_path: str = 'openai/clip-vit-base-patch32', class_names: List[str] = [], **kwargs)

Bases: PretrainedConfig

Configuration class for the Gating Network model. .. attribute:: num_classes

Number of output classes for the gating network (number of experts).

type:

int

top_k

Number of top predictions to consider for gating.

Type:

int

model_type: str = 'gating_network'

Module contents