Embedding Models
This guide will walk you through the process of deploying a private embedding model using the Predibase SDK and querying it via the embed endpoint.
Supported Embedding Models
Predibase supports various embedding models, including:
- BERT-based models (e.g., WhereIsAI/UAE-Large-V1 and other BERT-based models)
- DistilBERT and DistilBERT-based models
- MRL Qwen-based models (e.g., dunzhang/stella_en_1.5B_v5)
Deploying an Embedding Model
To deploy an embedding model, use the pb.deployments.create()
method. The process is similar to deploying other types of models in Predibase.
import predibase as pb
# Create a deployment for an embedding model
pb.deployments.create(
name="my-embedding-model",
config=DeploymentConfig(
base_model="WhereIsAI/UAE-Large-V1", # Replace with your chosen embedding model
min_replicas=0,
max_replicas=1,
accelerator='a10_24gb_100',
lorax_image_tag='c754415'
),
)
Note: For MRL Qwen-based models, you need to specify the embedding dimension. For example:
pb.deployments.create(
name="my-qwen-embedding-model",
config=DeploymentConfig(
base_model="dunzhang/stella_en_1.5B_v5",
min_replicas=0,
max_replicas=1,
custom_args=["--embedding-dim", "1024"], # Specify the embedding dimension
accelerator='a100_80gb_100',
lorax_image_tag='c754415'
)
)
Note: We pin the LoRAX image version for embedding models to commit that is optimized for current models. Please use the image tag c754415
for embedding use cases.
Creating a Client for the Embed Endpoint
After deploying your embedding model, you can create a client to interact with it:
# Create a client for the embedding model
lorax_client = pb.deployments.client("my-embedding-model")
Using the Embed Endpoint
Once you have a client, you can use the embed()
method to generate embeddings for your input text.
The embed()
method returns an EmbedResponse
object that contains the following field:
embeddings
: A list of floating-point numbers representing the embedding vector. This vector captures the semantic meaning of your input text in a high-dimensional space. The length of this list depends on the model - for example, UAE-Large-V1 produces 1024-dimensional embeddings.
# Generate embeddings for a single input
input_text = "This is a sample text for embedding."
response = lorax_client.embed(input_text)
# Access the embeddings
embeddings = response.embeddings
Handling Multiple Inputs
If you need to generate embeddings for multiple inputs, you can do so by calling embed()
multiple times or by using a list comprehension:
input_texts = [
"This is the first sample text.",
"Here's another sample for embedding.",
"And one more for good measure."
]
# Generate embeddings for multiple inputs
embeddings_list = [lorax_client.embed(text).embeddings for text in input_texts]
Best Practices and Limitations
Input size: The maximum input size depends on the specific model configuration in Hugging Face. Be sure to check the model's documentation for any limitations.
Batch size: The
embed()
method processes one input at a time. If you need to process large batches, consider implementing your own batching logic to avoid overloading the server.Deployment configuration: The default deployment configuration (min_replicas=0, max_replicas=1) should work well for most use cases. Adjust these values if you need higher throughput or always-on availability.
Model selection: Choose an embedding model that best fits your use case. BERT-based models like UAE-Large-V1 offer a good balance of performance and quality for many applications.
Embedding dimension: For MRL Qwen-based models, make sure to specify the correct embedding dimension using the
custom_args
parameter during deployment.
Remember to monitor your deployment's performance and adjust the configuration as needed to optimize for your specific use case.