Skip to main content

Supervised ML Model Configurations

Predibase supports training custom model architectures for specific tasks using the open source Ludwig framework. Predibase can train a supervised-ML model on any table-like dataset, meaning that every feature has its own column and every example its own row.

Model Training

Once you have your dataset connect in Predibase, use the create_model method which will create a new model repository and train a model with the config that we provide.

From the rotten tomatoes example from the Quickstart, here is a basic config that specifies the model inputs and outputs. We'll let Predibase handle the rest.

rotten_tomatoes.yaml
input_features:
- name: genres
type: set
preprocessing:
tokenizer: comma
- name: content_rating
type: category
- name: top_critic
type: binary
- name: runtime
type: number
- name: review_content
type: text
output_features:
- name: recommended
type: binary

This config file tells Predibase that we want to train a model using the following input features:

  • The genres associated with the movie will be used as a set feature
  • The movie's content rating will be used as a category feature
  • Whether the review was done by a top critic or not will be used as a binary feature
  • The movie's runtime will be used as a number feature
  • The review content will be used as text feature

This config file also tells Predibase that we want our model to have the following output features:

  • The recommended column indicates whether a movie was recommended and is used as a binary feature

Now let's read our config file in and pass it to the model for training:

rotten_tomatoes_config = yaml.safe_load(
"""
input_features:
- name: genres
type: set
preprocessing:
tokenizer: comma
- name: content_rating
type: category
- name: top_critic
type: binary
- name: runtime
type: number
- name: review_content
type: text
output_features:
- name: recommended
type: binary
"""
)

Finally, let's pass the dataset and config to the create_model method. Here we will also specify a repository name, repository description, and model description for the first model trained. These are all optional, but will help you keep track of your models.

rotten_tomatoes_model = pc.create_model(
repository_name="Rotten Tomatoes Recommender",
dataset=rotten_tomatoes_dataset,
config=rotten_tomatoes_config,
repo_description="Predict whether a movie is recommended or not",
model_description="Baseline Model"
)

After running this command, if you follow the provided link, you can track your model training in real-time:

Training in the UI

Learn how to train a supervised ML model in the Predibase UI.