Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch 1.9 Compatibility: IterableDataset Type Annotations #235

Open
lebrice opened this issue Aug 4, 2021 · 5 comments
Open

PyTorch 1.9 Compatibility: IterableDataset Type Annotations #235

lebrice opened this issue Aug 4, 2021 · 5 comments
Labels
enhancement New feature or request

Comments

@lebrice
Copy link
Owner

lebrice commented Aug 4, 2021

Creating this issue as a follow-up to the discussion with @ejguan on the closed PR #233 :

PyTorch 1.9 checks the type annotations on the IterableDataset subclasses. This is currently causing issues with the annotations of the various subclasses of Sequoia's EnvDataset class, which inherits from IterableDataset.

This EnvDataset class can be found here, but here is a simple version of what it is meant to be like:

Note: The actual EnvDataset class is a bit different from this, and doesn't use a generator, but this is just to illustrate the kind of interaction we have with IterableDataset:

from torch.utils.data import IterableDataset
import gym
import numpy as np
from typing import Generator, Optional, Generic, TypeVar, Tuple, Union, Iterator

ObservationType = TypeVar("ObservationType")
ActionType = TypeVar("ActionType")
RewardType = TypeVar("RewardType")


class EnvDataset(
    gym.Wrapper,
    IterableDataset,
):
    def __init__(self, env: gym.Env):
        super().__init__(env=env)
        self.env = env
        self._iterator: Optional[Generator] = None

    def __iter__(self):
        if self._iterator:
            self._iterator.close()
        self._iterator = self.iterate(self.env)
        return self._iterator

    @staticmethod
    def iterate(env) -> Generator[Union[ObservationType, RewardType], ActionType, None]:
        """Iterator / generator for a gym.Env."""
        try:
            observations = env.reset()
            done = False
            while not done:
                actions = yield observations
                if actions is None:
                    raise RuntimeError("Need to send an action after each observation.")
                observations, rewards, done, info = env.step(actions)
                yield rewards
        except GeneratorExit:
            print("closing")

    def send(self, actions: ActionType) -> RewardType:
        return self._iterator.send(actions)

env = gym.make("CartPole-v0")
env_dataset = EnvDataset(env)

for i, obs in enumerate(env_dataset):
    action = env_dataset.action_space.sample()
    rewards = env_dataset.send(action)
    print(i, obs, action)
@ejguan
Copy link

ejguan commented Aug 5, 2021

Just tested the EnvDataset class itself and it should not break anything. Please correct me if I am wrong, I didn't find any subclass of EnvDataset in your repo. D
BTW, can you share the python version?
Edit: I did try python3.7 and python3.9 without any error.

@lebrice
Copy link
Owner Author

lebrice commented Aug 6, 2021

Oh my mistake, I meant subclasses of IterableWrapper, a type of gym.Wrapper used around these EnvDatasets.

Indeed this example doesn't have any errors! I'm pretty sure this issue isn't that hard to fix, I probably need to make sure all the __iter__ methods don't have any type annotation or just have an Iterator[Any] annotation so the errors go away.
Python 3.8 at the moment, but Sequoia is meant to work for 3.7+.

@ejguan
Copy link

ejguan commented Aug 6, 2021

I see. I have a super simple fix for you, remove the bound for EnvType here:

EnvType = TypeVar("EnvType", bound=gym.Env)

It's true that it is too hard to figure out without context.

For example of MeasureSLPerformanceWrapper:

class MeasureSLPerformanceWrapper(
MeasurePerformanceWrapper,
# MeasurePerformanceWrapper[PassiveEnvironment] # Python 3.7
# MeasurePerformanceWrapper[PassiveEnvironment, ClassificationMetrics] # Python 3.8+
):

It inherits MeasurePerformanceWrapper . And, you can find you have specified the type is EnvType with bound of gym.env as shown below.
class MeasurePerformanceWrapper(
IterableWrapper[EnvType], Generic[EnvType, MetricsType], ABC

We just introduced a hard check for the return type of __iter__ in 1.9. But, you just specified the return type as below, which is not a subtype of your EnvType IMO. So, the Error raises.
def __iter__(self) -> Iterator[Tuple[Observations, Optional[Rewards]]]:

But, I agree this check may be too strict. I may try to release the check for iterator in the next release. Thank you for your feedback!

BTW, during my testing, I did find some interesting part like:

class GymDataLoader(
ActiveEnvironment[ObservationType, ActionType, RewardType], IterableWrapper, Iterable
):

The GymDataLoader is inheriting both DataLoader and IterableDataset, which is kind unexpected even though it doesn't break anything.

@ejguan
Copy link

ejguan commented Aug 27, 2021

Hey @lebrice ,

Since it's been a while, just want to follow up if there is any other concern about it?

@lebrice
Copy link
Owner Author

lebrice commented Aug 27, 2021

@ejguan Hey! No its all good, I just havent taken the time to do it. I'll probably get this sorted out soon though.

Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants