Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will investigate shape errors tomorrow. Furthermore, experiment.stratified_kfold_predict has to be adapted (might also be related to the shape errors)
…/delira into trainer_refactoring
I just added two things to the
EDIT: |
I think we're almost good to go. The only remaining issue is the one with tensorflow not finding some resources in our tests. According to this issue it is most likely due to a thread that is spawned anywhere... Any Ideas on why this happens now and didn't happen before? |
Make shallow copy of batchdict to retain keys, which might get popped in `prepare_batch`
The trainer is now good to be merged. The only failing test is python 3.7 which is due to trixi dependencies but completely unrelated to this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks good 👍
We should include the Predictor
inside the training.rst
so it is displayed correctly in the documentation.
I think that should be a quick fix, so i approve the changes.
The file already existed, I just forgot to include it into the root file. Done now |
This is a first draft to refactor trainer (combine code where possible) and introduce a predictor.
The metric logging was moved from the networks closure to the trainer.
@ORippler : maybe we could also merge experiments and rename the
AbstractTrainer
toBaseTrainer
since it is not abstract anymore?Also some tf tests fail due to shape missmatch. Have you any idea why?
@mibaumgartner : does this match your idea of a predictor? Do you have any improvements compared to your own prototype?
fix #39
fix #46 ?
EDIT: Docstrings are still missing