Documentation for differentiable_mean_small_values
¶
Functionality¶
The differentiable_mean_small_values
function computes a differentiable approximation of the mean of tensor elements that are below a specified threshold. It employs a soft indicator to weigh values that are near the threshold, resulting in a weighted mean that supports gradient flow during backpropagation, making it ideal for applications in neural network training.
Parameters¶
x
: Atorch.FloatTensor
containing the input tensor values.threshold
: A float that specifies the cutoff value for determining which elements are considered "small."steepness
: An integer that controls the sharpness of the soft indicator function, affecting how quickly the weights transition around the threshold.
Usage¶
Purpose¶
Use this function to derive a mean of small values in a way that is differentiable, facilitating the training of neural networks.
Example¶
import torch
from embedding_studio.embeddings.models.utils.differentiable_mean import differentiable_mean_small_values
x = torch.tensor([0.005, 0.02, 0.009])
result = differentiable_mean_small_values(x, threshold=0.01)
print(result)