Skip to content

interface

Documentation for EmbeddingsModelInterface

Overview

EmbeddingsModelInterface is an abstract class that standardizes the interface for embedding models used in fine-tuning tasks. It handles both query and item representations, supporting multi-domain scenarios.

Main Purposes

  • Provide a consistent way to access model parameters for query and item components.
  • Facilitate fine-tuning by requiring concrete implementations of core model input and parameter methods.
  • Offer a base structure that enforces uniform method signatures across various model implementations.

Motivation

This interface simplifies managing models that deal with two entities: queries and items. By having standard abstract methods, it ensures that any subclass will implement the necessary functionalities for training, inference, and fine-tuning within the PyTorch Lightning framework.

Inheritance

EmbeddingsModelInterface inherits from pytorch_lightning.LightningModule, which integrates it into the PyTorch Lightning ecosystem, making it easier to leverage training, validation, and distributed execution features.

Method Documentation

get_query_model_params

Functionality: This method provides an iterator over the parameters of the query model. It is used to access the model's parameters for fine-tuning and inference configuration.

Parameters: None.

Usage: - Purpose: Retrieve model parameters for the query branch.

Example:

def get_query_model_params(self) -> Iterator[Parameter]:
    return self.query_model.parameters()


get_items_model_params

Functionality: This method returns an iterator over the parameters of the items model. It provides access to the parameters that are used for fine-tuning the items component of an embedding model.

Parameters: None.

Usage: - Purpose: Retrieve items model parameters for optimization or tracing.

Example:

def get_items_model_params(self) -> Iterator[Parameter]:
    return self.items_model.parameters()


is_named_inputs

Functionality: Returns a boolean indicating if the model expects named inputs.

Parameters: None. This is a property, not a method.

Usage: Use this property to check if the model requires named inputs.

Example:

@property
def is_named_inputs(self) -> bool:
    return True  # When inputs are specified as a named dictionary


get_query_model_inputs

Functionality: Returns a dictionary of input tensors for the query model, typically used for model tracing.

Parameters: - device: The device to place tensors on. If None, the model's device is used.

Usage: - Purpose: Generate sample inputs for the query model during tracing.

Example:

def get_query_model_inputs(self, device=None) -> Dict[str, Tensor]:
    inputs = self.tokenizer("example query", return_tensors="pt")
    device = device if device else self.device
    return {k: v.to(device) for k, v in inputs.items()}


get_items_model_inputs

Functionality: This method provides example inputs for the items model, mainly used for model tracing.

Parameters: - device: (Optional) Device to place the tensors on. If None, the model's device is used.

Usage: - Purpose: To generate a dictionary of input tensors for tracing the items model.

Example:

def example_usage(model):
    inputs = model.get_items_model_inputs(device="cuda")
    print(inputs)


get_query_model_inference_manager_class

Functionality: This method returns the Triton model storage manager class for query model inference.

Parameters: None.

Usage: - Purpose: Specify the Triton model storage manager class used for query model inference.

Example:

def get_query_model_inference_manager_class(self) -> Type[TritonModelStorageManager]:
    return JitTraceTritonModelStorageManager


get_items_model_inference_manager_class

Functionality: Returns the class for managing items model inference in Triton.

Parameters: None.

Usage: - Purpose: Provide the items model inference manager class for Triton.

Example:

def get_items_model_inference_manager_class(self) -> Type[TritonModelStorageManager]:
    return SomeTritonModelStorageManager


fix_query_model

Functionality: Fixes a specific number of layers in the query model by freezing them.

Parameters: - num_fixed_layers: Number of layers to freeze from the bottom of the model.

Usage: - Purpose: Freeze layers to prevent weight updates during training.

Example:

def fix_query_model(self, num_fixed_layers: int):
    if len(self.query_model.encoder.layers) <= num_fixed_layers:
        raise ValueError(
            f"Number of fixed layers ({num_fixed_layers}) >= "
            f"number of existing layers ({len(self.query_model.encoder.layers)})"
        )
    self.query_model.embeddings.requires_grad = False
    for i in range(num_fixed_layers):
        self.query_model.encoder.layers[i].requires_grad = False


unfix_query_model

Functionality: This method unfreezes all layers of the query model by enabling gradients.

Parameters: None.

Usage: Use this method to re-enable gradient computation for all layers of the query model.

Example:

def unfix_query_model(self):
    self.query_model.embeddings.requires_grad = True
    for layer in self.query_model.encoder.layers:
        layer.requires_grad = True


fix_item_model

Functionality: Freeze a given number of layers in the item model by setting their requires_grad attribute to False.

Parameters: - num_fixed_layers: An integer specifying the number of layers to freeze.

Usage: - Purpose: Freeze the initial layers of the item model during fine-tuning.

Example:

def fix_item_model(self, num_fixed_layers: int):
    if len(self.items_model.encoder.layers) <= num_fixed_layers:
        raise ValueError(
            f"Number of fixed layers ({num_fixed_layers}) is greater than or "
            f"equal to total layers "
            f"({len(self.items_model.encoder.layers)})")
    self.items_model.embeddings.requires_grad = False
    for i in range(num_fixed_layers):
        self.items_model.encoder.layers[i].requires_grad = False


unfix_item_model

Functionality: Unfixes all layers of the item model by enabling gradients.

Parameters: None.

Usage: Enable fine-tuning by unfreezing the item model.

Example:

def unfix_item_model(self):
    self.items_model.embeddings.requires_grad = True
    for layer in self.items_model.encoder.layers:
        layer.requires_grad = True


forward_query

Functionality: Processes a query through the query model to produce an embedding tensor.

Parameters: - query: Input query which may be text, features, or another format.

Returns: A FloatTensor containing the embedding of the query.

Usage: Converts the given query into an embedding vector.

Example:

def forward_query(self, query: str) -> FloatTensor:
    if len(query) == 0:
        logger.warning("Provided query is empty")
    tokenized = self.tokenize(query)
    return self.query_model(
        input_ids=tokenized["input_ids"].to(self.device),
        attention_mask=tokenized["attention_mask"].to(self.device)
    )


forward_items

Functionality: Processes a list of items through the items model to generate embedding tensors.

Parameters: - items: List of items, must not be empty.

Usage: Compute embeddings in a single forward pass.

Example:

def forward_items(self, items: List[str]) -> FloatTensor:
    if len(items) == 0:
        raise ValueError("items list must not be empty")
    tokenized = self.tokenize(items)
    return self.items_model(
        input_ids=tokenized["input_ids"].to(self.device),
        attention_mask=tokenized["attention_mask"].to(self.device)
    )