Skip to content

Large-Margin Softmax Loss, Angular Softmax Loss, Additive Margin Softmax, ArcFaceLoss And FocalLoss In Tensorflow

License

Notifications You must be signed in to change notification settings

HiKapok/tf.extra_losses

Repository files navigation

Large-Margin Softmax Loss, Angular Softmax Loss, Additive Margin Softmax, ArcFaceLoss And FocalLoss In Tensorflow

This repository contains core codes of the reimplementation of the following papers in TensorFlow:

If your goal is to reproduce the results in the original paper, please use the official codes:

For using these Ops on your own machine:

  • copy the header file "cuda_config.h" from "your_python_path/site-packages/external/local_config_cuda/cuda/cuda/cuda_config.h" to "your_python_path/site-packages/tensorflow/include/tensorflow/stream_executor/cuda/cuda_config.h".

  • run the following script:

mkdir build
cd build && cmake ..
make
  • run "test_op.py" and check the numeric errors to test your install

  • follow the below codes snippet to integrate this Op into your own code:

    • For Large Margin Softmax Loss:
     op_module = tf.load_op_library(so_lib_path)
     large_margin_softmax = op_module.large_margin_softmax
    
     @ops.RegisterGradient("LargeMarginSoftmax")
     def _large_margin_softmax_grad(op, grad, _):
       '''The gradients for `LargeMarginSoftmax`.
       '''
       inputs_features = op.inputs[0]
       inputs_weights = op.inputs[1]
       inputs_labels = op.inputs[2]
       cur_lambda = op.outputs[1]
       margin_order = op.get_attr('margin_order')
    
       grads = op_module.large_margin_softmax_grad(inputs_features, inputs_weights, inputs_labels, grad, cur_lambda[0], margin_order)
       return [grads[0], grads[1], None, None]
    
     var_weights = tf.Variable(initial_value, trainable=True, name='lsoftmax_weights')
     result = large_margin_softmax(features, var_weights, labels, global_step, 4, 1000., 0.000025, 35., 0.)
     loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=result[0]))
    • For Angular Softmax Loss:
     op_module = tf.load_op_library(so_lib_path)
     angular_softmax = op_module.angular_softmax
    
     @ops.RegisterGradient("AngularSoftmax")
     def _angular_softmax_grad(op, grad, _):
       '''The gradients for `AngularSoftmax`.
       '''
       inputs_features = op.inputs[0]
       inputs_weights = op.inputs[1]
       inputs_labels = op.inputs[2]
       cur_lambda = op.outputs[1]
       margin_order = op.get_attr('margin_order')
    
       grads = op_module.angular_softmax_grad(inputs_features, inputs_weights, inputs_labels, grad, cur_lambda[0], margin_order)
       return [grads[0], grads[1], None, None]
    
     var_weights = tf.Variable(initial_value, trainable=True, name='asoftmax_weights')
     normed_var_weights = tf.nn.l2_normalize(var_weights, 1, 1e-10, name='weights_normed')
     result = angular_softmax(features, normed_var_weights, labels, global_step, 4, 1000., 0.000025, 35., 0.)
     loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=result[0]))
    • For others just refer to this script.

All the codes was tested under TensorFlow 1.6, Python 3.5, Ubuntu 16.04 with CUDA 8.0. The outputs of these Ops in C++ had been compared with the original caffe codes' outputs, and the bias could be ignored. The gradients of this Op had been checked using tf.test.compute_gradient_error and tf.test.compute_gradient. While the others are implemented following the official implementation in Python Ops.

If you encountered some linkage problem when generating or loading *.so, you are highly recommended to read this section in the official tourial to make sure you were using the same C++ ABI version.

Any contributions to this repo is welcomed.

MIT License

About

Large-Margin Softmax Loss, Angular Softmax Loss, Additive Margin Softmax, ArcFaceLoss And FocalLoss In Tensorflow

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published