Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
Closes #3
Browse files Browse the repository at this point in the history
Updated use of bidict to take advantage of new API.
  • Loading branch information
capcarr committed Apr 20, 2016
1 parent 1e0edaa commit bf69177
Showing 1 changed file with 9 additions and 58 deletions.
67 changes: 9 additions & 58 deletions biosppy/biometrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,55 +57,6 @@ def __str__(self):
return str("Combination of empty array.")


class SubjectDict(bidict):
"""Adaptation of bidirectional dictionary to return default values
on KeyError.
Attributes
----------
LEFT : hashable
Left default token.
Right : hashable
Right default token.
"""

LEFT = ''
RIGHT = ''

def __getitem__(self, keyorslice):
"""Get an item; based on the bidict source."""

try:
start = keyorslice.start
stop = keyorslice.stop
step = keyorslice.step
except AttributeError:
# keyorslice is a key, e.g. b[key]
try:
return self._fwd[keyorslice]
except KeyError:
return self.RIGHT

# keyorslice is a slice
if (not ((start is None) ^ (stop is None))) or step is not None:
raise TypeError('Slice must only specify either start or stop')

if start is not None:
# forward lookup (by key), e.g. b[key:]
try:
return self._fwd[start]
except KeyError:
return self.RIGHT

# inverse lookup (by val), e.g. b[:val]
assert stop is not None
try:
return self._bwd[stop]
except KeyError:
return self.LEFT


class BaseClassifier(object):
"""Base biometric classifier class.
Expand Down Expand Up @@ -133,7 +84,7 @@ class BaseClassifier(object):
def __init__(self):
# generic self things
self.is_trained = False
self._subject2label = SubjectDict()
self._subject2label = bidict()
self._nbSubjects = 0
self._thresholds = {}
self._autoThresholds = None
Expand Down Expand Up @@ -296,7 +247,7 @@ def list_subjects(self):
"""

subjects = [self._subject2label[:i] for i in xrange(self._nbSubjects)]
subjects = self._subject2label.keys()

return subjects

Expand Down Expand Up @@ -448,7 +399,7 @@ def update_thresholds(self, fraction=1.):

# gather data to test
data = {}
for subject, label in self._subject2label.items():
for subject, label in self._subject2label.iteritems():
# select a random fraction of the training data
aux = self.io_load(label)
indx = range(len(aux))
Expand All @@ -460,7 +411,7 @@ def update_thresholds(self, fraction=1.):
_, res = self.evaluate(data, ths)

# choose thresholds at EER
for subject, label in self._subject2label.items():
for subject, label in self._subject2label.iteritems():
EER_auth = res['subject'][subject]['authentication']['rates']['EER']
self.set_auth_thr(label, EER_auth[self.EER_IDX, 0], ready=True)

Expand Down Expand Up @@ -653,7 +604,7 @@ def identify(self, data, threshold=None):
labels = self._identify(aux, threshold)

# translate class labels
subjects = [self._subject2label[:item] for item in labels]
subjects = [self._subject2label.inv.get(item, '') for item in labels]

return subjects

Expand Down Expand Up @@ -1216,7 +1167,7 @@ class SVM(BaseClassifier):
Degree of the polynomial kernel function (‘poly’). Ignored by all other
kernels.
gamma : float, optional
Kernel coefficient for ‘rbf’, ‘poly’ and ‘sigmoid’. If gamma is 0.0
Kernel coefficient for ‘rbf’, ‘poly’ and ‘sigmoid’. If gamma is 'auto'
then 1/n_features will be used instead.
coef0 : float, optional
Independent term in kernel function. It is only significant in ‘poly’
Expand Down Expand Up @@ -1246,7 +1197,7 @@ def __init__(self,
C=1.0,
kernel='linear',
degree=3,
gamma=0.0,
gamma='auto',
coef0=0.0,
shrinking=True,
tol=0.001,
Expand Down Expand Up @@ -1843,7 +1794,7 @@ def get_subject_results(results=None,
Classifier thresholds.
subjects : list
Target subject classes.
subject_dict : SubjectDict
subject_dict : bidict
Subject-label conversion dictionary.
subject_idx : list
Subject index.
Expand Down Expand Up @@ -1927,7 +1878,7 @@ def get_subject_results(results=None,
misses = np.logical_not(np.logical_or(hits, rejects))
nmisses = ns - (nhits + nrejects)
missCounts = {
subject_dict[:ms]: np.sum(res == ms)
subject_dict.inv[ms]: np.sum(res == ms)
for ms in np.unique(res[misses])
}

Expand Down

0 comments on commit bf69177

Please sign in to comment.