Implementing Custom Category Selectors in Embedding Studio¶
This tutorial will guide you through implementing custom category selectors for Embedding Studio's query parsing system. Category selectors determine which categories are relevant for a given search query based on various selection algorithms and distance metrics.
Understanding the Selector Architecture¶
Embedding Studio uses a hierarchical architecture for category selectors:
- AbstractSelector - The base interface for all selectors
- DistBasedSelector - Selectors that work with pre-calculated distances
- VectorsBasedSelector - Selectors that work directly with embedding vectors
Let's explore each level and how to implement your own custom selectors.
Base Selector Interface¶
All selectors implement the AbstractSelector
interface:
class AbstractSelector(ABC):
@property
@abstractmethod
def vectors_are_needed(self) -> bool:
"""Indicates whether this selector requires access to actual embedding vectors."""
raise NotImplementedError
@abstractmethod
def select(
self,
categories: List[ObjectWithDistance],
query_vector: Optional[torch.Tensor] = None,
) -> List[int]:
"""Selects indices of objects that meet selection criteria."""
raise NotImplementedError
Key methods:
- vectors_are_needed
: Indicates if the selector needs the raw vectors (vs. just distances)
- select
: Returns indices of categories that meet the selection criteria
Distance-Based Selectors¶
For many applications, you can work with pre-calculated distances using the DistBasedSelector
:
class DistBasedSelector(AbstractSelector):
def __init__(
self,
search_index_info: SearchIndexInfo,
is_similarity: bool = False,
margin: float = 0.2,
softmin_temperature: float = 1.0,
scale_to_one: bool = False,
):
self._search_index_info = search_index_info
self._is_similarity = is_similarity
self._margin = margin
self._softmin_temperature = softmin_temperature
self._scale_to_one = scale_to_one
@property
def vectors_are_needed(self) -> bool:
return False
@abstractmethod
def _calculate_binary_labels(
self, corrected_values: torch.Tensor
) -> torch.Tensor:
"""Calculates binary selection labels (0 or 1) from corrected distance values."""
raise NotImplementedError
def select(
self,
categories: List[ObjectWithDistance],
query_vector: Optional[torch.Tensor] = None,
) -> List[int]:
values = self._convert_values(categories)
positive_threshold_min = 1 - self._margin if self._is_similarity else self._margin
corrected_values = values - positive_threshold_min
bin_labels = self._calculate_binary_labels(corrected_values)
return torch.nonzero(bin_labels).T[0].tolist()
Key features: - Works with pre-calculated distance values (doesn't need vectors) - Handles different types of distance metrics and normalization - Requires subclasses to implement the actual selection logic
Example: Probability-Based Selector¶
The default ProbsDistBasedSelector
uses a sigmoid function to convert distances to probabilities:
class ProbsDistBasedSelector(DistBasedSelector):
def __init__(
self,
search_index_info: SearchIndexInfo,
is_similarity: bool = False,
margin: float = 0.2,
softmin_temperature: float = 1.0,
scale: float = 10.0,
prob_threshold: float = 0.5,
scale_to_one: bool = False,
):
super().__init__(
search_index_info=search_index_info,
is_similarity=is_similarity,
margin=margin,
softmin_temperature=softmin_temperature,
scale_to_one=scale_to_one,
)
self._scale = scale
self._prob_threshold = prob_threshold
def _calculate_binary_labels(
self, corrected_values: torch.Tensor
) -> torch.Tensor:
return (
torch.sigmoid(corrected_values * self._scale)
> self._prob_threshold
)
This selector: 1. Converts adjusted distance values to probabilities using sigmoid 2. Compares probabilities against a threshold 3. Returns a binary tensor (1 for selected, 0 for not selected)
Implementing Your Own Selector¶
1. Implementing a Threshold-Based Selector¶
Here's a simple threshold-based selector that's less complex than the probability-based approach:
class SimpleThresholdSelector(DistBasedSelector):
def _calculate_binary_labels(
self, corrected_values: torch.Tensor
) -> torch.Tensor:
# Simply checks if corrected values are positive
# (which means they passed the margin threshold)
return corrected_values > 0
This selector selects categories where: - For similarity metrics: similarity > (1 - margin) - For distance metrics: distance < margin
2. Implementing a Top-K Selector¶
This selector always returns the top K categories regardless of absolute distance:
class TopKSelector(DistBasedSelector):
def __init__(
self,
search_index_info: SearchIndexInfo,
k: int = 3,
**kwargs
):
super().__init__(search_index_info=search_index_info, **kwargs)
self.k = k
def _calculate_binary_labels(
self, corrected_values: torch.Tensor
) -> torch.Tensor:
if len(corrected_values) <= self.k:
# If we have fewer values than k, select all
return torch.ones_like(corrected_values, dtype=torch.bool)
# Get indices of top-k values
_, indices = torch.topk(corrected_values, self.k)
# Create a binary mask with 1s at top-k indices
mask = torch.zeros_like(corrected_values, dtype=torch.bool)
mask[indices] = True
return mask
This selector always selects exactly K categories (or all if there are fewer than K).
3. Implementing a Dynamic Threshold Selector¶
This selector adapts its threshold based on the distribution of values:
class DynamicThresholdSelector(DistBasedSelector):
def __init__(
self,
search_index_info: SearchIndexInfo,
percentile: float = 75.0,
min_threshold: float = 0.1,
**kwargs
):
super().__init__(search_index_info=search_index_info, **kwargs)
self.percentile = percentile
self.min_threshold = min_threshold
def _calculate_binary_labels(
self, corrected_values: torch.Tensor
) -> torch.Tensor:
if len(corrected_values) == 0:
return torch.tensor([], dtype=torch.bool)
# Calculate a threshold at the specified percentile
threshold = max(
torch.quantile(corrected_values, self.percentile / 100.0).item(),
self.min_threshold
)
# Select values above the threshold
return corrected_values > threshold
This selector dynamically adjusts based on the distribution of distances, selecting categories above a percentile threshold.
Vector-Based Selectors¶
For more advanced scenarios, you can work directly with embedding vectors using VectorsBasedSelector
:
class CustomVectorSelector(VectorsBasedSelector):
def __init__(
self,
search_index_info: SearchIndexInfo,
reference_vectors: List[torch.Tensor],
**kwargs
):
super().__init__(search_index_info=search_index_info, **kwargs)
self.reference_vectors = reference_vectors
def _calculate_binary_labels(
self, corrected_values: torch.Tensor
) -> torch.Tensor:
# Your custom selection logic
return corrected_values > 0
def select(
self,
categories: List[ObjectWithDistance],
query_vector: Optional[torch.Tensor] = None,
) -> List[int]:
if query_vector is None:
return []
# Get tensor of category vectors
category_vectors = self._get_categories_tensor(categories)
# Compare with reference vectors for additional context
reference_similarities = torch.stack([
F.cosine_similarity(query_vector, ref.unsqueeze(0))
for ref in self.reference_vectors
])
# Use reference vector information to inform selection
# (This is just an example - implement your own logic)
if torch.max(reference_similarities) > 0.8:
# If query is very similar to a reference vector,
# be more selective in category matching
margin = self._margin * 0.8
else:
# Otherwise use standard margin
margin = self._margin
# Calculate distances using adjusted margin
values = self._calculate_distance(
query_vector,
category_vectors,
self._softmin_temperature,
self._is_similarity,
)
positive_threshold_min = 1 - margin if self._is_similarity else margin
corrected_values = values - positive_threshold_min
bin_labels = self._calculate_binary_labels(corrected_values)
return torch.nonzero(bin_labels).T[1].tolist()
This selector: 1. Takes reference vectors for additional context 2. Dynamically adjusts its selection threshold based on query similarity to references 3. Works directly with vector embeddings for more sophisticated matching
Registering Your Custom Selector¶
To use your custom selector, you need to register it with the category selection system:
# In your plugin initialization code
from embedding_studio.embeddings.selectors.dist_based_selector import DistBasedSelector
from my_custom_selectors import SimpleThresholdSelector
class MyPlugin:
def get_category_selector(self) -> AbstractSelector:
# Create and return your custom selector
return SimpleThresholdSelector(
search_index_info=self.search_index_info,
is_similarity=False, # Using distance metrics (lower is better)
margin=0.25, # Select categories with distance < 0.25
)
Working with Different Distance Metrics¶
Your selector implementation needs to consider the distance metric being used:
Cosine Distance¶
class CosineSelector(DistBasedSelector):
def __init__(self, **kwargs):
super().__init__(
is_similarity=False, # We're working with distance, not similarity
**kwargs
)
def _convert_values(self, categories: List[ObjectWithDistance]) -> torch.Tensor:
values = []
for category in categories:
# For cosine, smaller distance means more similar
value = category.distance
values.append(value)
return torch.tensor(values)
Dot Product¶
class DotProductSelector(DistBasedSelector):
def __init__(self, **kwargs):
super().__init__(
is_similarity=True, # For dot product, higher is more similar
**kwargs
)
def _convert_values(self, categories: List[ObjectWithDistance]) -> torch.Tensor:
values = []
for category in categories:
# For dot product, we negate the value since our system expects distances
value = -category.distance
values.append(value)
return torch.tensor(values)
Advanced Techniques¶
1. Combining Multiple Selectors¶
You can create meta-selectors that combine the results from multiple selection strategies:
class CombinedSelector(AbstractSelector):
def __init__(self, selectors: List[AbstractSelector]):
self.selectors = selectors
@property
def vectors_are_needed(self) -> bool:
# We need vectors if any selector needs them
return any(selector.vectors_are_needed for selector in self.selectors)
def select(
self,
categories: List[ObjectWithDistance],
query_vector: Optional[torch.Tensor] = None,
) -> List[int]:
# Get selections from each selector
all_selections = []
for selector in self.selectors:
selections = selector.select(categories, query_vector)
all_selections.extend(selections)
# Remove duplicates and sort
return sorted(list(set(all_selections)))
2. Context-Aware Selectors¶
This selector adjusts its behavior based on the query context:
class ContextAwareSelector(DistBasedSelector):
def __init__(
self,
search_index_info: SearchIndexInfo,
strict_keywords: List[str],
relaxed_keywords: List[str],
strict_margin: float = 0.15,
relaxed_margin: float = 0.3,
**kwargs
):
super().__init__(search_index_info=search_index_info, **kwargs)
self.strict_keywords = [k.lower() for k in strict_keywords]
self.relaxed_keywords = [k.lower() for k in relaxed_keywords]
self.strict_margin = strict_margin
self.relaxed_margin = relaxed_margin
self.query_text = ""
def select(
self,
categories: List[ObjectWithDistance],
query_vector: Optional[torch.Tensor] = None,
) -> List[int]:
# Determine which margin to use based on query
effective_margin = self._margin # Default
query_lower = self.query_text.lower()
# Use strict matching for certain queries
if any(keyword in query_lower for keyword in self.strict_keywords):
effective_margin = self.strict_margin
# Use relaxed matching for other queries
elif any(keyword in query_lower for keyword in self.relaxed_keywords):
effective_margin = self.relaxed_margin
# Then proceed with selection using the context-appropriate margin
values = self._convert_values(categories)
positive_threshold_min = 1 - effective_margin if self._is_similarity else effective_margin
corrected_values = values - positive_threshold_min
bin_labels = self._calculate_binary_labels(corrected_values)
return torch.nonzero(bin_labels).T[0].tolist()
Testing Your Selector¶
Before deploying a custom selector, test it thoroughly:
def test_selector(selector, categories, query_vector=None):
"""Test a selector with sample data."""
selected_indices = selector.select(categories, query_vector)
print(f"Selected {len(selected_indices)} out of {len(categories)} categories")
for idx in selected_indices:
category = categories[idx]
print(f"- {category.object_id}: distance={category.distance:.4f}")
return selected_indices
# Create test data
test_categories = [
ObjectWithDistance(object_id="category1", distance=0.1, payload={"name": "Category 1"}),
ObjectWithDistance(object_id="category2", distance=0.2, payload={"name": "Category 2"}),
ObjectWithDistance(object_id="category3", distance=0.3, payload={"name": "Category 3"}),
ObjectWithDistance(object_id="category4", distance=0.4, payload={"name": "Category 4"}),
ObjectWithDistance(object_id="category5", distance=0.5, payload={"name": "Category 5"}),
]
# Test your selector
test_selector(SimpleThresholdSelector(
search_index_info=mock_search_info,
margin=0.35
), test_categories)
Debugging and Troubleshooting¶
When implementing custom selectors, these debugging tips can help:
Visualizing Selection Boundaries¶
import matplotlib.pyplot as plt
def visualize_selector_boundary(selector, max_distance=1.0, points=100):
"""Visualize the selection boundary for a distance-based selector."""
distances = torch.linspace(0, max_distance, points)
# Convert distances to the format expected by the selector
categories = [
ObjectWithDistance(object_id=f"test_{i}", distance=float(d), payload={})
for i, d in enumerate(distances)
]
# Get binary selections
selected = selector.select(categories)
selected_mask = torch.zeros(points, dtype=torch.bool)
for idx in selected:
selected_mask[idx] = True
# Plot
plt.figure(figsize=(10, 4))
plt.scatter(distances[selected_mask], torch.ones_like(distances[selected_mask]),
label='Selected', color='green')
plt.scatter(distances[~selected_mask], torch.ones_like(distances[~selected_mask]),
label='Not Selected', color='red')
plt.axvline(x=selector._margin, linestyle='--', color='blue',
label=f'Margin = {selector._margin}')
plt.xlabel('Distance')
plt.yticks([])
plt.legend()
plt.title('Selector Decision Boundary')
plt.show()
Tracing Selector Logic¶
Add logging statements to your selector to trace its decision process:
import logging
logger = logging.getLogger(__name__)
class TracedSelector(DistBasedSelector):
def select(self, categories, query_vector=None):
logger.debug(f"Selecting from {len(categories)} categories")
values = self._convert_values(categories)
logger.debug(f"Converted values: {values}")
threshold = 1 - self._margin if self._is_similarity else self._margin
logger.debug(f"Using threshold: {threshold}")
corrected_values = values - threshold
logger.debug(f"Corrected values: {corrected_values}")
bin_labels = self._calculate_binary_labels(corrected_values)
logger.debug(f"Binary labels: {bin_labels}")
indices = torch.nonzero(bin_labels).T[0].tolist()
logger.debug(f"Selected indices: {indices}")
return indices
Best Practices¶
When implementing your own selectors, follow these guidelines:
- Start Simple: Begin with a simple selector and add complexity only as needed
- Test Thoroughly: Test with a variety of input data and boundary cases
- Consider Performance: Optimize computation-heavy operations for production use
- Document Behavior: Clearly document how your selector works and its parameters
- Handle Edge Cases: Properly handle empty inputs, single categories, etc.
- Use Appropriate Metrics: Ensure your selector works correctly with the chosen distance metric
Conclusion¶
Custom category selectors give you fine-grained control over how Embedding Studio matches queries to categories. By implementing your own selectors, you can:
- Adjust selection thresholds based on your specific use case
- Implement domain-specific logic for category matching
- Combine different selection strategies for optimal results
- Create context-aware selectors that adapt to different query types