Skip to content

Commit

Permalink
Merge pull request #50 from cvjena/fixes/label-sharing-loss
Browse files Browse the repository at this point in the history
keras_labelsharing: Fix wrong reduction in loss function
  • Loading branch information
cabrust committed Apr 13, 2021
2 parents 9b76fa9 + 397dd25 commit adeeb28
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions chia/components/classifiers/keras_labelsharing_hc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def update_embedding(self):
)
if self._l2_regularization_coefficient > 0.0
else None,
kernel_initializer="zero",
bias_initializer="zero",
)

update_uids = [
Expand Down Expand Up @@ -143,16 +145,20 @@ def deembed_dist(self, embedded_labels):
]

def loss(self, feature_batch, ground_truth, weight_batch, global_step):
embedded_predictions = self.predict_embedded(feature_batch)
embedded_ground_truth = self.embed(ground_truth)

loss = tf.reduce_mean(
tf.keras.losses.binary_crossentropy(
embedded_ground_truth, embedded_predictions
)
* weight_batch
embedding = self.embed(ground_truth)
prediction = self.predict_embedded(feature_batch)

# Binary cross entropy loss function from keras_idk
clipped_probs = tf.clip_by_value(prediction, 1e-7, (1.0 - 1e-7))
the_loss = -(
embedding * tf.math.log(clipped_probs)
+ (1.0 - embedding) * tf.math.log(1.0 - clipped_probs)
)
return loss

# We can't use tf's binary_crossentropy because it always takes a mean around axis -1,
# but we need the sum
sum_per_batch_element = tf.reduce_sum(the_loss, axis=1)
return tf.reduce_mean(sum_per_batch_element * weight_batch)

def observe(self, samples, gt_resource_id):
self.maybe_update_embedding()
Expand Down

0 comments on commit adeeb28

Please sign in to comment.