Skip to main content

Prompt Prefix Caching

Private serverless deployments allow you to optionally enable prompt prefix caching when creating your deployment.

This feature is particuarly useful for speeding up inference making multiple prompts over a single long document or context (e.g., RAG) or for long multi-turn chat conversations.

When enabled, the Predibase LoRAX deployment will store the KV cache as a radix tree, allowing previously computed prefixes to be reused on subsequent prompts, significantly speeding up the time to first token.

note

Currently dynamic adapter loading is not supported when using prompt prefix caching. For now, you will need to use a single adapter on your deployment by setting --adapter-id during initialization. Dynamic adapter loading will be fully supported for prefix caching in an upcoming release.

Enabling Prefix Caching

Set the --prefix-caching true custom arguments when creating your deployment as shown:

pb.deployments.create(
name="llama-3-1-8b-prefix-cache",
config=DeploymentConfig(
base_model="llama-3-1-8b-instruct",
min_replicas=1,
max_replicas=1,
custom_args=["--prefix-caching", "true"]
)
)

Usage

Using the Predibase Python SDK to prompt your deployment:

# Get a handle to your private deployment with prefix caching enabled
lorax_client = pb.deployments.client("llama-3-1-8b-prefix-cache")

# Define your prefix containing the long context your wish to cache
prefix = f"""The following is a document to be used as context when answering the question:

{document_text}

Answer the following question using only the above as context:
"""

# Helper function to do prompt templating
def format_prompt(prompt: str) -> str:
return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n {prefix}\n{prompt} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

# Start asking questions
response1 = lorax_client.generate(format_prompt("Summarize this document as a bulleted list."))

# Subsequent prompts should be faster as the prefix will have been cached
response2 = lorax_client.generate(format_prompt("Who is the author of this document?"))