Supervised ML Model Configurations
Predibase supports training custom model architectures for specific tasks using the open source Ludwig framework. Predibase can train a supervised-ML model on any table-like dataset, meaning that every feature has its own column and every example its own row.
Model Training
Once you have your dataset connect in Predibase, use the create_model
method which will create a new
model repository and train a model with the config that we provide.
From the rotten tomatoes example from the Quickstart, here is a basic config that specifies the model inputs and outputs. We'll let Predibase handle the rest.
input_features:
- name: genres
type: set
preprocessing:
tokenizer: comma
- name: content_rating
type: category
- name: top_critic
type: binary
- name: runtime
type: number
- name: review_content
type: text
output_features:
- name: recommended
type: binary
This config file tells Predibase that we want to train a model using the following input features:
- The genres associated with the movie will be used as a set feature
- The movie's content rating will be used as a category feature
- Whether the review was done by a top critic or not will be used as a binary feature
- The movie's runtime will be used as a number feature
- The review content will be used as text feature
This config file also tells Predibase that we want our model to have the following output features:
- The recommended column indicates whether a movie was recommended and is used as a binary feature
Now let's read our config file in and pass it to the model for training:
rotten_tomatoes_config = yaml.safe_load(
"""
input_features:
- name: genres
type: set
preprocessing:
tokenizer: comma
- name: content_rating
type: category
- name: top_critic
type: binary
- name: runtime
type: number
- name: review_content
type: text
output_features:
- name: recommended
type: binary
"""
)
Finally, let's pass the dataset and config to the create_model
method. Here we will also specify a repository name,
repository description, and model description for the first model trained. These are all optional, but will help you
keep track of your models.
rotten_tomatoes_model = pc.create_model(
repository_name="Rotten Tomatoes Recommender",
dataset=rotten_tomatoes_dataset,
config=rotten_tomatoes_config,
repo_description="Predict whether a movie is recommended or not",
model_description="Baseline Model"
)
After running this command, if you follow the provided link, you can track your model training in real-time:
Training in the UI
Learn how to train a supervised ML model in the Predibase UI.