Skip to content

Documentation for BertModelSimplifiedWrapper and TextToTextBertModel

BertModelSimplifiedWrapper Class

Functionality

BertModelSimplifiedWrapper wraps a Hugging Face BERT model to produce text embeddings. It extracts the pooler output from the underlying model.

Inheritance

Inherits from torch.nn.Module, the base for neural network modules in PyTorch.

Motivation

Simplifies usage of BERT by exposing only the essential embedding output, hiding underlying complexity for text embedding tasks.

Documentation for BertModelSimplifiedWrapper.forward

Functionality

This method computes the pooled output from a BERT model by using its pooler output. It accepts token IDs and attention masks, passes them to the underlying BERT model, and returns the transformed pooler output.

Parameters

  • input_ids (Tensor): Tensor containing token IDs of the input text.
  • attention_mask (Tensor): Tensor indicating which tokens to attend.

Returns

  • (Tensor): The pooler output with a linear transformation and tanh activation applied on the first token.

Usage

  • Purpose: Generate text embeddings using the BERT model's pooler output.

Example

from transformers import AutoModel
import torch

model = AutoModel.from_pretrained("bert-base-uncased")
wrapper = BertModelSimplifiedWrapper(model)
inputs = {
    "input_ids": torch.tensor([[101, 102]]),
    "attention_mask": torch.tensor([[1, 1]])
}
emb_output = wrapper.forward(**inputs)

TextToTextBertModel Class

Functionality

TextToTextBertModel is a wrapper that encapsulates a BERT model and its tokenizer into a unified interface for generating text embeddings. It leverages the HuggingFace transformers AutoModel and AutoTokenizer to compute embeddings from input texts using BERT's pooler output.

Parameters

  • bert_model: A string identifier for a pretrained BERT model or an existing AutoModel instance.
  • bert_tokenizer: The tokenizer for processing texts. If omitted, it is auto-loaded based on the model's configuration.
  • max_length: Maximum number of tokens to consider during tokenization.

Usage

  • Purpose: Provides a simple interface for generating text embeddings for both queries and items, ensuring consistency in usage.
  • Inheritance: Inherits from EmbeddingsModelInterface to conform with the repository's embedding model structure.

Example

from transformers import AutoModel

# Initialize the model using a pretrained BERT model
model = TextToTextBertModel(
    AutoModel.from_pretrained('bert-base-uncased')
)

# Generate embeddings using the query model
embeddings = model.get_query_model()(input_ids, attention_mask)

Documentation for TextToTextBertModel Methods

get_query_model

Functionality

Returns the model used for query processing. It is the wrapped BERT model shared with item processing. Use this to obtain the query encoder.

Parameters

This method takes no parameters.

Usage

Call this method to get the model for encoding queries.

Example
model_inst = TextToTextBertModel("bert-base-uncased")
query_model = model_inst.get_query_model()
output = query_model(input_ids, attention_mask)

get_items_model

Functionality

Returns the wrapped BERT model used for processing items. Since the same model is used for both query and items, it simply returns the model component.

Parameters

None.

Usage
  • Purpose: Retrieve the model used to process item data in the text-to-text BERT implementation.
Example
items_model = text_to_text_bert.get_items_model()

get_query_model_params

Functionality

Returns an iterator over the parameters of the query model. This iterator can be used to apply optimizers or other processing steps on the model parameters during training or evaluation.

Parameters

None.

Usage
  • Purpose: Retrieve all query model parameters for further processing, such as in training loops or custom optimization routines.
Example
model = TextToTextBertModel('bert-base-uncased')
params = model.get_query_model_params()
for param in params:
    print(param.shape)

get_items_model_params

Functionality

Returns an iterator over the parameters of the items model. Since the query and items models are the same, this method returns the same parameters as get_query_model_params.

Parameters

None.

Usage

Use this method to collect model parameters for training or fine-tuning operations where the items model parameters are needed.

Example
model = TextToTextBertModel(bert_model, bert_tokenizer)
for param in model.get_items_model_params():
    print(param.shape)

is_named_inputs

Functionality

Indicates if the model uses named inputs. BERT models require inputs named "input_ids" and "attention_mask" for proper functioning.

Parameters

This property does not accept any parameters.

Usage
  • Purpose: Signals that the model expects named inputs for generating text embeddings.
Example
model = TextToTextBertModel("bert-base-uncased")
print(model.is_named_inputs)  # Output: True

get_query_model_inputs

Functionality

This method returns sample inputs for model tracing by tokenizing predefined text using TEST_INPUT_TEXTS. It filters the output to keep only 'input_ids' and 'attention_mask', and moves the tensors to the specified device.

