Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Trainer refactoring #66

Merged
merged 92 commits into from
Jun 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
0ce953a
Start refactoring trainer and experiment
justusschock Feb 22, 2019
722cd55
make new torch trainer and experiment working
justusschock Feb 25, 2019
0ee1c08
Merge branch 'master' into trainer_refactoring
justusschock Feb 26, 2019
1dfa7ce
Update pytorch_trainer.py
justusschock Feb 26, 2019
ec0c11a
Update pytorch_trainer.py
justusschock Feb 26, 2019
a4b8743
Merge branch 'trainer_refactoring' of https://github.com/justusschock…
mibaumgartner Mar 11, 2019
151a3df
First version of auc metric implementation
mibaumgartner Mar 11, 2019
da4f474
Resolve DImension Error
justusschock Mar 12, 2019
71387c7
Move `_is_better_val_score` from predictor to abstract trainer
justusschock Mar 12, 2019
17478ce
Fix error in torchvision datasets due to latest torchvision release
justusschock Mar 12, 2019
1a8cca0
Fix error due to missing brackets for file extension checks
justusschock Mar 12, 2019
e467013
Add docstrings
justusschock Mar 12, 2019
f23b52a
Update pytorch_trainer docstrings
justusschock Mar 12, 2019
af66224
Merge remote-tracking branch 'upstream/trainer_refactoring' into trai…
mibaumgartner Mar 12, 2019
503ce15
Finished auc metric and added doc strings
mibaumgartner Mar 12, 2019
1575430
Added unit test for metric wrapper and auc metric
mibaumgartner Mar 12, 2019
06a9be9
Remove dummy experiment file
mibaumgartner Mar 12, 2019
bd25321
Remove todo
mibaumgartner Mar 12, 2019
667733f
Added docstring
mibaumgartner Mar 12, 2019
1c9ace3
Updated docstring
mibaumgartner Mar 12, 2019
ec04ddf
Update test_metrics.py
justusschock Mar 13, 2019
cd1b686
Merge branch 'master' into trainer_refactoring
justusschock Mar 13, 2019
775b684
Merge pull request #73 from mibaumgartner/trainer_refactoring
justusschock Mar 13, 2019
59046e4
resolve merge conflicts and merge
ORippler Mar 22, 2019
abc91b9
update TfExperiment.test, change test signature
ORippler Mar 22, 2019
b4c7773
Make consistent logging
justusschock Apr 7, 2019
85566f6
Continue merging trainers
justusschock Apr 7, 2019
d43cbc6
update tests
justusschock Apr 7, 2019
5ca4024
fix feed errors by using default placeholder
ORippler Apr 9, 2019
c645c76
Change list outputs to dict outputs
justusschock Apr 16, 2019
23dd11d
Change networks to return dicts
justusschock Apr 16, 2019
7fba516
remove duplicate argument in tf trainer
justusschock Apr 16, 2019
55c1ed9
deprecate models in favor of to be announced repo
justusschock May 9, 2019
506960b
Merge branch 'master' into trainer_refactoring
justusschock May 9, 2019
444b6b8
start updating notebooks
justusschock May 9, 2019
8c8c9cd
Start Experiment refactoring
justusschock May 9, 2019
3b83c6b
Update Experiment Tests and Fix Bugs
justusschock May 9, 2019
7f9bea4
Fix Naming issue
justusschock May 9, 2019
a9760fc
continue test fixing
justusschock May 10, 2019
c0b7553
Merge branch 'trainer_refactoring' of https://github.com/justusschock…
justusschock May 10, 2019
23436ed
add docstrings
justusschock May 10, 2019
e1189b7
fix minor bugs, change descriptor during validation phase and introdu…
justusschock May 10, 2019
866b0cf
Fix some bugs, torch tests should complete now
mibaumgartner May 11, 2019
f189575
Add resume to tf trainer
justusschock May 11, 2019
2b56ed4
merge master into trainer_refactoring
justusschock May 12, 2019
2a5cfb2
fix merge
justusschock May 12, 2019
3c48c98
fix indent tf trainer
justusschock May 12, 2019
1053486
switch to suitable permutation
justusschock May 12, 2019
424318c
specify `data` in signature as dict type
ORippler May 12, 2019
e9a221d
switch from list to dicts for tf
ORippler May 12, 2019
829e241
pep8
ORippler May 12, 2019
6f023a8
change to dicts
ORippler May 12, 2019
399fd7d
Adress requested changes: update start_epoch for tf_trainer, support …
justusschock May 13, 2019
3068ab9
fix indent causing swigpy pickling error
justusschock May 13, 2019
efa61da
fix tf experiment test and type annotations
justusschock May 13, 2019
6a3e5c2
detach network outputs, remove params from test in TF Experiment
justusschock May 13, 2019
a9383de
key 'prediction' to 'pred'
ORippler May 13, 2019
487b41b
prediction to pred
ORippler May 13, 2019
4e0ff57
formulate default fn explicitly
justusschock May 13, 2019
12f4919
rename argument in docstrings
justusschock May 13, 2019
de014d7
Merge branch 'trainer_refactoring' of https://github.com/justusschock…
justusschock May 13, 2019
19c954c
revert move to predictions key
justusschock May 13, 2019
fd99cf6
check arrays for zero-size and reshape if necessary
justusschock May 13, 2019
3fb15ed
fix start_epoch
justusschock May 14, 2019
562453f
added additional checks
justusschock May 17, 2019
291d89c
Make dict a lookup config to support nested dicts via `nested_get`
justusschock May 27, 2019
a5655d8
Make item concat working for nested dicts
justusschock May 27, 2019
f4531f5
Pin scipy requirement
justusschock May 27, 2019
6f26eda
Remove unnecessary imports
justusschock May 27, 2019
45cedd4
Merge branch 'trainer_refactoring' of https://github.com/justusschock…
justusschock May 27, 2019
1babcb4
Merge branch 'master' into trainer_refactoring
justusschock May 27, 2019
36ed4a0
Update docs
justusschock May 27, 2019
7c2b81a
Add kwargs in predictor for prepare_batch_fn
mibaumgartner Jun 1, 2019
47e2556
Convert batchdict back to numpy for metric calculation and introduce …
justusschock Jun 3, 2019
44ef716
Merged origin commits into offline branch
justusschock Jun 3, 2019
b510ffb
fix generator behavior
justusschock Jun 3, 2019
13d3adb
make experiment.test return the first generator item
justusschock Jun 3, 2019
22c7113
Merge branch 'master' into trainer_refactoring
justusschock Jun 3, 2019
d123c1e
add kwargs to overwritten predict_data_mgr functions
justusschock Jun 3, 2019
7007cd7
merge Parallel_master
justusschock Jun 4, 2019
bc32d5c
PEP-8 Auto-Fix
Jun 4, 2019
b63587c
Merge GAN
justusschock Jun 4, 2019
3325ba8
Merge PEP-8 Autofix
justusschock Jun 4, 2019
2e264f5
Add style fixes and common function to search for previous checkpoints
justusschock Jun 4, 2019
4ef61c9
fix infinite recursion by hard type checking instead of instance chec…
justusschock Jun 4, 2019
27fd0e4
fix pep8
justusschock Jun 4, 2019
bcee731
Remove TrixiExperiment as Experiment Baseclass
justusschock Jun 4, 2019
ac1c6e9
correct indent in tf trainer
justusschock Jun 4, 2019
6b9ac0d
shallow copy
justusschock Jun 5, 2019
12ac946
initialize uninitialized members in TfExeriment.Test
justusschock Jun 5, 2019
5077864
remove param argument from test for TfExperiment.test
justusschock Jun 5, 2019
74313e1
add predictor to docs
justusschock Jun 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions delira/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def get_backends():
"""
Return List of currently available backends

Returns
-------
list
list of strings containing the currently installed backends

"""

