Skip to content

Documentation for differentiable_mean_small_values

Functionality

The differentiable_mean_small_values function computes a differentiable approximation of the mean of tensor elements that are below a specified threshold. It employs a soft indicator to weigh values that are near the threshold, resulting in a weighted mean that supports gradient flow during backpropagation, making it ideal for applications in neural network training.

Parameters

  • x: A torch.FloatTensor containing the input tensor values.
  • threshold: A float that specifies the cutoff value for determining which elements are considered "small."
  • steepness: An integer that controls the sharpness of the soft indicator function, affecting how quickly the weights transition around the threshold.

Usage

Purpose

Use this function to derive a mean of small values in a way that is differentiable, facilitating the training of neural networks.

Example

import torch
from embedding_studio.embeddings.models.utils.differentiable_mean import differentiable_mean_small_values

x = torch.tensor([0.005, 0.02, 0.009])
result = differentiable_mean_small_values(x, threshold=0.01)
print(result)