Parameters
  • device: Device to place the tensors. If None, the model's device is used.
Usage

Call this method to generate example inputs for model tracing. It is useful for debugging or exporting the model.

Example
inputs = model.get_query_model_inputs(device=torch.device('cpu'))
print(inputs)

get_items_model_inputs

Functionality

Retrieves example inputs for the items model. This method returns the same inputs as get_query_model_inputs, offering a consistent approach for model tracing and inference.

Parameters
  • device: Optional parameter specifying the target device for tensor placement. If not provided, the model's default device is used.
Usage
  • Purpose: Generate standardized input tensors for the items model to facilitate efficient model tracing and inference.
Example
device = torch.device('cuda')
inputs = model.get_items_model_inputs(device)

get_query_model_inference_manager_class

Functionality

This method returns the class used to manage query model inference in Triton deployments. It supports tracing and serving the model efficiently during inference.

Parameters

None.

Return

Returns the JitTraceTritonModelStorageManager class which handles inference management for the query model.

Usage
  • Purpose: Acquire the proper inference manager class for preparing the query model for Triton inference.
Example
model = TextToTextBertModel(bert_model, bert_tokenizer)
manager_class = model.get_query_model_inference_manager_class()
# manager_class is JitTraceTritonModelStorageManager

get_items_model_inference_manager_class

Functionality

This method returns the class used to manage items model inference for Triton. It provides the same inference manager as used for queries, ensuring consistency in model deployment.

Parameters

None.

Usage
  • Purpose: Acquire the inference manager class for the items model in Triton.
Example
# Obtain the inference manager class
manager_class = text_to_text_bert_model.get_items_model_inference_manager_class()

# Create an inference manager instance
inference_manager = manager_class(model_instance)

fix_query_model

Functionality

Freezes the embeddings and a given number of lower encoder layers during fine-tuning. This prevents these layers from updating during training by setting requires_grad to False.

Parameters
  • num_fixed_layers: Number of layers from the bottom to freeze.
Usage

Call fix_query_model during training to lock lower model layers.

Example
model = TextToTextBertModel("bert-base-uncased")
model.fix_query_model(3)

unfix_query_model

Functionality

This method enables all layers of the query model for fine-tuning. It resets the frozen state by setting the requires_grad attribute to True on the model embeddings and every encoder layer. This allows the model to update its weights during training.

Parameters

This method does not take any parameters.

Usage
  • Purpose: Use this method when you want to unfreeze all layers of the query model to allow full training.
Example
model.unfix_query_model()

fix_item_model

Functionality

This method freezes a specified number of layers in the item model during fine-tuning. Since query and item models use the same BERT instance, it simply calls fix_query_model to freeze the corresponding layers.

Parameters
  • num_fixed_layers: An integer representing the number of layers from the bottom of the model to freeze. This must be less than the total number of layers, otherwise a ValueError is raised.
Usage
  • Purpose - Ensures that a subset of the model layers remain unchanged during fine-tuning, preserving the performance of pretrained layers.
Example
model.fix_item_model(3)

unfix_item_model

Functionality

This method enables gradient updates for all layers of the item model. Since query and items use the same model, it reactivates training by calling the underlying unfix_query_model method.

Parameters

None.

Usage
  • Purpose: Re-enable training for the item model by unfixing all of its layers.
Example
model = TextToTextBertModel('bert-base-uncased')
model.unfix_item_model()

tokenize

Functionality

Tokenizes a text query or a list of queries into a dictionary of tensors. The method uses the underlying BERT tokenizer with specified maximum length, padding, and truncation settings.

Parameters
  • query: A string or a list of strings representing the input text.
Usage
  • Purpose: Prepares text input for the BERT model by generating token ids and attention masks.
Example

Tokenizing a single query:

tokenized = model.tokenize("Example query")

Tokenizing multiple queries:

queries = ["First query", "Second query"]
tokenized = model.tokenize(queries)


forward_query

Functionality

This method tokenizes a provided text query and returns its embedding as a tensor. If the input query is empty, a warning is logged.

Parameters
  • query: A non-empty string representing the text query.
Returns

A tensor (FloatTensor or Tensor) containing the query embedding.

Usage
  • Purpose: Convert a text query into an embedding for search or similarity computations.
Example
embedding = model.forward_query("Your query text")

forward_items

Functionality

Processes a list of text items through the BERT model and returns their embedding representations. It tokenizes the input texts and computes embeddings using the forward method of the wrapped model.

Parameters
  • items: List[str] containing text items to encode.
Usage
  • Purpose: Generate embedding vectors for a list of text items.
Example
items = ["hello", "world"]
embeddings = model.forward_items(items)