if not __BACKENDS:
Expand Down
13 changes: 11 additions & 2 deletions delira/data_loading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def train_test_split(self, *args, **kwargs):
"""
split dataset into train and test data

.. deprecated:: 0.3
.. deprecated-removed:: 0.3 0.4
method will be removed in next major release

Parameters
Expand Down Expand Up @@ -692,8 +692,10 @@ def _make_dataset(self, path):


if "TORCH" in get_backends():

from torchvision.datasets import CIFAR10, CIFAR100, EMNIST, MNIST, \
FashionMNIST
import torch

class TorchvisionClassificationDataset(AbstractDataset):
"""
Expand Down Expand Up @@ -802,8 +804,15 @@ def __getitem__(self, index):
"""

data = self.data[index]
label = data[1]

if isinstance(label, torch.Tensor):
label = label.numpy()
elif isinstance(label, int):
label = np.array(label)
data_dict = {"data": np.array(data[0]),
"label": data[1].reshape(1).astype(np.float32)}

"label": label.reshape(1).astype(np.float32)}

if self.one_hot:
# TODO: Remove and refer to batchgenerators transform:
Expand Down
15 changes: 11 additions & 4 deletions delira/data_loading/load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ def norm_zero_mean_unit_std(data):


@make_deprecated("LoadSample")
def _is_valid_image_file(fname, img_extensions, gt_extensions):
def is_valid_image_file(fname, img_extensions, gt_extensions):
"""
Helper Function to check wheter file is image file and has at least
one label file
Parameters

