/
metric_spec.py
431 lines (365 loc) · 16 KB
/
metric_spec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The metric spec class to flexibly connect models and metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
def _assert_named_args(sentinel):
if sentinel is not None:
raise ValueError(
'`metric_fn` requires named args: '
'`labels`, `predictions`, and optionally `weights`.')
def _args(fn):
"""Get argument names for function-like object.
Args:
fn: Function, or function-like object (e.g., result of `functools.partial`).
Returns:
`tuple` of string argument names.
"""
if hasattr(fn, 'func') and hasattr(fn, 'keywords'):
# Handle functools.partial and similar objects.
return tuple(
[arg for arg in _args(fn.func) if arg not in set(fn.keywords.keys())])
# Handle function.
return tuple(tf_inspect.getargspec(fn).args)
_CANONICAL_LABELS_ARG = 'labels'
_LABELS_ARGS = set((_CANONICAL_LABELS_ARG, 'label', 'targets', 'target'))
_CANONICAL_PREDICTIONS_ARG = 'predictions'
_PREDICTIONS_ARGS = set((_CANONICAL_PREDICTIONS_ARG, 'prediction',
'logits', 'logit'))
_CANONICAL_WEIGHTS_ARG = 'weights'
_WEIGHTS_ARGS = set((_CANONICAL_WEIGHTS_ARG, 'weight'))
def _matching_arg(
fn_name, fn_args, candidate_args, canonical_arg, is_required=False):
"""Find single argument in `args` from `candidate_args`.
Args:
fn_name: Function name, only used for error string.
fn_args: String argument names to `fn_name` function.
candidate_args: Candidate argument names to find in `args`.
canonical_arg: Canonical argument name in `candidate_args`. This is only
used to log a warning if a non-canonical match is found.
is_required: Whether function is required to have an arg in
`candidate_args`.
Returns:
String argument name if found, or `None` if not found.
Raises:
ValueError: if 2 candidates are found, or 0 are found and `is_required` is
set.
"""
assert canonical_arg in candidate_args # Sanity check.
matching_args = candidate_args.intersection(fn_args)
if len(matching_args) > 1:
raise ValueError(
'Ambiguous arguments %s, must provide only one of %s.' % (
matching_args, candidate_args))
matching_arg = matching_args.pop() if matching_args else None
if matching_arg:
if matching_arg != canonical_arg:
logging.warning(
'Canonical arg %s missing from %s(%s), using %s.',
canonical_arg, fn_name, fn_args, matching_arg)
elif is_required:
raise ValueError(
'%s missing from %s(%s).' % (candidate_args, fn_name, fn_args))
return matching_arg
def _fn_name(fn):
if hasattr(fn, '__name__'):
return fn.__name__
if hasattr(fn, 'func') and hasattr(fn.func, '__name__'):
return fn.func.__name__ # If it's a functools.partial.
return str(fn)
def _adapt_metric_fn(
metric_fn, metric_fn_name, is_labels_required, is_weights_required):
"""Adapt `metric_fn` to take only named args.
This returns a function that takes only named args `labels`, `predictions`,
and `weights`, and invokes `metric_fn` according to the following rules:
- If `metric_fn` args include exactly one of `_LABELS_ARGS`, that arg is
passed (usually by name, but positionally if both it and `predictions` need
to be passed positionally). Otherwise, `labels` are omitted.
- If `metric_fn` args include exactly one of `_PREDICTIONS_ARGS`, that arg is
passed by name. Otherwise, `predictions` are passed positionally as the
first non-label argument.
- If exactly one of `_WEIGHTS_ARGS` is provided, that arg is passed by
name.
Args:
metric_fn: Metric function to be wrapped.
metric_fn_name: `metric_fn` name, only used for logging.
is_labels_required: Whether `labels` is a required arg.
is_weights_required: Whether `weights` is a required arg.
Returns:
Function accepting only named args `labels, `predictions`, and `weights`,
and passing those to `metric_fn`.
Raises:
ValueError: if one of the following is true:
- `metric_fn` has more than one arg of `_LABELS_ARGS`, `_PREDICTIONS_ARGS`,
or `_WEIGHTS_ARGS`
- `is_labels_required` is true, and `metric_fn` has no arg from
`_LABELS_ARGS`.
- `is_weights_required` is true, and `metric_fn` has no arg from
`_WEIGHTS_ARGS`.
"""
args = _args(metric_fn)
labels_arg = _matching_arg(
metric_fn_name, args, _LABELS_ARGS, _CANONICAL_LABELS_ARG,
is_labels_required)
predictions_arg = _matching_arg(
metric_fn_name, args, _PREDICTIONS_ARGS, _CANONICAL_PREDICTIONS_ARG)
weights_arg = _matching_arg(
metric_fn_name, args, _WEIGHTS_ARGS, _CANONICAL_WEIGHTS_ARG,
is_weights_required)
# pylint: disable=invalid-name
if labels_arg:
if predictions_arg:
# Both labels and predictions are named args.
def _named_metric_fn(
_sentinel=None, labels=None, predictions=None, weights=None):
_assert_named_args(_sentinel)
kwargs = {
labels_arg: labels,
predictions_arg: predictions,
}
if weights is not None:
kwargs[weights_arg] = weights
return metric_fn(**kwargs)
return _named_metric_fn
if labels_arg == args[0]:
# labels is a named arg, and first. predictions is not a named arg, so we
# want to pass it as the 2nd positional arg (i.e., the first non-labels
# position), which means passing both positionally.
def _positional_metric_fn(
_sentinel=None, labels=None, predictions=None, weights=None):
_assert_named_args(_sentinel)
# TODO(ptucker): Should we support metrics that take only labels?
# Currently, if you want streaming mean of a label, you have to wrap it
# in a fn that takes discards predictions.
if weights is None:
return metric_fn(labels, predictions)
return metric_fn(labels, predictions, **{weights_arg: weights})
return _positional_metric_fn
# labels is a named arg, and not first, so we pass predictions positionally
# and labels by name.
def _positional_predictions_metric_fn(
_sentinel=None, labels=None, predictions=None, weights=None):
_assert_named_args(_sentinel)
kwargs = {
labels_arg: labels,
}
if weights is not None:
kwargs[weights_arg] = weights
return metric_fn(predictions, **kwargs)
return _positional_predictions_metric_fn
if predictions_arg:
# No labels, and predictions is named, so we pass the latter as a named arg.
def _named_no_labels_metric_fn(
_sentinel=None, labels=None, predictions=None, weights=None):
del labels
_assert_named_args(_sentinel)
kwargs = {
predictions_arg: predictions,
}
# TODO(ptucker): Should we allow weights with no labels?
if weights is not None:
kwargs[weights_arg] = weights
return metric_fn(**kwargs)
return _named_no_labels_metric_fn
# Neither labels nor predictions are named, so we just pass predictions as the
# first arg.
def _positional_no_labels_metric_fn(
_sentinel=None, labels=None, predictions=None, weights=None):
del labels
_assert_named_args(_sentinel)
if weights is None:
return metric_fn(predictions)
# TODO(ptucker): Should we allow weights with no labels?
return metric_fn(predictions, **{weights_arg: weights})
return _positional_no_labels_metric_fn
class MetricSpec(object):
"""MetricSpec connects a model to metric functions.
The MetricSpec class contains all information necessary to connect the
output of a `model_fn` to the metrics (usually, streaming metrics) that are
used in evaluation.
It is passed in the `metrics` argument of `Estimator.evaluate`. The
`Estimator` then knows which predictions, labels, and weight to use to call a
given metric function.
When building the ops to run in evaluation, an `Estimator` will call
`create_metric_ops`, which will connect the given `metric_fn` to the model
as detailed in the docstring for `create_metric_ops`, and return the metric.
Example:
Assuming a model has an input function which returns inputs containing
(among other things) a tensor with key "input_key", and a labels dictionary
containing "label_key". Let's assume that the `model_fn` for this model
returns a prediction with key "prediction_key".
In order to compute the accuracy of the "prediction_key" prediction, we
would add
```
"prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn,
prediction_key="prediction_key",
label_key="label_key")
```
to the metrics argument to `evaluate`. `prediction_accuracy_fn` can be either
a predefined function in metric_ops (e.g., `streaming_accuracy`) or a custom
function you define.
If we would like the accuracy to be weighted by "input_key", we can add that
as the `weight_key` argument.
```
"prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn,
prediction_key="prediction_key",
label_key="label_key",
weight_key="input_key")
```
An end-to-end example is as follows:
```
estimator = tf.contrib.learn.Estimator(...)
estimator.fit(...)
_ = estimator.evaluate(
input_fn=input_fn,
steps=1,
metrics={
'prediction accuracy':
metric_spec.MetricSpec(
metric_fn=prediction_accuracy_fn,
prediction_key="prediction_key",
label_key="label_key")
})
```
"""
def __init__(self,
metric_fn,
prediction_key=None,
label_key=None,
weight_key=None):
"""Constructor.
Creates a MetricSpec.
Args:
metric_fn: A function to use as a metric. See `_adapt_metric_fn` for
rules on how `predictions`, `labels`, and `weights` are passed to this
function. This must return either a single `Tensor`, which is
interpreted as a value of this metric, or a pair
`(value_op, update_op)`, where `value_op` is the op to call to
obtain the value of the metric, and `update_op` should be run for
each batch to update internal state.
prediction_key: The key for a tensor in the `predictions` dict (output
from the `model_fn`) to use as the `predictions` input to the
`metric_fn`. Optional. If `None`, the `model_fn` must return a single
tensor or a dict with only a single entry as `predictions`.
label_key: The key for a tensor in the `labels` dict (output from the
`input_fn`) to use as the `labels` input to the `metric_fn`.
Optional. If `None`, the `input_fn` must return a single tensor or a
dict with only a single entry as `labels`.
weight_key: The key for a tensor in the `inputs` dict (output from the
`input_fn`) to use as the `weights` input to the `metric_fn`.
Optional. If `None`, no weights will be passed to the `metric_fn`.
"""
self._metric_fn_name = _fn_name(metric_fn)
self._metric_fn = _adapt_metric_fn(
metric_fn=metric_fn,
metric_fn_name=self._metric_fn_name,
is_labels_required=label_key is not None,
is_weights_required=weight_key is not None)
self._prediction_key = prediction_key
self._label_key = label_key
self._weight_key = weight_key
@property
def prediction_key(self):
return self._prediction_key
@property
def label_key(self):
return self._label_key
@property
def weight_key(self):
return self._weight_key
@property
def metric_fn(self):
"""Metric function.
This function accepts named args: `predictions`, `labels`, `weights`. It
returns a single `Tensor` or `(value_op, update_op)` pair. See `metric_fn`
constructor argument for more details.
Returns:
Function, see `metric_fn` constructor argument for more details.
"""
return self._metric_fn
def __str__(self):
return ('MetricSpec(metric_fn=%s, ' % self._metric_fn_name +
'prediction_key=%s, ' % self.prediction_key +
'label_key=%s, ' % self.label_key +
'weight_key=%s)' % self.weight_key
)
def create_metric_ops(self, inputs, labels, predictions):
"""Connect our `metric_fn` to the specified members of the given dicts.
This function will call the `metric_fn` given in our constructor as follows:
```
metric_fn(predictions[self.prediction_key],
labels[self.label_key],
weights=weights[self.weight_key])
```
And returns the result. The `weights` argument is only passed if
`self.weight_key` is not `None`.
`predictions` and `labels` may be single tensors as well as dicts. If
`predictions` is a single tensor, `self.prediction_key` must be `None`. If
`predictions` is a single element dict, `self.prediction_key` is allowed to
be `None`. Conversely, if `labels` is a single tensor, `self.label_key` must
be `None`. If `labels` is a single element dict, `self.label_key` is allowed
to be `None`.
Args:
inputs: A dict of inputs produced by the `input_fn`
labels: A dict of labels or a single label tensor produced by the
`input_fn`.
predictions: A dict of predictions or a single tensor produced by the
`model_fn`.
Returns:
The result of calling `metric_fn`.
Raises:
ValueError: If `predictions` or `labels` is a single `Tensor` and
`self.prediction_key` or `self.label_key` is not `None`; or if
`self.label_key` is `None` but `labels` is a dict with more than one
element, or if `self.prediction_key` is `None` but `predictions` is a
dict with more than one element.
"""
def _get_dict(name, dict_or_tensor, key):
"""Get a single tensor or an element of a dict or raise ValueError."""
if key:
if not isinstance(dict_or_tensor, dict):
raise ValueError('MetricSpec with ' + name + '_key specified'
' requires ' +
name + 's dict, got %s.\n' % dict_or_tensor +
'You must not provide a %s_key if you ' % name +
'only have a single Tensor as %ss.' % name)
if key not in dict_or_tensor:
raise KeyError(
'Key \'%s\' missing from %s.' % (key, dict_or_tensor.keys()))
return dict_or_tensor[key]
else:
if isinstance(dict_or_tensor, dict):
if len(dict_or_tensor) != 1:
raise ValueError('MetricSpec without specified ' + name + '_key'
' requires ' + name + 's tensor or single element'
' dict, got %s' % dict_or_tensor)
return six.next(six.itervalues(dict_or_tensor))
return dict_or_tensor
# Get the predictions.
prediction = _get_dict('prediction', predictions, self.prediction_key)
# Get the labels.
label = _get_dict('label', labels, self.label_key)
try:
return self.metric_fn(
labels=label,
predictions=prediction,
weights=inputs[self.weight_key] if self.weight_key else None)
except Exception as ex:
logging.error('Could not create metric ops for %s, %s.' % (self, ex))
raise