Skip to content

Commit

Permalink
Merge pull request #53 from cvjena/fixes/fergus-labelsharing-bugs
Browse files Browse the repository at this point in the history
keras_labelsharing: Fixed hardcoded kappa value
  • Loading branch information
cabrust committed May 5, 2021
2 parents 285661b + 8c222de commit c0bc766
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions chia/components/classifiers/keras_labelsharing_hc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, kb, l2=5e-5, kappa=10.0):

# Configuration
self._l2_regularization_coefficient = l2
self._kappa = 10
self._kappa = kappa

self.last_observed_concept_count = len(
self.kb.concepts(flags={knowledge.ConceptFlag.PREDICTION_TARGET})
Expand Down Expand Up @@ -135,31 +135,20 @@ def embed(self, labels):

def deembed_dist(self, embedded_labels):
return [
[
(uid, embedded_label[dim] / embedded_label_sum)
for uid, dim in self.uid_to_dimension.items()
]
for (embedded_label, embedded_label_sum) in zip(
embedded_labels, np.sum(embedded_labels, axis=1)
)
[(uid, embedded_label[dim]) for uid, dim in self.uid_to_dimension.items()]
for embedded_label in embedded_labels
]

def loss(self, feature_batch, ground_truth, weight_batch, global_step):
embedding = self.embed(ground_truth)
prediction = self.predict_embedded(feature_batch)

# Binary cross entropy loss function from keras_idk
clipped_probs = tf.clip_by_value(prediction, 1e-7, (1.0 - 1e-7))
the_loss = -(
embedding * tf.math.log(clipped_probs)
+ (1.0 - embedding) * tf.math.log(1.0 - clipped_probs)
# We can use categorical cross-entropy because A is appropriately normalized
return tf.reduce_mean(
tf.keras.losses.categorical_crossentropy(embedding, prediction)
* weight_batch
)

# We can't use tf's binary_crossentropy because it always takes a mean around axis -1,
# but we need the sum
sum_per_batch_element = tf.reduce_sum(the_loss, axis=1)
return tf.reduce_mean(sum_per_batch_element * weight_batch)

def observe(self, samples, gt_resource_id):
self.maybe_update_embedding()

Expand Down Expand Up @@ -232,6 +221,9 @@ def _update_affinity_matrix(self):
self.affinity_matrix[i, j] = affinity_ij
self.affinity_matrix[j, i] = affinity_ij

# Normalize the affinity matrix to keep embeddings sum one
self.affinity_matrix /= self.affinity_matrix.sum(axis=1)

if not np.all(self.affinity_matrix > 0.0):
raise ValueError("Affinity matrix contains zero entry!")

Expand Down

0 comments on commit c0bc766

Please sign in to comment.