Skip to main content

Addressing Overfitting

What is overfitting?

When a model has “overfit”, it means that it has learned too much about the specifics of the training data. Models that are overfit have learned to make predictions accurately using details and noise only found in the training data and are unable to generalize what they have learned to unseen data.

Toy Example

For example, let’s say you have a toy dataset of five samples– three of which are in the training set, and two of which are in the validation set. Each sample is comprised of an image and a label, has_dog. For the sake of simplicity, we will avoid discussion regarding the test set for now.

Here is the training set. These are the samples that the model explicitly learns from.

  1. An image of a cat. The cat is on the left side of the image. has_dog = 0
  2. An image of a dog. The dog is on the right side of the image. has_dog = 1
  3. An image of a dog. The dog is on the right side of the image. has_dog = 1

Here is the validation set. These are the samples we set aside to determine whether or not the model will perform well on “real-world” data in production.

  1. An image of a cat. The cat is on the left side of the image. has_dog = 0
  2. An image of a dog. The dog is on the left side of the image. has_dog = 1

Now, let’s say that we have two models: Model A and Model B. Model A gets strong performance on both the training set and the validation set– in this case, 100% accuracy on both dataset splits. Strong performance on the validation set means that it “generalizes well” to samples beyond the initial training set. Upon analysis of Model A, we see that it uses whether there is a cat or dog in the image to determine if the label is 0 or 1.

On the other hand, Model B gets 100% accuracy on the training set, but only 50% accuracy on the validation set! Here are its predictions on the validation set:

  1. An image of a cat. The cat is on the left side of the image. has_dog = 0, prediction = 0
  2. An image of a dog. The dog is on the left side of the image. has_dog = 1, prediction = 0

Upon analysis of Model B, we see that it not only uses the cat or dog in the image, it also uses the position of the animal to attempt to determine the value of has_dog. The training set only had images of cats on the left side and dogs on the right side, so Model B (incorrectly) learned associations between animal position and has_dog. Upon seeing a dog on the left side for the first time when running inference on the validation set, the model was confused by the conflicting signals and incorrectly predicted has_dog to be 0.

Diagnosing overfitting in practice

Typically, the data we feed into our ML models are much more complex than in the toy example above, which makes it much more tricky to determine overfitting with sample-by-sample analysis. Luckily, it is easy to identify overfitting by inspecting the model’s performance metrics across its training and validation sets.

The canonical symptom for overfitting is increasingly strong performance on the training set in contrast to stagnant or declining performance on the validation and test sets. An example of this can be found below. The model whose learning curves are shown below is trained on the Adult Census dataset, a dataset comprised of tabular data. The model’s task is to classify each person's income as either below or above $50,000 based on their attributes as collected by the US Census.

Here, we see that the training loss (orange) is steadily decreasing over time and that the validation and test loss (blue and red) is increasing over time. What this typically means is that the model has stopped learning generalizable features from the data and begun exploiting details and noise specific to the samples of the training set to determine income . This exploitation leads to a boost in its accuracy on the training set and is a primary indicator of overfitting.

Models that overfit “can’t see the forest for the trees” – during the training process, overfit models memorize the little details of the training set and lose the ability to extrapolate its insights to unseen data. Preventing overfitting is typically a matter of providing a given model a more diverse training set, or reducing its capacity to memorize those little details.

How to fix overfitting

Fixing overfitting means to prevent the model from learning associations that are specific to the training set. This will help us ensure that the model only learns the most important features of the data, i.e. the features it will need in order to generalize and perform well on real-world data in production. There are two common ways to fix overfitting: modifying the training set, or regularizing the model.

Modifying the training set

One can modify the training set of the data in order to ensure that the model is best set-up for success during the training process. Using the example above, we could have ensured that the training set had more varied samples (such as samples with cats and dogs various positions) in order to prevent the model from creating erroneous associations. Data augmentation, a technique most commonly seen in computer vision, is an example of an a method used to provide the model varied samples of data during the training process to reduce the likelihood of overfitting.

That said, often times it is impossible to modify the data in a meaningful way, typically because data acquisition costs are too high. In these cases, we can turn to an approach called “regularization”, which makes tweaks to the model itself to prevent overfitting.

Regularizing the model

Broadly, regularization is a mechanism we can add to the model training process to handicap the model’s ability to memorize the little details that may be found in the training set. This becomes particularly relevant for Deep Neural Networks, which may have billions of parameters and therefore the capacity to memorize massive amounts of data.

Controlling Growth (Trees only)

Decision trees are often regularized by limiting the branch creation process. Typically, a decision tree becomes overfit if its branches become too specific to particular examples in the training set. Below are some ways to prevent this in decision trees:

  1. Enforcing a maximum tree depth in the overall decision tree.
  2. Enforcing a maximum number of leaves in the overall decision tree
  3. Enforcing a minimum number of data points per leaf

Weight Decay (L1/L2 Regularization)

Both neural networks and decision trees support a regularization method known as “weight decay”– simply put, weight decay limits how large the weight values can be within a particular ML model. In general, the values of a model’s weights are how the model formulates its decisions about its inputs. By enforcing weight “sparsity” within the model during training, the model is limited in how sensitive it can be to particular features for a given sample. This limitation encourages it to focus on only the most important features across all samples, therefore preventing overfitting.

Typically, weight decay is controlled by some lambda coefficient, which describes how strongly you want to enforce the sparsity constraint in your architecture. Higher values indicate stronger regularization. To adjust this value in Predibase, search for Regularization Lambda in the Parameters tab of the Model Builder. Some good values to try for this parameter are 0.0001, 0.001, 0.01, and 0.1.

Dropout

Both neural networks and some types of decision trees (e.g. DART) support another regularization method known as “dropout”. Dropout is a mechanism enabled during the training process that will temporarily “drop” components of the model architecture at random during training. For neural networks, the components are typically individual neurons. For DART models, the components are typically individual trees.

By dropping these components, the overall model is operating at only a fraction of its total capacity at any given moment. This limitation forces the model to reduce its reliance on any one component and build redundancy into its architecture. This in turn limits its capacity for memorizing too many fine details about samples in the training set, thus preventing overfitting.

Typically, dropout is controlled by the probability of dropping components. Higher values mean that components are dropped more frequently during training, and thus cause stronger regularization. To adjust this value in Predibase, search for Dropout or FC Dropout in the Parameters tab of the Model Builder. Some good values to try for this parameter are 0.1, 0.2, and 0.5.

Other forms of implicit regularization (i.e. normalization)

Neural networks have mechanisms that prevent overfitting without having being designed for it. An example is normalization (batch, layer, cosine, and other kinds) normalize the values of activations. This in turn makes it so that the absolute value for the parameters at the following layers does not have to be very large or very small to account for very large values in the activations. Implicitly, this regularizes the weights and makes gradients less prone to explode or vanish through the layers. To adjust this value in Predibase, search for Norm in the Parameters tab of the Model Builder, and try setting it to either batch or layer.

Overfitting as a part of the modeling process

While overfitting is an indicator that your model is not as performant as it could ultimately be, it is often good signal during the modeling process that you are on the right track.

One practice commonly seen in the machine learning development lifecycle is actually striving for overfitting as a modeling milestone. Once you have a model that is capable of getting strong performance on the training set, you likely have a model with sufficient expressive capacity to learn the broader strokes required to attain good generalization performance on its validation set and beyond. Careful application of the regularization techniques listed above will typically yield stronger performance in such cases, and bring you closer to your model quality acceptance criteria.