.. deprecated-removed:: 0.3.4 0.3.5

Parameters
----------
fname : str
filename of image path
Expand Down Expand Up @@ -99,6 +102,8 @@ def default_load_fn_2d(img_file, *label_files, img_shape, n_channels=1):
"""
loading single 2d sample with arbitrary number of samples

.. deprecated-removed:: 0.3.4 0.3.5

Parameters
----------
img_file : string
Expand Down Expand Up @@ -136,14 +141,16 @@ def default_load_fn_2d(img_file, *label_files, img_shape, n_channels=1):


class LoadSample:
"""
Provides a callable to load a single sample from multiple files in a folder
"""

def __init__(self,
sample_ext: dict,
sample_fn: collections.abc.Callable,
dtype={}, normalize=(), norm_fn=norm_range('-1,1'),
**kwargs):
"""
Provides a callable to load a single sample from multiple files in a
folder

Parameters
----------
Expand Down
4 changes: 4 additions & 0 deletions delira/data_loading/nii.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def load_sample_nii(files, label_load_cls):
"""
Load sample from multiple ITK files

.. deprecated-removed:: 0.3.4 0.4

Parameters
----------
files : dict with keys `img` and `label`
Expand Down Expand Up @@ -72,6 +74,8 @@ class BaseLabelGenerator(object):
"""
Base Class to load labels from json files

.. deprecated-removed: 0.3.3 0.3.5

"""

@make_deprecated('dict containing labels')
Expand Down
3 changes: 2 additions & 1 deletion delira/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .multistream_handler import MultiStreamHandler
from .trixi_handler import TrixiHandler
from .trixi_handler import TrixiHandler, VisdomLoggingHandler, \
TensorboardXLoggingHandler
56 changes: 56 additions & 0 deletions delira/logging/trixi_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,59 @@ def emit(self, record):
kwargs = {}

getattr(self._logger, _prefix + key)(*args, **kwargs)


class TensorboardXLoggingHandler(TrixiHandler):
"""
Logging Handler to log with TensorboardX (via Trixi)
"""

def __init__(self, log_dir, level=NOTSET, **kwargs):
"""

Parameters
----------
log_dir : str
path to log to
level : int (default: NOTSET)
logging level
**kwargs :
additional keyword arguments

"""

from trixi.logger.tensorboard import TensorboardXLogger

super().__init__(TensorboardXLogger, level=level,
target_dir=log_dir, **kwargs)


class VisdomLoggingHandler(TrixiHandler):
"""
Logging Handler to log with Visdom (via Trixi)

"""

def __init__(self, exp_name, server="http://localhost", port=8080,
level=NOTSET, **kwargs):
"""

