Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(xrk): add q-transformer #783

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

rongkunxue
Copy link
Contributor

Description

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added the algo Add new algorithm or improve old one label Mar 22, 2024
@@ -19,6 +19,7 @@
from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy
from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy
from .cql import CQLPolicy, DiscreteCQLPolicy
from .qtransformer import QtransformerPolicy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QTransformerPolicy

from ding.entry import serial_pipeline_offline
from ding.config import read_config
from pathlib import Path
from ding.model.template.qtransformer import QTransformer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import from the secondary directory, such as:

from ding.model import QTransformer

alpha=0.2,
discount_factor_gamma=0.9,
min_reward = 0.1,
auto_alpha=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused fields like this

update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_theta}
)
self._low = np.array(self._cfg.other["low"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need low and high here, We always think that the action value range in the policy is [-1,1]

cuda=True,
model=dict(
num_actions = 3,
action_bins = 256,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this action_bins field is not used in policy

selected = t.gather(-1, indices)
return rearrange(selected, '... 1 -> ...')

def _discretize_action(self, actions):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can optimize this for loop:

action_values = np.linspace(-1, 1, 8)[np.newaxis, ...].repeat(4, 0)
action_values = torch.as_tensor(action_values).to(self._device)
diff = (actions.unsqueeze(-1) - action_values.unsqueeze(0)) ** 2
indices = diff.argmin(-1)

actions = data['action']

#get q
num_timesteps, device = states.shape[1], states.device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use self._device, which is the default member variable of Policy

import torch
import torch.nn.functional as F
from torch.distributions import Normal, Independent
from ema_pytorch import EMA
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused third party libraries


from pathlib import Path
from functools import partial
from contextlib import nullcontext
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish imports


from torchtyping import TensorType

from einops import rearrange, repeat, pack, unpack
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add einops in setup.py

from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

from beartype import beartype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will not use beartype to validate runtime types in the current version, thus remove it in this PR

@@ -0,0 +1,753 @@
from random import random
from functools import partial, cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache is the new feature in python3.9, for compatibility, you should implement it as follows:

try:
    from functools import cache  # only in Python >= 3.9
except ImportError:
    from functools import lru_cache
    cache = lru_cache(maxsize=None)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants