Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Trainer refactoring #66

Merged
merged 92 commits into from
Jun 5, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
0ce953a
Start refactoring trainer and experiment
justusschock Feb 22, 2019
722cd55
make new torch trainer and experiment working
justusschock Feb 25, 2019
0ee1c08
Merge branch 'master' into trainer_refactoring
justusschock Feb 26, 2019
1dfa7ce
Update pytorch_trainer.py
justusschock Feb 26, 2019
ec0c11a
Update pytorch_trainer.py
justusschock Feb 26, 2019
a4b8743
Merge branch 'trainer_refactoring' of https://github.com/justusschock…
mibaumgartner Mar 11, 2019
151a3df
First version of auc metric implementation
mibaumgartner Mar 11, 2019
da4f474
Resolve DImension Error
justusschock Mar 12, 2019
71387c7
Move `_is_better_val_score` from predictor to abstract trainer
justusschock Mar 12, 2019
17478ce
Fix error in torchvision datasets due to latest torchvision release
justusschock Mar 12, 2019
1a8cca0
Fix error due to missing brackets for file extension checks
justusschock Mar 12, 2019
e467013
Add docstrings
justusschock Mar 12, 2019
f23b52a
Update pytorch_trainer docstrings
justusschock Mar 12, 2019
af66224
Merge remote-tracking branch 'upstream/trainer_refactoring' into trai…
mibaumgartner Mar 12, 2019
503ce15
Finished auc metric and added doc strings
mibaumgartner Mar 12, 2019
1575430
Added unit test for metric wrapper and auc metric
mibaumgartner Mar 12, 2019
06a9be9
Remove dummy experiment file
mibaumgartner Mar 12, 2019
bd25321
Remove todo
mibaumgartner Mar 12, 2019
667733f
Added docstring
mibaumgartner Mar 12, 2019
1c9ace3
Updated docstring
mibaumgartner Mar 12, 2019
ec04ddf
Update test_metrics.py
justusschock Mar 13, 2019
cd1b686
Merge branch 'master' into trainer_refactoring
justusschock Mar 13, 2019
775b684
Merge pull request #73 from mibaumgartner/trainer_refactoring
justusschock Mar 13, 2019
59046e4
resolve merge conflicts and merge
ORippler Mar 22, 2019
abc91b9
update TfExperiment.test, change test signature
ORippler Mar 22, 2019
b4c7773
Make consistent logging
justusschock Apr 7, 2019
85566f6
Continue merging trainers
justusschock Apr 7, 2019
d43cbc6
update tests
justusschock Apr 7, 2019
5ca4024
fix feed errors by using default placeholder
ORippler Apr 9, 2019
c645c76
Change list outputs to dict outputs
justusschock Apr 16, 2019
23dd11d
Change networks to return dicts
justusschock Apr 16, 2019
7fba516
remove duplicate argument in tf trainer
justusschock Apr 16, 2019
55c1ed9
deprecate models in favor of to be announced repo
justusschock May 9, 2019
506960b
Merge branch 'master' into trainer_refactoring
justusschock May 9, 2019
444b6b8
start updating notebooks
justusschock May 9, 2019
8c8c9cd
Start Experiment refactoring
justusschock May 9, 2019
3b83c6b
Update Experiment Tests and Fix Bugs
justusschock May 9, 2019
7f9bea4
Fix Naming issue
justusschock May 9, 2019
a9760fc
continue test fixing
justusschock May 10, 2019
c0b7553
Merge branch 'trainer_refactoring' of https://github.com/justusschock…
justusschock May 10, 2019
23436ed
add docstrings
justusschock May 10, 2019
e1189b7
fix minor bugs, change descriptor during validation phase and introdu…
justusschock May 10, 2019
866b0cf
Fix some bugs, torch tests should complete now
mibaumgartner May 11, 2019
f189575
Add resume to tf trainer
justusschock May 11, 2019
2b56ed4
merge master into trainer_refactoring
justusschock May 12, 2019
2a5cfb2
fix merge
justusschock May 12, 2019
3c48c98
fix indent tf trainer
justusschock May 12, 2019
1053486
switch to suitable permutation
justusschock May 12, 2019
424318c
specify `data` in signature as dict type
ORippler May 12, 2019
e9a221d
switch from list to dicts for tf
ORippler May 12, 2019
829e241
pep8
ORippler May 12, 2019
6f023a8
change to dicts
ORippler May 12, 2019
399fd7d
Adress requested changes: update start_epoch for tf_trainer, support …
justusschock May 13, 2019
3068ab9
fix indent causing swigpy pickling error
justusschock May 13, 2019
efa61da
fix tf experiment test and type annotations
justusschock May 13, 2019
6a3e5c2
detach network outputs, remove params from test in TF Experiment
justusschock May 13, 2019
a9383de
key 'prediction' to 'pred'
ORippler May 13, 2019
487b41b
prediction to pred
ORippler May 13, 2019
4e0ff57
formulate default fn explicitly
justusschock May 13, 2019
12f4919
rename argument in docstrings
justusschock May 13, 2019
de014d7
Merge branch 'trainer_refactoring' of https://github.com/justusschock…
justusschock May 13, 2019
19c954c
revert move to predictions key
justusschock May 13, 2019
fd99cf6
check arrays for zero-size and reshape if necessary
justusschock May 13, 2019
3fb15ed
fix start_epoch
justusschock May 14, 2019
562453f
added additional checks
justusschock May 17, 2019
291d89c
Make dict a lookup config to support nested dicts via `nested_get`
justusschock May 27, 2019
a5655d8
Make item concat working for nested dicts
justusschock May 27, 2019
f4531f5
Pin scipy requirement
justusschock May 27, 2019
6f26eda
Remove unnecessary imports
justusschock May 27, 2019
45cedd4
Merge branch 'trainer_refactoring' of https://github.com/justusschock…
justusschock May 27, 2019
1babcb4
Merge branch 'master' into trainer_refactoring
justusschock May 27, 2019
36ed4a0
Update docs
justusschock May 27, 2019
7c2b81a
Add kwargs in predictor for prepare_batch_fn
mibaumgartner Jun 1, 2019
47e2556
Convert batchdict back to numpy for metric calculation and introduce …
justusschock Jun 3, 2019
44ef716
Merged origin commits into offline branch
justusschock Jun 3, 2019
b510ffb
fix generator behavior
justusschock Jun 3, 2019
13d3adb
make experiment.test return the first generator item
justusschock Jun 3, 2019
22c7113
Merge branch 'master' into trainer_refactoring
justusschock Jun 3, 2019
d123c1e
add kwargs to overwritten predict_data_mgr functions
justusschock Jun 3, 2019
7007cd7
merge Parallel_master
justusschock Jun 4, 2019
bc32d5c
PEP-8 Auto-Fix
Jun 4, 2019
b63587c
Merge GAN
justusschock Jun 4, 2019
3325ba8
Merge PEP-8 Autofix
justusschock Jun 4, 2019
2e264f5
Add style fixes and common function to search for previous checkpoints
justusschock Jun 4, 2019
4ef61c9
fix infinite recursion by hard type checking instead of instance chec…
justusschock Jun 4, 2019
27fd0e4
fix pep8
justusschock Jun 4, 2019
bcee731
Remove TrixiExperiment as Experiment Baseclass
justusschock Jun 4, 2019
ac1c6e9
correct indent in tf trainer
justusschock Jun 4, 2019
6b9ac0d
shallow copy
justusschock Jun 5, 2019
12ac946
initialize uninitialized members in TfExeriment.Test
justusschock Jun 5, 2019
5077864
remove param argument from test for TfExperiment.test
justusschock Jun 5, 2019
74313e1
add predictor to docs
justusschock Jun 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions delira/models/abstract_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(self, *args, **kwargs):

