By Sandeep Uttamchandani, Ph.D., Both a Product/Software Builder (VP of Engg) & Leader in operating enterprise-wide Data/AI initiatives (CDO)
Image by Tumisu from Pixabay
ML model training is the most time-consuming and resource-expensive part of the overall model-building journey. Training by definition is iterative, but somewhere during the iterations, mistakes seep into the mix. In this article, I share the ten deadly sins during ML model training — these are the most common as well as the easiest to overlook.
Ten Deadly Sins of ML Model Training
1. Blindly increasing the number of epochs when the model is not converging
During model training, there are scenarios when the loss-epoch graph keeps bouncing around and does not seem to converge irrespective of the number of epochs. There is no silver bullet as there are multiple root causes to investigate — bad training examples, missing truths, changing data distributions, too high a learning rate. The most common one I have seen is bad training examples related to a combination of anomalous data and incorrect labels.
2. Not shuffling the training dataset
Sometimes there are scenarios where the model seems to be converging, but suddenly the loss value increases significantly, i.e., loss value reduces and then increases significantly with epochs. There are multiple reasons for this kind of exploding loss. The most common one I have seen is outliers in the data that are not evenly distributed/shuffled in the data. Shuffling, in general, is an important step including for patterns where the loss is showing a repeating step function behavior.
3. In multiclass classification, not prioritizing specific per-class metrics accuracy
For multiclass prediction problems, instead of tracking just the overall classification accuracy, it is often useful to prioritize the accuracy of specific classes and iteratively work on improving the model class by class. For instance, in classifying different forms of fraudulent transactions, focus on increasing the recall of specific classes (such as foreign transactions) based on business needs.
4. Assuming specificity will lead to lower model accuracy
Instead of building a generic model, imagine building a model for a specific geographic region or specific user persona. Specificity will make the data more sparse but can lead to better accuracy for those specific problems. It is important to explore the specificity and sparsity trade-off during tuning.
5. Ignoring prediction bias
Prediction bias is the difference between the average of predictions and the average of labels in the dataset. Prediction bias serves as an early indicator of model issues. A big nonzero prediction bias is indicative of a bug somewhere in the model. There’s an interesting Facebook paper in the context of ad CTR. Typically, the bias is useful to measure across prediction buckets.
6. Calling it a success just on model accuracy numbers
Accuracy of 95% means 95 of 100 predictions were correct. Accuracy is a flawed metric with a class imbalance in the dataset. Instead investigate deeply into metrics, such as precision/recall and how it correlates to overall user metrics such as spam detection, tumor classification, etc.
7. Not understanding the impact of regularization lambda
Lambda is a key parameter in striking the balance between simplicity and training-data fit. High lambda → simple model → possibly underfitting. Low lambda → complex model → potential overfitting your data (won’t be able to generalize to new data). The ideal value of lambda is one that generalizes well to previously unseen data: data-dependent and requires analysis.
8. Using the same test set over and over
The more the same data is used for parameter and hyperparameter settings, the lesser confidence that the results will actually generalize. It is important to collect more data and keep adding to the test and validation sets.
9. Not paying attention to initiation value in neural networks
Given non-convex optimization in NN, initialization matters.
10. Assuming wrong labels always need to be fixed
When wrong labels are detected, it is tempting to jump in and get them fixed. It is important to first analyze misclassified examples for the root cause. Oftentimes, errors due to incorrect labels may be a very small percentage. There might be a bigger opportunity to better train for specific data slices that might be the predominant root cause.
To summarize, avoiding these mistakes puts you significantly ahead of most other teams. Incorporate these as a checklist in your process.
Bio: Sandeep Uttamchandani, Ph.D.: Data + AI/ML — Both a Product/Software Builder (VP of Engg) & Leader in operating enterprise-wide Data/AI initiatives (CDO) | O’Reilly Book Author | Founder – DataForHumanity (non-profit)
Original. Reposted with permission.