Skip to content

Commit

Permalink
Add is_pred_labels to struct labeler and increment version (#435)
Browse files Browse the repository at this point in the history
* feat: add is_pred_labels to struct labeler

* feat: increment version
  • Loading branch information
JGSweets committed Jan 28, 2022
1 parent 4cbf4e7 commit 3f3fa87
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 11 deletions.
30 changes: 20 additions & 10 deletions dataprofiler/labelers/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,10 +830,10 @@ def _word_level_argmax(self, data, predictions, label_mapping,
is_end = (idx == len(sample)-1 and start_idx > 0)

if not is_separator:
label = entities_in_sample[idx]
label = entities_in_sample[idx]
if label not in label_count:
label_count[label] = 0
label_count[label] += 1
label_count[label] += 1

if is_separator or is_end:

Expand Down Expand Up @@ -867,7 +867,7 @@ def _word_level_argmax(self, data, predictions, label_mapping,
label_count = {background_label: 0}
if char_pred[idx] == background_label and \
sample[idx] in separator_dict:
continue
continue
word_level_predictions.append(entities_in_sample)

return word_level_predictions
Expand Down Expand Up @@ -1220,7 +1220,8 @@ class StructCharPostprocessor(BaseDataPostprocessor,
metaclass=AutoSubRegistrationMeta):

def __init__(self, default_label='UNKNOWN', pad_label='PAD',
flatten_separator="\x01"*5, random_state=None):
flatten_separator="\x01"*5, is_pred_labels=True,
random_state=None):
"""
Initialize the StructCharPostprocessor class
Expand All @@ -1231,6 +1232,9 @@ def __init__(self, default_label='UNKNOWN', pad_label='PAD',
:param flatten_separator: separator used to put between flattened
samples.
:type flatten_separator: str
:param is_pred_labels: (default: true) if true, will convert the model
indexes to the label strings given the label_mapping
:type is_pred_labels: bool
:param random_state: random state setting to be used for randomly
selecting a prediction when two labels have equal opportunity for
a given sample.
Expand All @@ -1256,6 +1260,7 @@ def __init__(self, default_label='UNKNOWN', pad_label='PAD',
super().__init__(default_label=default_label,
pad_label=pad_label,
flatten_separator=flatten_separator,
is_pred_labels=is_pred_labels,
random_state=random_state)

def __eq__(self, other):
Expand All @@ -1276,7 +1281,9 @@ def __eq__(self, other):
or self._parameters["pad_label"] != \
other._parameters["pad_label"]\
or self._parameters["flatten_separator"] != \
other._parameters["flatten_separator"]:
other._parameters["flatten_separator"] \
or self._parameters["is_pred_labels"] != \
other._parameters["is_pred_labels"]:
return False
return True

Expand All @@ -1303,6 +1310,8 @@ def _validate_parameters(self, parameters):
if param in ['default_label', 'pad_label', 'flatten_separator'] \
and not isinstance(value, str):
errors.append("`{}` must be a string.".format(param))
if param in ['is_pred_labels'] and not isinstance(value, bool):
errors.append("`{}` must be a boolean.".format(param))
if param == 'random_state' and not isinstance(value, random.Random):
errors.append('`{}` must be a random.Random.'.format(param))
elif param not in allowed_parameters:
Expand Down Expand Up @@ -1483,6 +1492,7 @@ def process(self, data, results, label_mapping):
flatten_separator = self._parameters['flatten_separator']
default_label = self._parameters['default_label']
pad_label = self._parameters['pad_label']
is_pred_labels = self._parameters['is_pred_labels']

# Format predictions
# FORMER DEEPCOPY, SHALLOW AS ONLY INTERNAL
Expand All @@ -1494,11 +1504,11 @@ def process(self, data, results, label_mapping):
default_label=default_label,
pad_label=pad_label)

reverse_label_mapping = {v: k for k, v in label_mapping.items()}
rev_label_map_vec_func = np.vectorize(
lambda x: reverse_label_mapping.get(x, None))

results['pred'] = rev_label_map_vec_func(results['pred'])
if is_pred_labels:
reverse_label_mapping = {v: k for k, v in label_mapping.items()}
rev_label_map_vec_func = np.vectorize(
lambda x: reverse_label_mapping.get(x, None))
results['pred'] = rev_label_map_vec_func(results['pred'])
return results

def _save_processor(self, dirpath):
Expand Down
20 changes: 20 additions & 0 deletions dataprofiler/tests/labelers/test_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,13 +1852,15 @@ def test_get_parameters(self):
self.assertDictEqual(dict(default_label='UNKNOWN',
pad_label='PAD',
flatten_separator='\x01'*5,
is_pred_labels=True,
random_state=random_state),
processor.get_parameters())

# test set params
params = dict(default_label='test default',
pad_label='test pad',
flatten_separator='test',
is_pred_labels=False,
random_state=random_state)
processor = StructCharPostprocessor(**params)
self.assertDictEqual(params, processor.get_parameters())
Expand All @@ -1867,6 +1869,7 @@ def test_get_parameters(self):
params = dict(default_label='test default',
pad_label='test pad',
flatten_separator='test',
is_pred_labels=False,
random_state=random_state)
processor = StructCharPostprocessor(**params)
self.assertDictEqual(
Expand Down Expand Up @@ -1911,7 +1914,24 @@ def test_process(self):
self.assertIn('pred', output)
self.assertTrue((expected_output['pred'] == output['pred']).all())

# test with is_pred_labels = False
processor = StructCharPostprocessor(
default_label='UNKNOWN',
pad_label='PAD',
is_pred_labels=False,
flatten_separator='\x01' * 5)
expected_output_ints = dict(pred=np.array([2, 3, 1, 3, 2]))
output = processor.process(data, results, label_mapping)

self.assertIn('pred', output)
self.assertTrue((expected_output_ints['pred'] == output['pred']).all())

# with confidences
processor = StructCharPostprocessor(
default_label='UNKNOWN',
pad_label='PAD',
is_pred_labels=True,
flatten_separator='\x01' * 5)
confidences = []
for sample in results['pred']:
confidences.append([])
Expand Down
2 changes: 1 addition & 1 deletion dataprofiler/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

MAJOR = 0
MINOR = 7
MICRO = 4
MICRO = 5

VERSION = '%d.%d.%d' % (MAJOR, MINOR, MICRO)

Expand Down

0 comments on commit 3f3fa87

Please sign in to comment.