Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ProbabilisticModel base class using abstract base class #37

Open
tom-andersson opened this issue Aug 18, 2023 · 3 comments
Open
Milestone

Comments

@tom-andersson
Copy link
Collaborator

The ProbabilisticModel base class (below) defines the interface that custom model classes must implement to work in deepsensor (such as .mean, .stddev). This should be updated to be an abstract base class using collections.abc: https://docs.python.org/3/library/collections.abc.html

https://github.com/tom-andersson/deepsensor/blob/e349aaa48f6ed673c721b13b64b9a7b422425876/deepsensor/model/model.py#L95-L108

cc @patel-zeel

@tom-andersson tom-andersson added this to the v0.3.0 milestone Aug 18, 2023
@patel-zeel
Copy link
Contributor

patel-zeel commented Aug 22, 2023

I agree, @tom-andersson. Should I add this via a PR?

Given that train_epoch function uses model.loss_fn, should loss_fn be a mandatory method in the abstract class?
https://github.com/tom-andersson/deepsensor/blob/e349aaa48f6ed673c721b13b64b9a7b422425876/deepsensor/train/train.py#L40-L48

@tom-andersson tom-andersson self-assigned this Aug 22, 2023
@tom-andersson
Copy link
Collaborator Author

Thanks for the PR offer @patel-zeel, in this instance I'm happy to implement this myself and will update this issue with my progress.

Great point about model.loss_fn being required by train_epoch. While it may be strange to associate a single loss function with a model in this way, IMO it's tidier to encapsulate the loss within the model class (rather than, say, loss_fn(model, task)). So I think yes, let's include loss_fn in the abstract base class.

Note that train_epoch is currently intended as a simple Adam-based training algorithm for ConvNP. It's not guaranteed to work with other models (e.g. GPs, or models not based on PyTorch/TensorFlow). Users who want more training customisation may prefer to develop their own training code.

@tom-andersson
Copy link
Collaborator Author

I didn't get around to implementing this and won't have time for a while, so I've unassigned myself if anyone wants to pick this up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants