pytorch_manager
Documentation for PytorchTritonModelStorageManager
¶
Functionality¶
This class manages the storage of PyTorch models for deployment with Triton. It saves the model's state dictionary and generates a Python script that recreates the model structure during initialization, ensuring that dynamic control flow is maintained.
Purpose¶
- Save the model parameters separately from the definition.
- Generate a reconstruction script for model loading in Triton.
- Support both standard and sequential PyTorch models.
Motivation¶
By decoupling the model weights from its architecture, this approach offers flexibility in handling models with dynamic control flows and varying structures. It simplifies updates and allows easier integration with Triton's LibTorch backend.
Inheritance¶
This class inherits from TritonModelStorageManager
, extending its functionality specifically for PyTorch models.
Method: PytorchTritonModelStorageManager._get_model_artifacts
¶
Functionality¶
This method retrieves the list of artifact filenames that are expected to be located in the model directory for a PyTorch model. It ensures that both the model's state dictionary and the Python initialization script are present.
Parameters¶
None.
Returns¶
- List[str]: A list with two filenames:
- "model.pt": The file containing the model's state dictionary.
- "model.py": The Python script to initialize the model.
Usage¶
- Purpose: To identify the required files for deploying the model with Triton Inference Server.
Example¶
artifacts = pytorch_manager._get_model_artifacts()
print(artifacts)
# Output: ['model.pt', 'model.py']
Method: PytorchTritonModelStorageManager._generate_triton_config_model_info
¶
Functionality¶
Generates configuration lines for Triton model info. It creates a list of strings that configures the model name, platform, maximum batch size, and Python runtime script.
Parameters¶
None.
Usage¶
Use this method to obtain the base Triton configuration for a PyTorch model storage manager. The configuration includes the model name, backend, max_batch_size, and the runtime script path.
Example¶
manager = PytorchTritonModelStorageManager(storage_info, do_dynamic_batching)
config_lines = manager._generate_triton_config_model_info()
Method: PytorchTritonModelStorageManager._save_model
¶
Functionality¶
This method saves a PyTorch model by serializing its state dict and creating a Python script to reinitialize the model for Triton. It generates different scripts for regular and Sequential models.
Parameters¶
model
: A PyTorch model (nn.Module) whose parameters are saved.example_inputs
: A dict mapping input names to example tensors for illustrating the model's expected input structure.
Usage¶
- Purpose: Stores the model weights and produces an initialization script for Triton inference serving.
Example¶
For a given model and example inputs, save the model with:
manager._save_model(my_model, {"input": tensor_data})