@staticmethod
@abc.abstractmethod
def closure(model, data_dict: dict, optimizers: dict, criterions={},
def closure(model, data_dict: dict, optimizers: dict, losses={},
metrics={}, fold=0, **kwargs):
"""
Function which handles prediction from batch, logging, loss calculation
Expand All @@ -65,8 +65,8 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
dictionary containing the data
optimizers : dict
dictionary containing all optimizers to perform parameter update
criterions : dict
Functions or classes to calculate criterions
losses : dict
Functions or classes to calculate losses
metrics : dict
Functions or classes to calculate other metrics
fold : int
Expand All @@ -79,7 +79,7 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict criterions)
Loss values (with same keys as input dict losses)
list
Arbitrary number of predictions

Expand Down
16 changes: 8 additions & 8 deletions delira/models/classification/classification_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, input_batch: torch.Tensor):

@staticmethod
def closure(model: AbstractPyTorchNetwork, data_dict: dict,
optimizers: dict, criterions={}, metrics={},
optimizers: dict, losses={}, metrics={},
fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Expand All @@ -77,9 +77,9 @@ def closure(model: AbstractPyTorchNetwork, data_dict: dict,
dictionary containing the data
optimizers : dict
dictionary of optimizers to optimize model's parameters
criterions : dict
dict holding the criterions to calculate errors
(gradients from different criterions will be accumulated)
losses : dict
dict holding the losses to calculate errors
(gradients from different losses will be accumulated)
metrics : dict
dict holding the metrics to calculate
fold : int
Expand All @@ -92,19 +92,19 @@ def closure(model: AbstractPyTorchNetwork, data_dict: dict,
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict criterions)
Loss values (with same keys as input dict losses)
list
Arbitrary number of predictions as torch.Tensor

Raises
------
AssertionError
if optimizers or criterions are empty or the optimizers are not
if optimizers or losses are empty or the optimizers are not
specified

"""

assert (optimizers and criterions) or not optimizers, \
assert (optimizers and losses) or not optimizers, \
"Criterion dict cannot be emtpy, if optimizers are passed"

loss_vals = {}
Expand All @@ -125,7 +125,7 @@ def closure(model: AbstractPyTorchNetwork, data_dict: dict,

if data_dict:

for key, crit_fn in criterions.items():
for key, crit_fn in losses.items():
_loss_val = crit_fn(preds, *data_dict.values())
loss_vals[key] = _loss_val.detach()
total_loss += _loss_val
Expand Down
18 changes: 9 additions & 9 deletions delira/models/gan/generative_adversarial_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def forward(self, real_image_batch):

@staticmethod
def closure(model, data_dict: dict,
optimizers: dict, criterions={}, metrics={},
optimizers: dict, losses={}, metrics={},
fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Expand All @@ -102,9 +102,9 @@ def closure(model, data_dict: dict,
dictionary containing data
optimizers : dict
dictionary of optimizers to optimize model's parameters
criterions : dict
dict holding the criterions to calculate errors
(gradients from different criterions will be accumulated)
losses : dict
dict holding the losses to calculate errors
(gradients from different losses will be accumulated)
metrics : dict
dict holding the metrics to calculate
fold : int
Expand All @@ -117,14 +117,14 @@ def closure(model, data_dict: dict,
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict criterions)
Loss values (with same keys as input dict losses)
list
Arbitrary number of predictions as torch.Tensor

Raises
------
AssertionError
if optimizers or criterions are empty or the optimizers are not
if optimizers or losses are empty or the optimizers are not
specified

"""
Expand All @@ -149,14 +149,14 @@ def closure(model, data_dict: dict,
fake_image_batch, discr_pred_fake, discr_pred_real = model(batch)

# train discr with prediction from real image
for key, crit_fn in criterions.items():
for key, crit_fn in losses.items():
_loss_val = crit_fn(discr_pred_real,
torch.ones_like(discr_pred_real))
loss_vals[key + "_discr_real"] = _loss_val.detach()
total_loss_discr_real += _loss_val

# train discr with prediction from fake image
for key, crit_fn in criterions.items():
for key, crit_fn in losses.items():
_loss_val = crit_fn(discr_pred_fake,
torch.zeros_like(discr_pred_fake))
loss_vals[key + "_discr_fake"] = _loss_val.detach()
Expand All @@ -175,7 +175,7 @@ def closure(model, data_dict: dict,
optimizers["discr"].step()

# calculate adversarial loss for generator update
for key, crit_fn in criterions.items():
for key, crit_fn in losses.items():
_loss_val = crit_fn(discr_pred_fake,
torch.ones_like(discr_pred_fake))
loss_vals[key + "_adversarial"] = _loss_val.detach().cpu()
Expand Down
46 changes: 18 additions & 28 deletions delira/models/segmentation/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def forward(self, x):
return x

@staticmethod
def closure(model, data_dict: dict, optimizers: dict, criterions={},
def closure(model, data_dict: dict, optimizers: dict, losses={},
metrics={}, fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Expand All @@ -186,9 +186,9 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
dictionary containing the data
optimizers : dict
dictionary of optimizers to optimize model's parameters
criterions : dict
dict holding the criterions to calculate errors
(gradients from different criterions will be accumulated)
losses : dict
dict holding the losses to calculate errors
(gradients from different losses will be accumulated)
metrics : dict
dict holding the metrics to calculate
fold : int
Expand All @@ -201,20 +201,20 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict criterions)
Loss values (with same keys as input dict losses)
list
Arbitrary number of predictions as torch.Tensor

Raises
------
AssertionError
if optimizers or criterions are empty or the optimizers are not
if optimizers or losses are empty or the optimizers are not
specified

"""

assert (optimizers and criterions) or not optimizers, \
"Criterion dict cannot be emtpy, if optimizers are passed"
assert (optimizers and losses) or not optimizers, \
"Loss dict cannot be emtpy, if optimizers are passed"

loss_vals = {}
metric_vals = {}
Expand All @@ -234,7 +234,7 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},

if data_dict:

for key, crit_fn in criterions.items():
for key, crit_fn in losses.items():
_loss_val = crit_fn(preds, *data_dict.values())
loss_vals[key] = _loss_val.detach()
total_loss += _loss_val
Expand Down Expand Up @@ -264,11 +264,6 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
loss_vals = eval_loss_vals
metric_vals = eval_metrics_vals

for key, val in {**metric_vals, **loss_vals}.items():
logging.info({"value": {"value": val.item(), "name": key,
"env_appendix": "_%02d" % fold
}})

logging.info({'image_grid': {"images": inputs, "name": "input_images",
"env_appendix": "_%02d" % fold}})

Expand Down Expand Up @@ -624,7 +619,7 @@ def forward(self, x):
return x

@staticmethod
def closure(model, data_dict: dict, optimizers: dict, criterions={},
def closure(model, data_dict: dict, optimizers: dict, losses={},
metrics={}, fold=0, **kwargs):
"""
closure method to do a single backpropagation step
Expand All @@ -638,9 +633,9 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
dictionary containing the data
optimizers : dict
dictionary of optimizers to optimize model's parameters
criterions : dict
dict holding the criterions to calculate errors
(gradients from different criterions will be accumulated)
losses : dict
dict holding the losses to calculate errors
(gradients from different losses will be accumulated)
metrics : dict
dict holding the metrics to calculate
fold : int
Expand All @@ -653,20 +648,20 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
dict
Metric values (with same keys as input dict metrics)
dict
Loss values (with same keys as input dict criterions)
Loss values (with same keys as input dict losses)
list
Arbitrary number of predictions as torch.Tensor

Raises
------
AssertionError
if optimizers or criterions are empty or the optimizers are not
if optimizers or losses are empty or the optimizers are not
specified

"""

assert (optimizers and criterions) or not optimizers, \
"Criterion dict cannot be emtpy, if optimizers are passed"
assert (optimizers and losses) or not optimizers, \
"Loss dict cannot be emtpy, if optimizers are passed"

loss_vals = {}
metric_vals = {}
Expand All @@ -686,7 +681,7 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},

if data_dict:

for key, crit_fn in criterions.items():
for key, crit_fn in losses.items():
_loss_val = crit_fn(preds, *data_dict.values())
loss_vals[key] = _loss_val.detach()
total_loss += _loss_val
Expand Down Expand Up @@ -716,11 +711,6 @@ def closure(model, data_dict: dict, optimizers: dict, criterions={},
loss_vals = eval_loss_vals
metric_vals = eval_metrics_vals

for key, val in {**metric_vals, **loss_vals}.items():
logging.info({"value": {"value": val.item(), "name": key,
"env_appendix": "_%02d" % fold
}})

slicing_dim = inputs.size(2) // 2 # visualize slice in mid of volume

logging.info({'image_grid': {"inputs": inputs[:, :, slicing_dim, ...],
Expand Down
1 change: 0 additions & 1 deletion delira/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
if "TORCH" in get_backends():
from .experiment import PyTorchExperiment
from .pytorch_trainer import PyTorchNetworkTrainer
from .metrics import AccuracyMetricPyTorch, AurocMetricPyTorch

if "TF" in get_backends():
from .experiment import TfExperiment
Expand Down