A challenge which machine learning practitioners often face, is how to deal with skewed classes in classification problems. Such a tricky situation occurs when one class is over-represented in the data set. A common example for this issue is fraud detection: a very big part of the data set, usually 9x%, describes normal activities and only a small fraction of the data records should get classified as fraud. In such a case, if the model always predicts “normal”, then it is correct 9x% of the time. At first, the accuracy and therewith the quality of such a model might seem surprisingly good, but as soon as you perform some analysis and dig a little bit deeper, it becomes obvious that the model is completely useless. Just recently I stumbled upon this issue in the context of multiclass classification, which even adds to the original problem. After trying different approaches, doing research, and giving it some thoughts, I was able to figure out a couple of workarounds and solutions which I would like to share with you in this post.
1. Precision & Recall
This isn’t really a solution to the problem, but it helps for evaluating the final model. As described in the fraud example, “accuracy” might not be the best metric for determining the quality of the model. Fortunately, the metrics “precision” and “recall” can help us out. Precision describes how many of the data records, which got classified as fraud, actually are illustrating fraudulent activities. On the other hand, recall refers to the percentage of correctly classified frauds based on the overall number of frauds of the data set. In case you are attracted to formulas:
For multiclass classification, the usefulness of precision and recall as metrics is limited as you would have to compute and analyze those for every class. In my opinion, it is better to simply take a quick look at the confusion matrix for the classified test set and to determine whether at least some of the non-dominant classes got classified correctly.
Once you split up the data into train, validation and test set, chances are close to 100% that your already skewed data becomes even more unbalanced for at least one of the three resulting sets. Think about it: Let’s say your data set contains 1000 records and of those 20 are labelled as “fraud”. As soon as you split up the data set into train, validation and test set, it is possible that the train set contains the majority of the “fraud”-records or maybe even all of them. Although not as likely, the same could happen for the validation or test set, which is even worse because then the machine learning algorithm has no chance at all to learn the hidden patterns of “fraud”-records. This is why you shouldn’t use random sampling when constructing the three different data sets, but stratified sampling instead. It assures that the train, validation and test sets are well balanced. Therewith the already existing problem of skewed classes is not intensified, which is what you want when creating high quality models. For R, the function “strata” of the package “sampling” provides functionality for stratified sampling.
If you have a lot of data, then simply limit the allowed number of “not fraud” samples in your data set. Unfortunately in machine learning you can never have enough data and you usually have less than desperately needed. However, this method is also useful for smaller data sets as in most cases the left out data records wouldn’t add a substantial amount of information to the ones already included.
For most machine learning algorithms you can assign higher costs for “false positives” and “false negatives”. The higher the costs, the more the model is geared towards classifying something else. This way, you can influence either precision or recall and enhance the overall performance of the model. However, as one of the two metrics is going to become better, the other one inevitably will get worse. It’s a tradeoff. This principle of costs can also be applied to a multiclass classification scenario, where it is possible to assign different weights to specific classes. The higher the weight of a class, the more likely it is to be classified. For SVM in R, the argument “class.weights” of the “svm” function in the e1071 package allows for such an adjustment of the importance of each class. One more hint to ease understanding: To my knowledge weights are the same as reversed costs.
In anomaly detection, the basic idea is to predict the probability for every record to be an anomaly, e.g. fraud. This is done via looking at the values of the features of an unseen data record and comparing them to those of all other data records which are known to be normal. If the probability turns out to be below a certain threshold , for example 5%, then the data record is classified as an anomaly. Especially anomaly detection using a multivariate Gaussian distribution looks promising to me. However, anomaly detection cannot be applied to multiclass classification settings with skewed classes. The algorithm would only be able to tell which of the data records don’t belong to any of the labeled classes and therefore should be classified as something like “other”. Unfortunately, I don’t have any practical experience with anomaly detection algorithms yet and do neither know the performance nor the pitfalls of such an approach. But for further information, I can recommend the anomaly detection chapter of the Machine Learning class at coursera.
Hopefully this little list helps you once you have to deal with skewed classes. For me, applying the 2. and 3. approach did the trick. In case you know other effective ways of tackling this problem, I would be more than happy to read about it in the comments below.
Update: A lot of insightful comments were also posted on hackernews.