Skip to content

Commit

Permalink
Fix python 3.7 compatibility bugs (#233)
Browse files Browse the repository at this point in the history
- Limit pytorch version to 1.8.X, from >= 1.8: PyTorch suddenly got
  really picky about type annotations on subclasses of IterableDataset
  for some unexplained reason, and that breaks most of the
  EnvDataset-related wrappers in Sequoia.

- Fix compatibility issues with `lru_cache` taking one required argument
  in python 3.7

- Fix PyTorch-Lightning==1.4.0 compatibility issues due to ModelSummary
  not having a MODE_DEFAULT anymore

- Fix `typing.get_origin` bug with python 3.7 in typed_dict.py

Signed-off-by: Fabrice Normandin <fabrice.normandin@gmail.com>
  • Loading branch information
lebrice committed Jul 29, 2021
1 parent 082e3b0 commit fc767ee
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 17 deletions.
7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ nngeometry @ git+https://github.com/oleksost/nngeometry.git#egg=nngeometry
pyyaml!=5.4.*,>=5.1
simple_parsing>=0.0.15.post1
matplotlib==3.2.2
torch>=1.8
# NOTE: @lebrice: PyTorch suddenly got really picky about type annotations in 1.9.0 for
# some reason, and they really don't do a great job at evaluating them, so removing it
# for now.
torch>=1.8,<1.9.0
torchvision>=0.9
scikit-learn
tqdm
Expand All @@ -17,7 +20,7 @@ plotly
pandas
# Only for python < 3.8
singledispatchmethod;python_version<'3.8'
# Temporarily fix the pytorch lightning version (issue #134)
# NOTE: PyTorch-Lightning version 1.4.0 is "working" but raises lots of warnings.
pytorch-lightning>=1.3.8
pytorch-lightning-bolts>=0.3.2
# Requirements for running tests:
Expand Down
3 changes: 2 additions & 1 deletion sequoia/common/gym_wrappers/env_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
Dict,
Generator,
Iterator,
Generic,
Iterable,
List,
Expand Down Expand Up @@ -224,7 +225,7 @@ def send(self, action: ActionType) -> RewardType:
self.n_sends_ += 1
return self.reward_

def __iter__(self) -> Iterable[ObservationType]:
def __iter__(self) -> Iterator[ObservationType]:
"""Iterator for an episode in the environment, which uses the 'active
dataset' style with __iter__ and send.
Expand Down
14 changes: 12 additions & 2 deletions sequoia/common/gym_wrappers/policy_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
This policy should then accept the 'state' or something like that.
"""
from dataclasses import dataclass
from typing import Any, Callable, Iterable, Optional, Tuple, TypeVar, Dict, Generic
from typing import (
Any,
Callable,
Iterator,
Optional,
Tuple,
TypeVar,
Dict,
Generic,
Iterable,
)

import gym
from torch.utils.data import IterableDataset
Expand Down Expand Up @@ -169,7 +179,7 @@ def reset(self, *args, **kwargs) -> None:
self._n_steps_in_episode = 0
return self._observation

def __iter__(self) -> Iterable[DatasetItem]:
def __iter__(self) -> Iterator[DatasetItem]:
"""Iterator for an episode/trajectory in the env.
This uses the policy to iteratively perform an episode in the env, and
Expand Down
9 changes: 8 additions & 1 deletion sequoia/common/spaces/typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
S = TypeVar("S")
Dataclass = TypeVar("Dataclass")

try:
from typing import get_origin
except ImportError:
# Python 3.7's typing module doesn't have this `get_origin` function, so get it from
# `typing_inspect`.
from typing_inspect import get_origin


class TypedDictSpace(spaces.Dict, Mapping[str, Space], Generic[M]):
""" Subclass of `spaces.Dict` that allows custom dtypes and uses type annotations.
Expand Down Expand Up @@ -183,7 +190,7 @@ def __init__(
if isclass(type_annotation) and issubclass(type_annotation, gym.Space):
is_space = True
else:
origin = typing.get_origin(type_annotation)
origin = get_origin(type_annotation)
is_space = (
origin is not None
and isclass(origin)
Expand Down
2 changes: 1 addition & 1 deletion sequoia/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def wrap(method_class: Type[Method]) -> Type[Method]:
return wrap(method_class)


@lru_cache
@lru_cache(1)
def get_external_methods() -> Dict[str, Type[Method]]:
""" Returns a dictionary of the Methods defined outside of Sequoia.
Expand Down
8 changes: 4 additions & 4 deletions sequoia/methods/models/base_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,10 +748,10 @@ def shared_modules(self) -> Dict[str, nn.Module]:
shared_modules["output_head"] = self.output_head
return shared_modules

def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
model_summary = ModelSummary(self, mode=mode)
log.debug("\n" + str(model_summary))
return model_summary
# def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
# model_summary = ModelSummary(self, mode=mode)
# log.debug("\n" + str(model_summary))
# return model_summary

def _are_batched(self, observations: IncrementalAssumption.Observations) -> bool:
""" Returns wether these observations are batched. """
Expand Down
4 changes: 3 additions & 1 deletion sequoia/settings/rl/wrappers/measure_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


class MeasureRLPerformanceWrapper(
MeasurePerformanceWrapper[ActiveEnvironment, EpisodeMetrics]
MeasurePerformanceWrapper
# MeasurePerformanceWrapper[ActiveEnvironment] # python 3.7
# MeasurePerformanceWrapper[ActiveEnvironment, EpisodeMetrics] # python 3.8+
):
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion sequoia/settings/sl/continual/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class ContinualSLSetting(SLSetting, ContinualAssumption):

def __post_init__(self):
super().__post_init__()
assert not self.has_setup_fit
# assert not self.has_setup_fit
# Test values default to the same as train.
self.test_increment = self.test_increment or self.increment
self.test_initial_increment = (
Expand Down
2 changes: 1 addition & 1 deletion sequoia/settings/sl/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sequoia.settings import Setting
from sequoia.common.transforms import Transforms
from .environment import PassiveEnvironment
from torch.tensor import Tensor
from torch import Tensor


@dataclass
Expand Down
8 changes: 5 additions & 3 deletions sequoia/settings/sl/wrappers/measure_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
TODO: Move this somewhere more appropriate. There's also the RL version of the wrapper
here.
"""
from typing import Any, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Generic, Iterator, List, Optional, Sequence, Tuple, Union

import numpy as np
import wandb
Expand All @@ -30,7 +30,9 @@


class MeasureSLPerformanceWrapper(
MeasurePerformanceWrapper[PassiveEnvironment, ClassificationMetrics]
MeasurePerformanceWrapper,
# MeasurePerformanceWrapper[PassiveEnvironment] # Python 3.7
# MeasurePerformanceWrapper[PassiveEnvironment, ClassificationMetrics] # Python 3.8+
):
def __init__(
self,
Expand Down Expand Up @@ -139,7 +141,7 @@ def get_metrics(self, action: Actions, reward: Rewards) -> Metrics:
wandb.log(log_dict)
return metric

def __iter__(self) -> Iterable[Tuple[Observations, Optional[Rewards]]]:
def __iter__(self) -> Iterator[Tuple[Observations, Optional[Rewards]]]:
if self.__epochs == 1 and self.first_epoch_only:
print(
colorize(
Expand Down

0 comments on commit fc767ee

Please sign in to comment.