Skip to content
This repository has been archived by the owner on Nov 29, 2022. It is now read-only.

artemmavrin/focal-loss

Repository files navigation

Focal Loss

Python Version

PyPI Package Version

Last Commit

Build Status

Code Coverage

Documentation Status

License

TensorFlow implementation of focal loss1: a loss function generalizing binary and multiclass cross-entropy loss that penalizes hard-to-classify examples.

The focal_loss package provides functions and classes that can be used as off-the-shelf replacements for tf.keras.losses functions and classes, respectively.

# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss

model = tf.keras.Model(...)
model.compile(
    optimizer=...,
    loss=BinaryFocalLoss(gamma=2),  # Used here like a tf.keras loss
    metrics=...,
)
history = model.fit(...)

The focal_loss package includes the functions

  • binary_focal_loss
  • sparse_categorical_focal_loss

and wrapper classes

  • BinaryFocalLoss (use like tf.keras.losses.BinaryCrossentropy)
  • SparseCategoricalFocalLoss (use like tf.keras.losses.SparseCategoricalCrossentropy)

Documentation is available at Read the Docs.

Focal loss plot

Installation

The focal_loss package can be installed using the pip utility. For the latest version, install directly from the package's GitHub page:

pip install git+https://github.com/artemmavrin/focal-loss.git

Alternatively, install a recent release from the Python Package Index (PyPI):

pip install focal-loss

Note. To install the project for development (e.g., to make changes to the source code), clone the project repository from GitHub and run make dev:

git clone https://github.com/artemmavrin/focal-loss.git
cd focal-loss
# Optional but recommended: create and activate a new environment first
make dev

This will additionally install the requirements needed to run tests, check code coverage, and produce documentation.

References


  1. T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for dense object detection. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018. (DOI) (arXiv preprint)