Skip to content

pytorch_manager

Documentation for PytorchTritonModelStorageManager

Functionality

This class manages the storage of PyTorch models for deployment with Triton. It saves the model's state dictionary and generates a Python script that recreates the model structure during initialization, ensuring that dynamic control flow is maintained.

Purpose

  • Save the model parameters separately from the definition.
  • Generate a reconstruction script for model loading in Triton.
  • Support both standard and sequential PyTorch models.

Motivation

By decoupling the model weights from its architecture, this approach offers flexibility in handling models with dynamic control flows and varying structures. It simplifies updates and allows easier integration with Triton's LibTorch backend.

Inheritance

This class inherits from TritonModelStorageManager, extending its functionality specifically for PyTorch models.


Method: PytorchTritonModelStorageManager._get_model_artifacts

Functionality

This method retrieves the list of artifact filenames that are expected to be located in the model directory for a PyTorch model. It ensures that both the model's state dictionary and the Python initialization script are present.

Parameters

None.

Returns

  • List[str]: A list with two filenames:
  • "model.pt": The file containing the model's state dictionary.
  • "model.py": The Python script to initialize the model.

Usage

  • Purpose: To identify the required files for deploying the model with Triton Inference Server.

Example

artifacts = pytorch_manager._get_model_artifacts()
print(artifacts)
# Output: ['model.pt', 'model.py']

Method: PytorchTritonModelStorageManager._generate_triton_config_model_info

Functionality

Generates configuration lines for Triton model info. It creates a list of strings that configures the model name, platform, maximum batch size, and Python runtime script.

Parameters

None.

Usage

Use this method to obtain the base Triton configuration for a PyTorch model storage manager. The configuration includes the model name, backend, max_batch_size, and the runtime script path.

Example

manager = PytorchTritonModelStorageManager(storage_info, do_dynamic_batching)
config_lines = manager._generate_triton_config_model_info()

Method: PytorchTritonModelStorageManager._save_model

Functionality

This method saves a PyTorch model by serializing its state dict and creating a Python script to reinitialize the model for Triton. It generates different scripts for regular and Sequential models.

Parameters

  • model: A PyTorch model (nn.Module) whose parameters are saved.
  • example_inputs: A dict mapping input names to example tensors for illustrating the model's expected input structure.

Usage

  • Purpose: Stores the model weights and produces an initialization script for Triton inference serving.

Example

For a given model and example inputs, save the model with:

manager._save_model(my_model, {"input": tensor_data})