Skip to content

Commit

Permalink
More loss functions, exposing mutual information and epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
Yves-Laurent committed Apr 10, 2022
1 parent 04b1401 commit b7c3d4c
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 16 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
Fixing some package incompatibilities.

## Changes

* Adding more (robust) mutual information loss functions.
* Exposing the learned total mutual information between principal features and target as an attribute of PFS.
* Exposing the number of epochs as a parameter of PFS' fit.
2 changes: 2 additions & 0 deletions kxy/misc/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
except:
import logging
logging.warning('You need tensorflow version 2.8 or higher to estimate mutual information or copula entropy locally.')

from .generators import *
from .ops import *
from .layers import *
from .losses import *
from .models import *
Expand Down
6 changes: 3 additions & 3 deletions kxy/misc/tf/learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .generators import CopulaBatchGenerator, PFSBatchGenerator
from .models import CopulaModel, PFSModel, PFSOneShotModel
from .losses import MINDLoss
from .losses import MINDLoss, ApproximateMINDLoss, RectifiedMINDLoss


class CopulaLearner(object):
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, dx, dy=1, dox=0, doy=0, beta_1=0.9, beta_2=0.999, epsilon=1e-
self.model = PFSModel(x_ixs, y_ixs, ox_ixs=ox_ixs, oy_ixs=oy_ixs)
self.opt = Adam(beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, amsgrad=amsgrad, \
name=name, lr=lr)
self.loss = MINDLoss()
self.loss = RectifiedMINDLoss() # MINDLoss()
self.model.compile(optimizer=self.opt, loss=self.loss)
self.mutual_information = None
self.feature_direction = None
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(self, dx, dy=1, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=Fa
self.model = PFSOneShotModel(x_ixs, y_ixs, p=p)
self.opt = Adam(beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, amsgrad=amsgrad, \
name=name, lr=lr)
self.loss = MINDLoss()
self.loss = RectifiedMINDLoss() # MINDLoss()
self.model.compile(optimizer=self.opt, loss=self.loss)
self.mutual_information = None
self.feature_direction = None
Expand Down
31 changes: 30 additions & 1 deletion kxy/misc/tf/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,43 @@
from tensorflow.python.ops import math_ops
from tensorflow.keras.losses import Loss

from .ops import rectified_exp, d_rectified_exp


class MINDLoss(Loss):
'''
Loss function.
MIND loss function: :math:`-E_P(T(x, y)^T\theta) + \log E_Q(e^{T(x, y)^T\theta})`.
'''
def call(self, y_true, y_pred):
''' '''
p_samples = y_pred[:, 0]
q_samples = y_pred[:, 1]
mi = -tf.reduce_mean(p_samples) + math_ops.log(tf.reduce_mean(math_ops.exp(q_samples)))
return mi


class ApproximateMINDLoss(Loss):
'''
MIND loss function with a gentler version of the exponential: :math:`-E_P(r_exp(T(x, y)^T\theta)) + \log E_Q(dr_exp(T(x, y)^T\theta)`. :math:`r_exp(t) = exp(t)` if :math:`t<0` and :math:`r_exp(t) = 1+x+(1/2)x^2+(1/6)x^2`.
'''
def call(self, y_true, y_pred):
''' '''
p_samples = y_pred[:, 0]
q_samples = y_pred[:, 1]
mi = -tf.reduce_mean(p_samples) + math_ops.log(tf.reduce_mean(rectified_exp(q_samples)))
return mi


class RectifiedMINDLoss(Loss):
'''
Rectified-MIND loss function: :math:`-E_P(\log dr_exp((T(x, y)^T\theta)) + \log E_Q(dr_exp(T(x, y)^T\theta)`. :math:`r_exp(t) = exp(t)` if :math:`t<0` and :math:`r_exp(t) = 1+x+(1/2)x^2+(1/6)x^2`.
'''
def call(self, y_true, y_pred):
''' '''
p_samples = y_pred[:, 0]
q_samples = y_pred[:, 1]
mi = -tf.reduce_mean(math_ops.log(d_rectified_exp(p_samples))) + math_ops.log(tf.reduce_mean(d_rectified_exp(q_samples)))
return mi



2 changes: 1 addition & 1 deletion kxy/misc/tf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def call(self, inputs):
t_q = self.statistics(q_samples)

t = concatenate([t_p, t_q], axis=1)
t = clip(t, -200., 200.)
t = clip(t, -200., 400.)
return t


Expand Down
35 changes: 35 additions & 0 deletions kxy/misc/tf/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Custom math operations.
"""
from multiprocessing import Pool, cpu_count
import numpy as np

import tensorflow as tf
tf.keras.backend.set_floatx('float64')
tf.config.threading.set_inter_op_parallelism_threads(2)
tf.config.threading.set_intra_op_parallelism_threads(8)
tf.config.set_soft_device_placement(True)
from tensorflow.python.ops import math_ops

def rectified_exp(t):
'''
:math:`r_exp(t) = exp(t)` if :math:`t<0` and :math:`r_exp(t) = 1+x+(1/2)x^2+(1/6)x^3`.
'''
exp = math_ops.exp(t)
approx_exp = 1.+t+(1./2.)*tf.math.pow(t, 2.)+(1./6.)*tf.math.pow(t, 3.)
condition = tf.greater(t, 0.0)
r_exp = tf.where(condition, x=approx_exp, y=exp)
return r_exp


def d_rectified_exp(t):
'''
:math:`dr_exp(t) = exp(t)` if :math:`t<0` and :math:`dr_exp(t) = 1+x+(1/2)x^2`.
'''
dexp = math_ops.exp(t)
approx_dexp = 1.+t+(1./2.)*tf.math.pow(t, 2.)
condition = tf.greater(t, 0.0)
dr_exp = tf.where(condition, x=approx_dexp, y=dexp)
return dr_exp
20 changes: 12 additions & 8 deletions kxy/pfs/pfs_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@



def learn_principal_direction(y, x, ox=None, oy=None):
def learn_principal_direction(y, x, ox=None, oy=None, epochs=20):
"""
Learn the i-th principal feature when using :math:`x` to predict :math:`y`.
Expand All @@ -36,7 +36,7 @@ def learn_principal_direction(y, x, ox=None, oy=None):
doy = 0 if oy is None else 1 if len(oy.shape) == 1 else oy.shape[1]

learner = PFSLearner(dx, dy=dy, dox=dox, doy=doy)
learner.fit(x, y, ox=ox, oy=oy)
learner.fit(x, y, ox=ox, oy=oy, epochs=epochs)

mi = learner.mutual_information
w = learner.feature_direction
Expand All @@ -47,7 +47,7 @@ def learn_principal_direction(y, x, ox=None, oy=None):



def learn_principal_directions_one_shot(y, x, p):
def learn_principal_directions_one_shot(y, x, p, epochs=20):
"""
Jointly learn p principal features.
Expand All @@ -67,10 +67,11 @@ def learn_principal_directions_one_shot(y, x, p):
"""
dx = 1 if len(x.shape) == 1 else x.shape[1]
learner = PFSOneShotLearner(dx, p=p)
learner.fit(x, y)
learner.fit(x, y, epochs=epochs)
w = learner.feature_directions
mi = learner.mutual_information

return w
return w, mi



Expand All @@ -79,7 +80,7 @@ class PFS(object):
"""
Principal Feature Selection.
"""
def fit(self, x, y, p=None, mi_tolerance=0.0001, max_duration=None):
def fit(self, x, y, p=None, mi_tolerance=0.0001, max_duration=None, epochs=20):
"""
Perform Principal Feature Selection using :math:`x` to predict :math:`y`.
Expand Down Expand Up @@ -124,7 +125,7 @@ def fit(self, x, y, p=None, mi_tolerance=0.0001, max_duration=None):
ox = None
oy = None
for i in range(d):
w, mi, ox, oy = learn_principal_direction(t, x, ox=ox, oy=oy)
w, mi, ox, oy = learn_principal_direction(t, x, ox=ox, oy=oy, epochs=epochs)

if mi-old_mi < mi_tolerance:
logging.info('The mutual information %.4f after %d round has not increase by more than %.4f: stopping.' % (
Expand All @@ -146,9 +147,12 @@ def fit(self, x, y, p=None, mi_tolerance=0.0001, max_duration=None):
rows += [w.copy()]

self.feature_directions = np.array(rows)
self.mutual_information = old_mi
else:
# Learn all p principal features jointly.
self.feature_directions = learn_principal_directions_one_shot(y, x, p)
feature_directions, mi = learn_principal_directions_one_shot(y, x, p, epochs=epochs)
self.feature_directions = feature_directions
self.mutual_information = mi

return self.feature_directions

Expand Down
4 changes: 2 additions & 2 deletions tests/test_pfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_pfs_accuracy():

# Run PFS
selector = PFS()
selector.fit(x, y)
selector.fit(x, y, epochs=21)

# Learned principal directions
F = selector.feature_directions
Expand All @@ -126,6 +126,6 @@ def test_pfs_accuracy():
e = np.linalg.norm(true_f_1-learned_f_1)

assert e <= 0.10

assert selector.mutual_information > 1.0


0 comments on commit b7c3d4c

Please sign in to comment.