A Review of Focal Loss at Women in Data Science Blacksburg

Summary: I enjoyed listening to talks, meeting Virginia Tech students, and giving a tutorial on deep learning at the Women in Data Science (WiDS) Blacksburg conference. WiDS events all over the world are happening now to encourage and support current and future women in this field. Some of the material from my tutorial on focal loss, intended for people with a basic background in machine learning, is included below with context for an accompanying Jupyter notebook.

Last week I had the opportunity to attend and present at Women in Data Science Blacksburg, the first WiDS regional event at Virginia Tech, hosted by Dr. Eileen Martin, one of my classmates from grad school. For those of you unfamiliar with WiDS, the first Women in Data Science conference was organized and run at Stanford, led by Dr. Margot Gerritsen, who was the director of my department at the time. One of the things I love about WiDS is how from early on there were efforts to have it reach beyond Silicon Valley by encouraging people around the world to host their own WiDS events. At this point, I’ve attended WiDS conferences at Stanford; Cambridge, MA; Washington, D.C.; and now Blacksburg, VA. Video clips and images from various regional events are compiled and broadcast, so for me, there is a sense of this much broader community extending beyond the people in my locality. Speaking of video clips, the Virginia Tech College of Science has already put together a brief video about the WiDS Blacksburg. I hope they continue to support this event in the future.

If you are a woman or male ally interested in data science and machine learning, WiDS Blacksburg was held earlier than most, so there may still be time to register for a WiDS event in your region or participate remotely via the livestream from Stanford.

The tutorial session that I presented at this WiDS focused on focal loss, a variant of the cross-entropy function commonly used by neural networks to perform classification. The paper Focal Loss for Dense Object Detection was published (pre-published?) on arXiv in mid-2017, so it has been around for a bit, but many people are still not familiar with this simple but effective technique. To prepare an interactive example that students could run easily, naturally the first thing I did was search GitHub, because while I could write my own from scratch, let’s be real — I have a day job and a life outside of work, and I strongly believe in minimizing duplication of effort. I found a great example Jupyter notebook and accompanying blog post by user Tony607, forked the repository, and started making changes. I ended up changing a fair amount in order to approach the problem in the way that made the most sense to me and to emphasize certain aspects of how focal loss works. My version of the notebook is available here, although I encourage you to read more of this post before trying it out. (Yes, it’s a toy example with a teeny tiny neural net, and it’s what made the most sense for a live demo.)

In my presentation, I tried to break down the main ideas from the focal loss paper to be more intuitive and digestible for people with less experience in deep learning. Read on for the full explanation, intended for people with a basic background in machine learning, or skip to the last paragraph for a couple sentences’ worth of closing thoughts.

First, let’s take a step back and ask “what problem are we trying to solve?” Say you want to classify each sample from a dataset as one of two classes, and to add a slight complication, the class distribution is imbalanced. (Don’t worry, we can extend focal loss to N classes, but I’m using two for simplicity.) To make this example more concrete, let’s say that the problem is detecting fraudulent financial transactions in a dataset with a large proportion of normal transactions and relatively few fraudulent transactions. In fact, this is the problem used in the Jupyter notebook. You build a neural network and train it for a bit, and it quickly attains the ability to distinguish between normal and fraudulent transactions at a basic level. From this point on, most of the training examples are not doing much to improve your performance because the model is already doing a decent job on them. We will call those “well-classified” or “easy” examples. There is a smaller subset of “hard” examples in the training dataset that are more informative to the model, and focal loss allows us to place more emphasis on those examples.

How does focal loss achieve this objective? Figure 1 from the paper illustrates this well. I’ve taken the original figure from the paper and added my own annotations below. This plot shows curves for the standard cross-entropy loss function and a few variations of the focal loss function, where the variations use different values for the hyperparameter gamma. On the x-axis is the input pt, the predicted probability for the true class, and on the y-axis is the corresponding loss. Consider what happens with a well-classified example — say a training example with a true label of “normal” has a predicted “normal” score of 0.8. Looking at the cross-entropy function, the loss is small and, more importantly, the gradient is small. If we compare that to a hard example, such as a “normal” example with a score of 0.2, where we are not doing well on this example at all, the gradient for the hard example is larger. This is good — the standard cross-entropy loss function already has some built-in ability to place more emphasis on examples where the predictions are further from the truth labels.

However, if we go through the same thought exercise with one of the focal loss curves, we see that the gradient for the well-classified example is even smaller and the gradient for the hard example is even larger. We could interpret this difference in the shape of the loss functions by saying that a model trained with standard cross-entropy loss will continue trying to push scores for the well-classified examples further and further all the way to 1.0, where a model trained with focal loss will not care too much about the well-classified examples and instead work more towards improving on the hard examples. This effect is evident in the Jupyter notebook, and this is a good point to take a look at it and see for yourself.

There are a couple more key parts of the focal loss paper that I want to discuss. First, the application we were just considering is a pure classification problem. Where does object detection fit in? A common standard design for deep learning object detection models uses a grid with several template boxes or “anchor boxes” with different aspect ratios at each cell within the grid, and the model learns to classify each cell as either having an object of interest (a ground truth box that roughly matches an anchor box) or having nothing of interest, i.e. belonging to the “background” class. In this example image from SSD: Single Shot Multibox Detector, only two anchor boxes match the cat, and one matches the dog. The vast majority of anchor boxes do not match a ground truth box, so we have a situation with potentially a large number of easy background examples and a smaller subset of hard examples. For a model like SSD, there is a regression component of the architecture where the model predicts the deltas between the truth bounding box and the anchor box, but focal loss does not directly impact that pathway, and that’s all we’ll say about it here.

Figure 1 from SSD: Single Shot MultiBox Detector

Finally, the focal loss paper also mentions the use of a prior probability for the rare class. Based on my team’s experiments, I can say that adjusting the prior is not necessary (and it is not used in the Jupyter notebook example), but it can help improve performance. The general idea here is that if we know ahead of time that a certain class is very rare (or conversely that a certain class will be overwhelmingly represented) we can initialize the weights for the last layer leading up to classification such that the model is biased toward predicting the rare class with low probability (or predicting the common class with high probability) instead of predicting each class uniformly. The final layer will begin training already able to predict the correct label for most of the examples and just needs to learn to recognize the rare class(es).

I think focal loss was a great topic for this setting because it is easy to implement, general enough to apply to many situations, and based on straightforward reasoning about gradients. Maybe I’m an idealist, but I think you or I could come up with an idea like this, too.

References

  1. Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2018). Focal loss for dense object detection. IEEE transactions on pattern analysis and machine intelligence.
  2. Liu, W., Anguelov, D., Erhan, D., Szegedy, C., Reed, S., Fu, C. Y., & Berg, A. C. (2016, October). Ssd: Single shot multibox detector. In European conference on computer vision (pp. 21-37). Springer, Cham.