Parameters
----------
exp_name : str
experiment name
server : str
address of visdom server
port : int
port of visdom server
level : int (default: NOTSET)
logging level
**kwargs :
additional keyword arguments

"""

from trixi.logger.visdom import NumpyVisdomLogger

super().__init__(NumpyVisdomLogger, level=level, exp_name=exp_name,
server=server, port=port, **kwargs)
46 changes: 27 additions & 19 deletions delira/models/abstract_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __call__(self, *args, **kwargs):

@staticmethod
@abc.abstractmethod
def closure(model, data_dict: dict, optimizers: dict, criterions={},
def closure(model, data_dict: dict, optimizers: dict, losses={},
metrics={}, fold=0, **kwargs):
"""
Function which handles prediction from batch, logging, loss calculation
Expand All @@ -66,8 +66,8 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
dictionary containing the data
optimizers : dict
dictionary containing all optimizers to perform parameter update
criterions : dict
Functions or classes to calculate criterions
losses : dict
Functions or classes to calculate losses
metrics : dict
Functions or classes to calculate other metrics
fold : int
Expand All @@ -80,8 +80,8 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict criterions)
list
Loss values (with same keys as input dict losses)
dict
Arbitrary number of predictions

Raises
Expand Down Expand Up @@ -256,21 +256,23 @@ def __init__(self, sess=tf.Session, **kwargs):
"""
AbstractNetwork.__init__(self, **kwargs)
self._sess = sess()
self.inputs = None
self.outputs_train = None
self.outputs_eval = None
self.inputs = {}
self.outputs_train = {}
self.outputs_eval = {}
self._losses = None
self._optims = None
self.training = True

def __call__(self, *args):
def __call__(self, *args, **kwargs):
"""
Wrapper for calling self.run in eval setting

Parameters
----------
*args :
positional arguments (passed to `self.run`)
**kwargs:
keyword arguments (passed to `self.run`)

Returns
-------
Expand All @@ -279,7 +281,7 @@ def __call__(self, *args):

"""
self.training = False
return self.run(*args)
return self.run(*args, **kwargs)

def _add_losses(self, losses: dict):
"""
Expand All @@ -304,27 +306,33 @@ def _add_optims(self, optims: dict):
"""
raise NotImplementedError()

def run(self, *args):
def run(self, *args, **kwargs):
"""
Evaluates `self.outputs_train` or `self.outputs_eval` based on
`self.training`

Parameters
----------
*args :
arguments to feed as `self.inputs`. Must have same length as
`self.inputs`
currently unused, exist for compatibility reasons
**kwargs :
kwargs used to feed as ``self.inputs``. Same keys as for
``self.inputs`` must be used

Returns
-------
np.ndarray or list
based on len(self.outputs*), returns either list or np.ndarray
dict
sames keys as outputs_train or outputs_eval,
containing evaluated expressions as values

"""
if isinstance(self.inputs, tf.Tensor):
_feed_dict = dict(zip([self.inputs], args))
else:
_feed_dict = dict(zip(self.inputs, args))
_feed_dict = {}

for feed_key, feed_value in kwargs.items():
assert feed_key in self.inputs.keys(), \
"{} not found in self.inputs".format(feed_key)
_feed_dict[self.inputs[feed_key]] = feed_value

if self.training:
return self._sess.run(self.outputs_train, feed_dict=_feed_dict)
else:
Expand Down
3 changes: 3 additions & 0 deletions delira/models/classification/ResNet18.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf
from delira.utils.decorators import make_deprecated

conv2d = tf.keras.layers.Conv2D
maxpool2d = tf.keras.layers.MaxPool2D
Expand Down Expand Up @@ -81,7 +82,9 @@ def call(self, inputs, training=None):


class ResNet18(tf.keras.Model):
@make_deprecated("own repository to be announced")
def __init__(self, num_classes=None, bias=False):

super(ResNet18, self).__init__()

_image_format, _axis = get_image_format_and_axis()
Expand Down