Plugins¶
Embedding Studio supports plugins for fine-tuning models. A plugin is a script that inherits from the
FineTuningMethod
class and implements the upload_initial_model
and get_fine_tuning_builder
methods. Plugins can be of any type; you
can use any libraries and frameworks for model fine-tuning.
The path to the plugins directory is specified in the ES_PLUGINS_PATH
environment variable. By default, it points
to the plugins
directory at the project's root. You can easily change this in the .env
file.
We provide a demonstration plugin named
Default Fine Tuning Method
.
Let's dive into how it works:
class DefaultFineTuningMethod(FineTuningMethod):
meta = PluginMeta(
name="Default Fine Tuning Method",
version="0.0.1",
description="A default fine-tuning plugin",
)
...
The class name can be arbitrary, but it must inherit from FineTuningMethod
. In the meta
field, you specify metadata
about the plugin. This is used by Embedding Studio to determine which plugin to use for fine-tuning. The meta.name
field is essential because it's used to create tasks for fine_tuning_worker
.
Next, let's look at the upload_initial_model
method:
def upload_initial_model(self) -> None:
model = TextToImageCLIPModel(SentenceTransformer("clip-ViT-B-32"))
self.manager.upload_initial_model(model)
In this function, we define the initial model for fine-tuning. In our case, it's the TextToImageCLIPModel
model
composed of the SentenceTransformer
model named clip-ViT-B-32
. We upload it to Mlflow for future use in fine-tuning.
The call to self.manager.upload_initial_model(model)
is mandatory.
Now, let's examine the plugin initialization method. We've tried to describe what each line does in the comments:
def __init__(self):
# uncomment and pass your credentials to use your own s3 bucket
# creds = {
# "role_arn": "arn:aws:iam::123456789012:role/some_data"
# "aws_access_key_id": "TESTACCESSKEIDTEST11",
# "aws_secret_access_key": "QWERTY1232qdsadfasfg5349BBdf30ekp23odk03",
# }
# self.data_loader = AwsS3DataLoader(**creds)
# with empty creds, use anonymous session
creds = {
}
self.data_loader = AWSS3DataLoader(**creds)
self.retriever = TextQueryRetriever()
self.parser = AWSS3ClickstreamParser(
TextQueryItem, SearchResult, DummyEventType
)
self.splitter = ClickstreamSessionsSplitter()
self.normalizer = DatasetFieldsNormalizer("item", "item_id")
self.storage_producer = CLIPItemStorageProducer(self.normalizer)
self.accumulators = [
MetricsAccumulator("train_loss", True, True, True, True),
MetricsAccumulator(
"train_not_irrelevant_dist_shift", True, True, True, True
),
MetricsAccumulator(
"train_irrelevant_dist_shift", True, True, True, True
),
MetricsAccumulator("test_loss"),
MetricsAccumulator("test_not_irrelevant_dist_shift"),
MetricsAccumulator("test_irrelevant_dist_shift"),
]
self.manager = ExperimentsManager(
tracking_uri=settings.MLFLOW_TRACKING_URI,
main_metric="test_not_irrelevant_dist_shift",
accumulators=self.accumulators,
)
self.initial_params = INITIAL_PARAMS
self.initial_params.update(
{
"not_irrelevant_only": [True],
"negative_downsampling": [
0.5,
],
"examples_order": [
[
11,
]
],
}
)
self.settings = FineTuningSettings(
loss_func=CosineProbMarginRankingLoss(),
step_size=35,
test_each_n_sessions=0.5,
num_epochs=3,
)
Finally, let's look at the get_fine_tuning_builder
method:
def get_fine_tuning_builder(
self, clickstream: List[SessionWithEvents]
) -> FineTuningBuilder:
ranking_dataset = prepare_data(
clickstream,
self.parser,
self.splitter,
self.retriever,
self.data_loader,
self.storage_producer,
)
fine_tuning_builder = FineTuningBuilder(
data_loader=self.data_loader,
query_retriever=self.retriever,
clickstream_parser=self.parser,
clickstream_sessions_splitter=self.splitter,
dataset_fields_normalizer=self.normalizer,
item_storage_producer=self.storage_producer,
accumulators=self.accumulators,
experiments_manager=self.manager,
fine_tuning_settings=self.settings,
initial_params=self.initial_params,
ranking_data=ranking_dataset,
initial_max_evals=2,
)
return fine_tuning_builder
In this method, we describe how the model fine-tuning will take place. In our case, we use the
prepare_data
function to transform the clickstream into a dataset suitable for fine-tuning. Then, we create an instance of the
FineTuningBuilder
class, which will perform the fine-tuning. In the constructor, we pass all the necessary components
that will be used during the fine-tuning process.