#4 - So how do you build a vector embedding model?
This project explores a popular way to build a vector embedding model using your own custom dataset
This article is part of a blog series about demistifying vector embedding models for use in image embeddings:
Part 1. So how do you build a vector embedding model? - (this article) - Introduces vector embedding models and the intuition behind the technologies we can use to build one ourselves.
Part 2. Let's build our image embedding model - Shows a couple of ways to build embedding models - first by using a pre-trained model, and next by fine-tuning a pre-trained model. We use PyTorch to build our feature extractor.
Part 3. Modelling with Metaflow and MLFlow - Here we are using Metaflow to build our model training workflow, where we introduce the concept of checkpointing, and MLFlow for experiment tracking.
Part 4. From Training to Deployment: A Simple Approach to Serving Embedding Models -Packaging your ML model in a Docker container opens it up to a multitude of model serving options.
Part 5. Putting Our Bird Embedding Model to Work: Introducing the Web Frontend -For our embedding model to prove useful to others, we have created a modern frontend to serve the similarity inference to our users.
Hi friends,
It’s been an eventful few months at work since my last post, where I wrote about the time when I spent the weekend playing with SOMs. When I’m less busy work-wise, I do tend to work on interesting end-to-end projects that have something to do with machine learning. I did tell myself that for 2024, Gen AI will be my focus, and it still is, however, I will still have time for ML projects such as this.
One problem that stumped me recently was an image classification problem for 5,000+ classes. Yeah, I’ve done multiple classification problems before but not with more than a dozen classes. I also have not done a lot of computer vision ML problems so this is my chance to dig a little bit deeper and familiarise myself with the process.
A Hypothetical Problem
I am partnering with Bird Watch, a bird conservation and wildlife monitoring non-profit organization. They need a machine learning model that will help them identify bird species that had been captured by their drones and camera traps. Instead of directly classifying each bird species, I was asked to return the 10 most similar bird species, returned in order according to the most similar where naturally the most similar specie will be at the top of that list. So this is good, instead of forcing the output into predefined categories, it can actually find matches when the species presented is not part of the dataset, and can be a helpful in identifying new species.
multiple class classification of this magnitude is not recommended as it will be a large model and will be expensive to train
after reflecting on the problem, where we need to return the ten most similar bird species. This was not really a classification problem after all
input is an image, and the model will need to return the 10 most similar images from the images in a database.
so instead of an ML problem, we can actually treat this as a data retrieval problem, with the help of embedding vectors and a vector database
So if it’s similarity search, as what we are doing here, what technology enables it? Yes, semantic search and vector databases.
And when talking about vector databases, the thing that underpins it all is vector embeddings generated by an embedding model.
So, I don’t need a classification model?
A typical bird classification model will return with a species prediction, one out of the 525 bird species in the dataset. But what do I do if I don’t want to predict the species? What machine learning model do I need when I want to return the 10 most similar species? We need a species similarity search, and not a species classifier!
We need to build a vector embedding model!
Although vector embeddings have been around since the 1980s and 1990s, I only recently became aware of them, thanks to their rise in popularity with large language models like ChatGPT towards the end of 2022.
How to build an embedding model
To know how to produce an embedding model, we first need to understand the main parts of a Convolutional Neural Network (CNN) model. The CNN model has three main parts:
Backbone - This is the part of the model that extracts the features from the image. This is typically the convolutional layers of the model.
Neck - This is the part of the model that reduces the dimensionality of the features extracted by the backbone. This is typically the pooling layers of the model. This is typically the penultimate layer of the model, that outputs the features.
Head - This is the part of the model that classifies the features extracted by the backbone and refined by the neck.
This image is true for many of the image models hosted in Hugging Face. To produce an embedding model, we need to remove the head and use the neck's output which I later realised is actually are THE vector embeddings! These embeddings are the features extracted by the backbone and refined by the neck.
It is as easy as that. Just remove the classification head, regardless if you are using a pre-trained model or a fine-tuned model, and you have just produced an embedding model.
That is exactly what we will do, perform a ‘surgery’ on the fine-tuned model that we will build for this article!
Pre-trained image model
There are a multitude of options of pre-trained image models available online - Tensorflow and PyTorch have their model hubs, and cloud providers like AWS, Google and Azure als have their own AI galleries. For this article however, we will be using timm (PyTorch Image Module), which is a PyTorch library that contains a collection (over 1,400) of models, and miscellaneous code that can help you build, train or fine tune your image models. When you use this library, it actually downloads the model weights from Hugging Face, another well known model hub that initially became popular through its hosting of NLP models. These days, you will find there almost any type of open source machine learning model to fine-tune or to use outright.
Let’s check out ResNet50
ResNet 50, short for Residual Network architecture, a deep learning machine learning model that is popular for image classification tasks. It was introduced by Microsoft Research in 2015, and it became the basis of many deep learning architectures after that. It was first introduced in this paper by He. et. al.
Feature extractor
ResNet 50 is part of the ResNet family of image models, and has been proven for its performance as pre-trained models, in a process called transfer learning - where it learns weights from large datasets, like ImageNet, and uses that to improve the accuracy in much smaller datasets, like our bird species dataset.
For our embedding model, although we will start by training a ResNet 50 classifier, once we have built our fine-tuned model, we will be discarding the classification head, leaving just the embedding model, in other words, we’re left with the feature extractor, and that’s precisely what we need.
And how about our similarity search?
Once a vector embedding model has been built, we’re halfway to our bird species similarity search engine. We can then build it in the next 3 steps:
ingest our bird species dataset into a vector database, using the fine-tuned embedding model to generate the embedding of each
using the same embedding model, get the vector embedding of the new bird specie image
finally, perform a similarity search of this new bird specie image against all the bird species ingested in the vector database. There are several options of similarity criteria, but the most common are cosine similarity and euclidian distance. These similarity metrics are typical of the popular vector databases these days.
Conclusion
In this article, we have learned that the image models like ResNet 50 are used to train and fine-tune classification models. But to produce an embedding model, we can simply remove the model’s classification head, leaving behind the feature extractor and a model that produces embedding vectors - our fine-tuned embedding model.
In the next article, we’ll see how to actually build this embedding model, through a Jupyter notebook that we’ll build for this series.
Till then,
JO