Skip to main content

Embedding Models

This guide will walk you through the process of deploying a private embedding model using the Predibase SDK and querying it via the embed endpoint.

Supported Embedding Models

Predibase supports various embedding models, including:

  • BERT-based models (e.g., WhereIsAI/UAE-Large-V1 and other BERT-based models)
  • DistilBERT and DistilBERT-based models
  • MRL Qwen-based models (e.g., dunzhang/stella_en_1.5B_v5)

Deploying an Embedding Model

To deploy an embedding model, use the pb.deployments.create() method. The process is similar to deploying other types of models in Predibase.

import predibase as pb

# Create a deployment for an embedding model
pb.deployments.create(
name="my-embedding-model",
config=DeploymentConfig(
base_model="WhereIsAI/UAE-Large-V1", # Replace with your chosen embedding model
min_replicas=0,
max_replicas=1,
accelerator='a10_24gb_100',
lorax_image_tag='c754415'
),
)

Note: For MRL Qwen-based models, you need to specify the embedding dimension. For example:

pb.deployments.create(
name="my-qwen-embedding-model",
config=DeploymentConfig(
base_model="dunzhang/stella_en_1.5B_v5",
min_replicas=0,
max_replicas=1,
custom_args=["--embedding-dim", "1024"], # Specify the embedding dimension
accelerator='a100_80gb_100',
lorax_image_tag='c754415'
)
)

Note: We pin the LoRAX image version for embedding models to commit that is optimized for current models. Please use the image tag c754415 for embedding use cases.

Creating a Client for the Embed Endpoint

After deploying your embedding model, you can create a client to interact with it:

# Create a client for the embedding model
lorax_client = pb.deployments.client("my-embedding-model")

Using the Embed Endpoint

Once you have a client, you can use the embed() method to generate embeddings for your input text.

The embed() method returns an EmbedResponse object that contains the following field:

  • embeddings: A list of floating-point numbers representing the embedding vector. This vector captures the semantic meaning of your input text in a high-dimensional space. The length of this list depends on the model - for example, UAE-Large-V1 produces 1024-dimensional embeddings.
# Generate embeddings for a single input
input_text = "This is a sample text for embedding."
response = lorax_client.embed(input_text)

# Access the embeddings
embeddings = response.embeddings

Handling Multiple Inputs

If you need to generate embeddings for multiple inputs, you can do so by calling embed() multiple times or by using a list comprehension:

input_texts = [
"This is the first sample text.",
"Here's another sample for embedding.",
"And one more for good measure."
]

# Generate embeddings for multiple inputs
embeddings_list = [lorax_client.embed(text).embeddings for text in input_texts]

Best Practices and Limitations

  1. Input size: The maximum input size depends on the specific model configuration in Hugging Face. Be sure to check the model's documentation for any limitations.

  2. Batch size: The embed() method processes one input at a time. If you need to process large batches, consider implementing your own batching logic to avoid overloading the server.

  3. Deployment configuration: The default deployment configuration (min_replicas=0, max_replicas=1) should work well for most use cases. Adjust these values if you need higher throughput or always-on availability.

  4. Model selection: Choose an embedding model that best fits your use case. BERT-based models like UAE-Large-V1 offer a good balance of performance and quality for many applications.

  5. Embedding dimension: For MRL Qwen-based models, make sure to specify the correct embedding dimension using the custom_args parameter during deployment.

Remember to monitor your deployment's performance and adjust the configuration as needed to optimize for your specific use case.