diff --git a/CHANGELOG.md b/CHANGELOG.md index ee531bc..bd86cbe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +# Release 2.5 (2022-07-27) +- Introduce environment reset pool, so concurrent enviornment replicas can randomly reset themselves from the pool. + # Release 2.4 (2022-06-16) - Introduce new device context management and autoinit_pycuda - Therefore, Torch (any version) will not conflict with PyCUDA in the GPU context diff --git a/example_envs/tag_continuous/tag_continuous.py b/example_envs/tag_continuous/tag_continuous.py index e1ce2e4..51c45ad 100644 --- a/example_envs/tag_continuous/tag_continuous.py +++ b/example_envs/tag_continuous/tag_continuous.py @@ -9,7 +9,6 @@ import numpy as np from gym import spaces -from gym.utils import seeding from warp_drive.utils.constants import Constants from warp_drive.utils.data_feed import DataFeed @@ -313,7 +312,7 @@ def seed(self, seed=None): Note: this uses the code in https://github.com/openai/gym/blob/master/gym/utils/seeding.py """ - self.np_random, seed = seeding.np_random(seed) + self.np_random.seed(seed) return [seed] def set_global_state(self, key=None, value=None, t=None, dtype=None): @@ -756,10 +755,6 @@ def get_data_dictionary(self): ) return data_dict - def get_tensor_dictionary(self): - tensor_dict = DataFeed() - return tensor_dict - def reset(self): """ Env reset(). diff --git a/example_envs/tag_gridworld/tag_gridworld.py b/example_envs/tag_gridworld/tag_gridworld.py index e0c146b..e58d88f 100644 --- a/example_envs/tag_gridworld/tag_gridworld.py +++ b/example_envs/tag_gridworld/tag_gridworld.py @@ -6,7 +6,6 @@ import numpy as np from gym import spaces -from gym.utils import seeding # seeding code from https://github.com/openai/gym/blob/master/gym/utils/seeding.py from warp_drive.utils.constants import Constants @@ -130,7 +129,7 @@ def __init__( name = "TagGridWorld" def seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) + self.np_random.seed(seed) return [seed] def set_global_state(self, key=None, value=None, t=None, dtype=None): @@ -349,9 +348,100 @@ def get_data_dictionary(self): ) return data_dict - def get_tensor_dictionary(self): - tensor_dict = DataFeed() - return tensor_dict + def step(self, actions=None): + self.timestep += 1 + args = [ + _LOC_X, + _LOC_Y, + _ACTIONS, + "_done_", + _REWARDS, + _OBSERVATIONS, + "wall_hit_penalty", + "tag_reward_for_tagger", + "tag_penalty_for_runner", + "step_cost_for_tagger", + "use_full_observation", + "world_boundary", + "_timestep_", + ("episode_length", "meta"), + ] + if self.env_backend == "pycuda": + self.cuda_step( + *self.cuda_step_function_feed(args), + block=self.cuda_function_manager.block, + grid=self.cuda_function_manager.grid, + ) + elif self.env_backend == "numba": + self.cuda_step[ + self.cuda_function_manager.grid, self.cuda_function_manager.block + ](*self.cuda_step_function_feed(args)) + else: + raise Exception("CUDATagGridWorld expects env_backend = 'pycuda' or 'numba' ") + + +class CUDATagGridWorldWithResetPool(TagGridWorld, CUDAEnvironmentContext): + """ + CUDA version of the TagGridWorld environment and with reset pool for the starting point. + Note: this class subclasses the Python environment class TagGridWorld, + and also the CUDAEnvironmentContext + """ + + def get_data_dictionary(self): + data_dict = DataFeed() + for feature in [ + _LOC_X, + _LOC_Y, + ]: + data_dict.add_data( + name=feature, + data=self.global_state[feature][0], + save_copy_and_apply_at_reset=False, + log_data_across_episode=False, + ) + data_dict.add_data_list( + [ + ("wall_hit_penalty", self.wall_hit_penalty), + ("tag_reward_for_tagger", self.tag_reward_for_tagger), + ("tag_penalty_for_runner", self.tag_penalty_for_runner), + ("step_cost_for_tagger", self.step_cost_for_tagger), + ("use_full_observation", self.use_full_observation), + ("world_boundary", self.grid_length), + ] + ) + return data_dict + + def get_reset_pool_dictionary(self): + + def _random_location_generator(): + starting_location_x = self.np_random.choice( + np.linspace(1, int(self.grid_length) - 1, int(self.grid_length) - 1), + self.num_agents + ).astype(np.int32) + starting_location_x[-1] = 0 + starting_location_y = self.np_random.choice( + np.linspace(1, int(self.grid_length) - 1, int(self.grid_length) - 1), + self.num_agents + ).astype(np.int32) + starting_location_y[-1] = 0 + return starting_location_x, starting_location_y + + N = 5 # we hard code the number of env pool for this demo purpose + x_pool = [] + y_pool = [] + for _ in range(N): + x, y = _random_location_generator() + x_pool.append(x) + y_pool.append(y) + + x_pool = np.stack(x_pool, axis=0) + y_pool = np.stack(y_pool, axis=0) + + reset_pool_dict = DataFeed() + reset_pool_dict.add_pool_for_reset(name=f"{_LOC_X}_reset_pool", data=x_pool, reset_target=_LOC_X) + reset_pool_dict.add_pool_for_reset(name=f"{_LOC_Y}_reset_pool", data=y_pool, reset_target=_LOC_Y) + + return reset_pool_dict def step(self, actions=None): self.timestep += 1 diff --git a/requirements.txt b/requirements.txt index 7f083d0..61162e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -gym>=0.18, <0.26 +gym>=0.26 matplotlib>=3.2.1 numpy>=1.18.1 pycuda>=2022.1 diff --git a/setup.py b/setup.py index 96beb34..b388731 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( name="rl-warp-drive", - version="2.4", + version="2.5.0", author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng", author_email="tian.lan@salesforce.com", description="Framework for fast end-to-end " diff --git a/tests/warp_drive/numba_tests/test_pool_reset.py b/tests/warp_drive/numba_tests/test_pool_reset.py new file mode 100644 index 0000000..a58d4e0 --- /dev/null +++ b/tests/warp_drive/numba_tests/test_pool_reset.py @@ -0,0 +1,145 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +import unittest + +import numpy as np +import torch + +from warp_drive.managers.numba_managers.numba_data_manager import NumbaDataManager +from warp_drive.managers.numba_managers.numba_function_manager import ( + NumbaEnvironmentReset, + NumbaFunctionManager, +) +from warp_drive.utils.common import get_project_root +from warp_drive.utils.data_feed import DataFeed + +_NUMBA_FILEPATH = f"warp_drive.numba_includes" + + +class TestEnvironmentReset(unittest.TestCase): + """ + Unit tests for the CUDA environment resetter + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dm = NumbaDataManager(num_agents=5, num_envs=2, episode_length=2) + self.fm = NumbaFunctionManager( + num_agents=int(self.dm.meta_info("n_agents")), + num_envs=int(self.dm.meta_info("n_envs")), + ) + self.fm.import_numba_from_source_code(f"{_NUMBA_FILEPATH}.test_build") + self.resetter = NumbaEnvironmentReset(function_manager=self.fm) + + def test_reset_for_different_dim(self): + + self.dm.data_on_device_via_torch("_done_")[:] = torch.from_numpy( + np.array([1, 0]) + ).cuda() + + done = self.dm.pull_data_from_device("_done_") + self.assertSequenceEqual(list(done), [1, 0]) + + # expected mean would be around 0.5 * (1+2+3+15) / 4 = 2.625 + a_reset_pool = np.random.rand(4, 10, 10) + a_reset_pool[1] *= 2 + a_reset_pool[2] *= 3 + a_reset_pool[3] *= 15 + + b_reset_pool = np.random.rand(4, 100) + b_reset_pool[1] *= 2 + b_reset_pool[2] *= 3 + b_reset_pool[3] *= 15 + + c_reset_pool = np.random.rand(100) + + data_feed = DataFeed() + data_feed.add_data( + name="a", data=np.random.randn(2, 10, 10), save_copy_and_apply_at_reset=False + ) + data_feed.add_pool_for_reset(name="a_reset_pool", data=a_reset_pool, reset_target="a") + data_feed.add_data( + name="b", data=np.random.randn(2, 100), save_copy_and_apply_at_reset=False + ) + data_feed.add_pool_for_reset(name="b_reset_pool", data=b_reset_pool, reset_target="b") + data_feed.add_data( + name="c", data=np.random.randn(2), save_copy_and_apply_at_reset=False + ) + data_feed.add_pool_for_reset(name="c_reset_pool", data=c_reset_pool, reset_target="c") + + self.dm.push_data_to_device(data_feed) + + self.resetter.init_reset_pool(self.dm) + + a = self.dm.pull_data_from_device("a") + b = self.dm.pull_data_from_device("b") + c = self.dm.pull_data_from_device("c") + + # soft reset + a_after_reset_0_mean = [] + a_after_reset_1_mean = [] + b_after_reset_0_mean = [] + b_after_reset_1_mean = [] + c_after_reset_0_mean = [] + c_after_reset_1_mean = [] + + for _ in range(2000): + self.resetter.reset_when_done(self.dm, mode="if_done", undo_done_after_reset=False) + a_after_reset = self.dm.pull_data_from_device("a") + a_after_reset_0_mean.append(a_after_reset[0].mean()) + a_after_reset_1_mean.append(a_after_reset[1].mean()) + b_after_reset = self.dm.pull_data_from_device("b") + b_after_reset_0_mean.append(b_after_reset[0].mean()) + b_after_reset_1_mean.append(b_after_reset[1].mean()) + c_after_reset = self.dm.pull_data_from_device("c") + c_after_reset_0_mean.append(c_after_reset[0].mean()) + c_after_reset_1_mean.append(c_after_reset[1].mean()) + # env 0 should have 1000 times random reset from the pool, so it should close to a_reset_pool.mean() + print(a_reset_pool.mean()) + print(np.array(a_after_reset_0_mean).mean()) + self.assertTrue(np.absolute(a_reset_pool.mean() - np.array(a_after_reset_0_mean).mean()) < 5e-1) + print(b_reset_pool.mean()) + print(np.array(b_after_reset_0_mean).mean()) + self.assertTrue(np.absolute(b_reset_pool.mean() - np.array(b_after_reset_0_mean).mean()) < 5e-1) + print(c_reset_pool.mean()) + print(np.array(c_after_reset_0_mean).mean()) + self.assertTrue(np.absolute(c_reset_pool.mean() - np.array(c_after_reset_0_mean).mean()) < 5e-1) + # env 1 has no reset at all, so it should be exactly the same as the original one + self.assertTrue(np.absolute(a[1].mean() - np.array(a_after_reset_1_mean).mean()) < 1e-5) + self.assertTrue(np.absolute(b[1].mean() - np.array(b_after_reset_1_mean).mean()) < 1e-5) + self.assertTrue(np.absolute(c[1].mean() - np.array(c_after_reset_1_mean).mean()) < 1e-5) + + # hard reset + a_after_reset_0_mean = [] + a_after_reset_1_mean = [] + b_after_reset_0_mean = [] + b_after_reset_1_mean = [] + c_after_reset_0_mean = [] + c_after_reset_1_mean = [] + for _ in range(2000): + self.resetter.reset_when_done(self.dm, mode="force_reset", undo_done_after_reset=False) + a_after_reset = self.dm.pull_data_from_device("a") + a_after_reset_0_mean.append(a_after_reset[0].mean()) + a_after_reset_1_mean.append(a_after_reset[1].mean()) + b_after_reset = self.dm.pull_data_from_device("b") + b_after_reset_0_mean.append(b_after_reset[0].mean()) + b_after_reset_1_mean.append(b_after_reset[1].mean()) + c_after_reset = self.dm.pull_data_from_device("c") + c_after_reset_0_mean.append(c_after_reset[0].mean()) + c_after_reset_1_mean.append(c_after_reset[1].mean()) + # env 0 should have 1000 times random reset from the pool, so it should close to a_reset_pool.mean() + self.assertTrue(np.absolute(a_reset_pool.mean() - np.array(a_after_reset_0_mean).mean()) < 5e-1) + self.assertTrue(np.absolute(b_reset_pool.mean() - np.array(b_after_reset_0_mean).mean()) < 5e-1) + self.assertTrue(np.absolute(c_reset_pool.mean() - np.array(c_after_reset_0_mean).mean()) < 5e-1) + # env 1 should have 1000 times random reset from the pool, so it should close to a_reset_pool.mean() + self.assertTrue(np.absolute(a_reset_pool.mean() - np.array(a_after_reset_1_mean).mean()) < 5e-1) + self.assertTrue(np.absolute(b_reset_pool.mean() - np.array(b_after_reset_1_mean).mean()) < 5e-1) + self.assertTrue(np.absolute(c_reset_pool.mean() - np.array(c_after_reset_1_mean).mean()) < 5e-1) + + + + diff --git a/warp_drive/env_cpu_gpu_consistency_checker.py b/warp_drive/env_cpu_gpu_consistency_checker.py index f2daab6..0382ab0 100644 --- a/warp_drive/env_cpu_gpu_consistency_checker.py +++ b/warp_drive/env_cpu_gpu_consistency_checker.py @@ -13,7 +13,6 @@ import numpy as np import torch from gym.spaces import Discrete, MultiDiscrete -from gym.utils import seeding from warp_drive.env_wrapper import EnvWrapper from warp_drive.training.utils.data_loader import ( @@ -38,10 +37,9 @@ def generate_random_actions(env, num_envs, seed=None): Generate random actions for each agent and each env. """ agent_ids = list(env.action_space.keys()) + np_random = np.random if seed is not None: - np_random = seeding.np_random(seed)[0] - else: - np_random = np.random + np_random.seed(seed) return [ { diff --git a/warp_drive/env_wrapper.py b/warp_drive/env_wrapper.py index a3d18dc..202f807 100644 --- a/warp_drive/env_wrapper.py +++ b/warp_drive/env_wrapper.py @@ -291,9 +291,10 @@ def repeat_across_env_dimension(array, num_envs): # Copy host data and tensors to device # Note: this happens only once after the first reset on the host - # Add env dimension to data if "save_copy_and_apply_at_reset" is True data_dictionary = self.env.get_data_dictionary() tensor_dictionary = self.env.get_tensor_dictionary() + reset_pool_dictionary = self.env.get_reset_pool_dictionary() + # Add env dimension to data if "save_copy_and_apply_at_reset" is True for key in data_dictionary: if data_dictionary[key]["attributes"][ "save_copy_and_apply_at_reset" @@ -309,6 +310,26 @@ def repeat_across_env_dimension(array, num_envs): tensor_dictionary[key]["data"] = repeat_across_env_dimension( tensor_dictionary[key]["data"], self.n_envs ) + # Add env dimension to data if "is_reset_pool" exists for this data + # if so, also check this data has "save_copy_and_apply_at_reset" = False + for key in reset_pool_dictionary: + if "is_reset_pool" in reset_pool_dictionary[key]["attributes"] and \ + reset_pool_dictionary[key]["attributes"]["is_reset_pool"]: + # find the corresponding target data + reset_target = reset_pool_dictionary[key]["attributes"]["reset_target"] + if reset_target in data_dictionary: + assert not data_dictionary[reset_target]["attributes"]["save_copy_and_apply_at_reset"] + data_dictionary[reset_target]["data"] = repeat_across_env_dimension( + data_dictionary[reset_target]["data"], self.n_envs + ) + elif reset_target in tensor_dictionary: + assert not tensor_dictionary[reset_target]["attributes"]["save_copy_and_apply_at_reset"] + tensor_dictionary[reset_target]["data"] = repeat_across_env_dimension( + tensor_dictionary[reset_target]["data"], self.n_envs + ) + else: + raise Exception(f"Fail to locate the target data {reset_target} for the reset pool " + f"in neither data_dictionary nor tensor_dictionary") self.cuda_data_manager.push_data_to_device(data_dictionary) @@ -316,6 +337,8 @@ def repeat_across_env_dimension(array, num_envs): tensor_dictionary, torch_accessible=True ) + self.cuda_data_manager.push_data_to_device(reset_pool_dictionary) + # All subsequent resets happen on the GPU self.reset_on_host = False @@ -329,6 +352,9 @@ def repeat_across_env_dimension(array, num_envs): return {} return obs # CPU version + def init_reset_pool(self, seed=None): + self.env_resetter.init_reset_pool(self.cuda_data_manager, seed) + def reset_only_done_envs(self): """ This function only works for GPU example_envs. diff --git a/warp_drive/managers/data_manager.py b/warp_drive/managers/data_manager.py index 7289fff..bb95689 100644 --- a/warp_drive/managers/data_manager.py +++ b/warp_drive/managers/data_manager.py @@ -48,6 +48,7 @@ def __init__( self._device_data_pointer = {} self._scalar_data_list = [] self._reset_data_list = [] + self._reset_target_to_pool = {} self._log_data_list = [] self._device_data_via_torch = {} self._shared_constants = {} @@ -212,10 +213,14 @@ def push_data_to_device(self, data: Dict, torch_accessible: bool = False): key not in self._host_data ), f"the data with name: {key} has already been registered at the host" value = content["data"] + is_reset_pool = False if "is_reset_pool" not in content["attributes"].keys() else \ + content["attributes"]["is_reset_pool"] save_copy_and_apply_at_reset = content["attributes"][ "save_copy_and_apply_at_reset" - ] - log_data_across_episode = content["attributes"]["log_data_across_episode"] + ] if not is_reset_pool else False + log_data_across_episode = content["attributes"][ + "log_data_across_episode" + ] if not is_reset_pool else False if isinstance(value, (np.ndarray, list)): @@ -223,6 +228,18 @@ def push_data_to_device(self, data: Dict, torch_accessible: bool = False): f"the data with name: {key} has " f"already been pushed to device" ) + if is_reset_pool: + reset_target_key = content["attributes"]["reset_target"] + assert reset_target_key not in self._reset_target_to_pool, ( + f"the data with name: {key} has " + f"already been registered at the reset_target_to_pool" + ) + assert reset_target_key not in self._reset_data_list, ( + f"the data with name: {reset_target_key} has " + f"already been registered at the reset_data_list" + ) + self._reset_target_to_pool[reset_target_key] = key + if isinstance(value, np.ndarray): if not value.flags.c_contiguous: array = np.array(value, order="C") @@ -267,6 +284,11 @@ def push_data_to_device(self, data: Dict, torch_accessible: bool = False): f"the data with name: {key} has " f"already been registered at the reset_data_list" ) + assert key not in self._reset_target_to_pool, ( + f"the data with name: {key} has " + f"already been registered at the reset_target_to_pool" + ) + key_at_reset = f"{key}_at_reset" self._shape[key_at_reset] = self._host_data[key].shape self._dtype[key_at_reset] = self._host_data[key].dtype.name @@ -278,7 +300,7 @@ def push_data_to_device(self, data: Dict, torch_accessible: bool = False): self._to_device( key, name_on_device=key_at_reset, - torch_accessible=torch_accessible, + torch_accessible=False, ) self._reset_data_list.append(key) @@ -422,6 +444,10 @@ def get_dtype(self, name: str): assert name in self._dtype return self._dtype[name] + def get_reset_pool(self, name: str): + assert name in self._reset_target_to_pool + return self._reset_target_to_pool[name] + def _type_warning_helper(self, key: str, old: str, new: str, comment=None): logging.warning( f"{self.__class__.__name__} casts the data '{key}' " @@ -446,6 +472,10 @@ def scalar_data_list(self): def reset_data_list(self): return self._reset_data_list + @property + def reset_target_to_pool(self): + return self._reset_target_to_pool + @property def log_data_list(self): return self._log_data_list diff --git a/warp_drive/managers/function_manager.py b/warp_drive/managers/function_manager.py index 0e51de0..1fe2a68 100644 --- a/warp_drive/managers/function_manager.py +++ b/warp_drive/managers/function_manager.py @@ -223,6 +223,7 @@ def __init__(self, function_manager: CUDAFunctionManager): self._blocks_per_env = function_manager.blocks_per_env self._cuda_custom_reset = None self._cuda_reset_feed = None + self._random_initialized = False def register_custom_reset_function( self, data_manager: CUDADataManager, reset_function_name=None @@ -232,27 +233,43 @@ def register_custom_reset_function( def custom_reset(self, args: Optional[list] = None, block=None, grid=None): raise NotImplementedError + def init_reset_pool( + self, + data_manager: CUDADataManager, + seed: Optional[int] = None, + ): + raise NotImplementedError + def reset_when_done( self, data_manager: CUDADataManager, mode: str = "if_done", undo_done_after_reset: bool = True, - use_random_reset: bool = False, ): - if not use_random_reset: - self.reset_when_done_deterministic( - data_manager, mode, undo_done_after_reset - ) + if mode == "if_done": + force_reset = np.int32(0) + elif mode == "force_reset": + force_reset = np.int32(1) else: - # TODO: To be implemented - # self.reset_when_done_random(data_manager, mode, undo_done_after_reset) - raise NotImplementedError + raise Exception( + f"unknown reset mode: {mode}, only accept 'if_done' and 'force_reset' " + ) + self.reset_when_done_deterministic(data_manager, force_reset) + self.reset_when_done_from_pool(data_manager, force_reset) + if undo_done_after_reset: + self._undo_done_flag_and_reset_timestep(data_manager, force_reset) def reset_when_done_deterministic( self, data_manager: CUDADataManager, - mode: str = "if_done", - undo_done_after_reset: bool = True, + force_reset: int, + ): + raise NotImplementedError + + def reset_when_done_from_pool( + self, + data_manager: CUDADataManager, + force_reset: int, ): raise NotImplementedError diff --git a/warp_drive/managers/numba_managers/numba_function_manager.py b/warp_drive/managers/numba_managers/numba_function_manager.py index 7994fe1..df7ce33 100644 --- a/warp_drive/managers/numba_managers/numba_function_manager.py +++ b/warp_drive/managers/numba_managers/numba_function_manager.py @@ -196,6 +196,10 @@ def initialize_default_functions(self): "reset_when_done_1d", "reset_when_done_2d", "reset_when_done_3d", + "init_random_for_reset", + "reset_when_done_1d_from_pool", + "reset_when_done_2d_from_pool", + "reset_when_done_3d_from_pool", "undo_done_flag_and_reset_timestep", ] self.initialize_functions(default_func_names) @@ -351,9 +355,15 @@ def __init__(self, function_manager: NumbaFunctionManager): self.reset_func_1d = self._function_manager.get_function("reset_when_done_1d") self.reset_func_2d = self._function_manager.get_function("reset_when_done_2d") self.reset_func_3d = self._function_manager.get_function("reset_when_done_3d") + + self.reset_func_1d_from_pool = self._function_manager.get_function("reset_when_done_1d_from_pool") + self.reset_func_2d_from_pool = self._function_manager.get_function("reset_when_done_2d_from_pool") + self.reset_func_3d_from_pool = self._function_manager.get_function("reset_when_done_3d_from_pool") + self.undo = self._function_manager.get_function( "undo_done_flag_and_reset_timestep" ) + self.rng_states_dict = {} def register_custom_reset_function( self, data_manager: NumbaDataManager, reset_function_name=None @@ -386,11 +396,57 @@ def custom_reset(self, args: Optional[list] = None, block=None, grid=None): else: self._cuda_custom_reset[grid, block](*self._cuda_reset_feed(args)) + def init_reset_pool( + self, + data_manager: NumbaDataManager, + seed: Optional[int] = None, + ): + """ + Init random function for the reset pool + :param data_manager: NumbaDataManager object + :param seed: random seed selected for the initialization + """ + if len(data_manager.reset_target_to_pool) == 0: + return + + self._security_check_reset_pool(data_manager) + + if seed is None: + seed = time.time() + logging.info( + f"random seed is not provided, by default, " + f"using the current timestamp {seed} as seed" + ) + seed = np.int32(seed) + xoroshiro128p_dtype = np.dtype( + [("s0", np.uint64), ("s1", np.uint64)], align=True + ) + sz = self._function_manager._num_envs + rng_states = numba_driver.device_array(sz, dtype=xoroshiro128p_dtype) + init_random_for_reset = self._function_manager.get_function("init_random_for_reset") + init_random_for_reset(rng_states, seed) + self.rng_states_dict["rng_states"] = rng_states + self._random_initialized = True + + def _security_check_reset_pool(self, data_manager): + for name, pool_name in data_manager.reset_target_to_pool.items(): + data_shape = data_manager.get_shape(name) + data_type = data_manager.get_dtype(name) + pool_shape = data_manager.get_shape(pool_name) + pool_type = data_manager.get_dtype(pool_name) + assert data_type == pool_type, \ + f"Inconsistency of dtype is found for data: {name} has type {data_type} " \ + f"and its reset pool: {pool_name} has type {pool_type}" + assert data_shape[0] == self._function_manager._num_envs + assert pool_shape[0] > 1 + for i in range(1, len(data_shape)): + assert data_shape[i] == pool_shape[i], \ + f"Inconsistency of shape is found for data: {name} and its reset pool: {pool_name}" + def reset_when_done_deterministic( self, data_manager: NumbaDataManager, - mode: str = "if_done", - undo_done_after_reset: bool = True, + force_reset: int, ): """ Monitor the done flag for each env. If any env is done, it will reset this @@ -399,20 +455,11 @@ def reset_when_done_deterministic( and turn off the done flag. Therefore, this env can safely get restarted. :param data_manager: NumbaDataManager object - :param mode: "if_done": reset an env if done flag is observed for that env, - "force_reset": reset all env in a hard way - :param undo_done_after_reset: If True, turn off the done flag - and reset timestep after all data have been reset - (the flag should be True for most cases) + :param force_reset: 0: reset an env if done flag is observed for that env, + 1: reset all env in a hard way """ - if mode == "if_done": - force_reset = np.int32(0) - elif mode == "force_reset": - force_reset = np.int32(1) - else: - raise Exception( - f"unknown reset mode: {mode}, only accept 'if_done' and 'force_reset' " - ) + if len(data_manager.reset_data_list) == 0: + return for name in data_manager.reset_data_list: f_shape = data_manager.get_shape(name) @@ -471,8 +518,87 @@ def reset_when_done_deterministic( force_reset, ) - if undo_done_after_reset: - self._undo_done_flag_and_reset_timestep(data_manager, force_reset) + def reset_when_done_from_pool( + self, + data_manager: NumbaDataManager, + force_reset: int, + ): + """ + Monitor the done flag for each env. If any env is done, it will reset this + particular env without interrupting other example_envs. + The reset includes randomly select starting values from the candidate pool, + and copy the starting values of this env back, + and turn off the done flag. Therefore, this env can safely get restarted. + + :param data_manager: NumbaDataManager object + :param force_reset: 0: reset an env if done flag is observed for that env, + 1: reset all env in a hard way + """ + if len(data_manager.reset_target_to_pool) == 0: + return + + assert self._random_initialized, ( + "reset_when_done_from_pool() requires the random seed initialized first, " + "please call init_reset_pool()" + ) + + for name, pool_name in data_manager.reset_target_to_pool.items(): + f_shape = data_manager.get_shape(name) + assert f_shape[0] > 1, "reset function assumes the 0th dimension is n_pool" + if len(f_shape) >= 3: + if len(f_shape) > 3: + raise Exception( + "Numba environment.reset() temporarily " + "not supports array dimension > 3" + ) + agent_dim = np.int32(f_shape[1]) + feature_dim = np.int32(np.prod(f_shape[2:])) + data_shape = "is_3d" + elif len(f_shape) == 2: + feature_dim = np.int32(f_shape[1]) + data_shape = "is_2d" + else: # len(f_shape) == 1: + feature_dim = np.int32(1) + data_shape = "is_1d" + + dtype = data_manager.get_dtype(name) + if "float" not in dtype and "int" not in dtype: + raise Exception(f"unknown dtype: {dtype}") + if data_shape == "is_3d": + reset_func = self.reset_func_3d_from_pool + reset_func[ + self._grid, (int((agent_dim - 1) // self._blocks_per_env + 1), 1, 1) + ]( + self.rng_states_dict["rng_states"], + data_manager.device_data(name), + data_manager.device_data(pool_name), + data_manager.device_data("_done_"), + agent_dim, + feature_dim, + force_reset, + ) + elif data_shape == "is_2d": + reset_func = self.reset_func_2d_from_pool + reset_func[ + self._grid, + (int((feature_dim - 1) // self._blocks_per_env + 1), 1, 1), + ]( + self.rng_states_dict["rng_states"], + data_manager.device_data(name), + data_manager.device_data(pool_name), + data_manager.device_data("_done_"), + feature_dim, + force_reset, + ) + elif data_shape == "is_1d": + reset_func = self.reset_func_1d_from_pool + reset_func[self._grid, (1, 1, 1)]( + self.rng_states_dict["rng_states"], + data_manager.device_data(name), + data_manager.device_data(pool_name), + data_manager.device_data("_done_"), + force_reset, + ) def _undo_done_flag_and_reset_timestep( self, data_manager: NumbaDataManager, force_reset diff --git a/warp_drive/managers/pycuda_managers/pycuda_function_manager.py b/warp_drive/managers/pycuda_managers/pycuda_function_manager.py index 8adf818..8a301e5 100644 --- a/warp_drive/managers/pycuda_managers/pycuda_function_manager.py +++ b/warp_drive/managers/pycuda_managers/pycuda_function_manager.py @@ -658,11 +658,17 @@ def custom_reset(self, args: Optional[list] = None, block=None, grid=None): *self._cuda_reset_feed(args), block=block, grid=grid ) + def init_reset_pool( + self, + data_manager: PyCUDADataManager, + seed: Optional[int] = None, + ): + return + def reset_when_done_deterministic( self, data_manager: PyCUDADataManager, - mode: str = "if_done", - undo_done_after_reset: bool = True, + force_reset: int, ): """ Monitor the done flag for each env. If any env is done, it will reset this @@ -671,20 +677,11 @@ def reset_when_done_deterministic( and turn off the done flag. Therefore, this env can safely get restarted. :param data_manager: PyCUDADataManager object - :param mode: "if_done": reset an env if done flag is observed for that env, - "force_reset": reset all env in a hard way - :param undo_done_after_reset: If True, turn off the done flag - and reset timestep after all data have been reset - (the flag should be True for most cases) + :param force_reset: 0: reset an env if done flag is observed for that env, + 1: reset all env in a hard way """ - if mode == "if_done": - force_reset = np.int32(0) - elif mode == "force_reset": - force_reset = np.int32(1) - else: - raise Exception( - f"unknown reset mode: {mode}, only accept 'if_done' and 'force_reset' " - ) + if len(data_manager.reset_data_list) == 0: + return for name in data_manager.reset_data_list: f_shape = data_manager.get_shape(name) @@ -736,8 +733,13 @@ def reset_when_done_deterministic( grid=self._grid, ) - if undo_done_after_reset: - self._undo_done_flag_and_reset_timestep(data_manager, force_reset) + def reset_when_done_from_pool( + self, + data_manager: PyCUDADataManager, + force_reset: int, + ): + # TO DO + return def _undo_done_flag_and_reset_timestep( self, data_manager: PyCUDADataManager, force_reset diff --git a/warp_drive/numba_includes/core/pool_reset.py b/warp_drive/numba_includes/core/pool_reset.py new file mode 100644 index 0000000..8dec750 --- /dev/null +++ b/warp_drive/numba_includes/core/pool_reset.py @@ -0,0 +1,53 @@ +from numba import cuda as numba_driver +from numba import float32, int32, from_dtype +from numba.cuda.random import init_xoroshiro128p_states, xoroshiro128p_uniform_float32 +import numpy as np + +xoroshiro128p_type = from_dtype(np.dtype([("s0", np.uint64), ("s1", np.uint64)], align=True)) + + +def init_random_for_reset(rng_states, seed): + init_xoroshiro128p_states(states=rng_states, seed=seed) + + +# @numba_driver.jit([(int32[:], int32[:], int32[:], int32), +# (float32[:], float32[:], int32[:], int32)]) +@numba_driver.jit +def reset_when_done_1d_from_pool(rng_states, data, ref, done, force_reset): + env_id = numba_driver.blockIdx.x + tid = numba_driver.threadIdx.x + if tid == 0: + if force_reset > 0.5 or done[env_id] > 0.5: + p = xoroshiro128p_uniform_float32(rng_states, env_id) + ref_id_float = p * ref.shape[0] + ref_id = int(ref_id_float) + data[env_id] = ref[ref_id] + + +# @numba_driver.jit([(int32[:, :], int32[:, :], int32[:], int32, int32), +# (float32[:, :], float32[:, :], int32[:], int32, int32)]) +@numba_driver.jit +def reset_when_done_2d_from_pool(rng_states, data, ref, done, feature_dim, force_reset): + env_id = numba_driver.blockIdx.x + tid = numba_driver.threadIdx.x + if force_reset > 0.5 or done[env_id] > 0.5: + p = xoroshiro128p_uniform_float32(rng_states, env_id) + ref_id_float = p * ref.shape[0] + ref_id = int(ref_id_float) + if tid < feature_dim: + data[env_id, tid] = ref[ref_id, tid] + + +# @numba_driver.jit([(int32[:, :, :], int32[:, :, :], int32[:], int32, int32, int32), +# (float32[:, :, :], float32[:, :, :], int32[:], int32, int32, int32)]) +@numba_driver.jit +def reset_when_done_3d_from_pool(rng_states, data, ref, done, agent_dim, feature_dim, force_reset): + env_id = numba_driver.blockIdx.x + tid = numba_driver.threadIdx.x + if force_reset > 0.5 or done[env_id] > 0.5: + p = xoroshiro128p_uniform_float32(rng_states, env_id) + ref_id_float = p * ref.shape[0] + ref_id = int(ref_id_float) + if tid < agent_dim: + for i in range(feature_dim): + data[env_id, tid, i] = ref[ref_id, tid, i] \ No newline at end of file diff --git a/warp_drive/numba_includes/template_env_runner.txt b/warp_drive/numba_includes/template_env_runner.txt index 7ab4ab6..ee069e1 100644 --- a/warp_drive/numba_includes/template_env_runner.txt +++ b/warp_drive/numba_includes/template_env_runner.txt @@ -7,5 +7,6 @@ from warp_drive.numba_includes.core.random import * from warp_drive.numba_includes.core.reset import * from warp_drive.numba_includes.core.log import * +from warp_drive.numba_includes.core.pool_reset import * from <> import * diff --git a/warp_drive/numba_includes/test_build.py b/warp_drive/numba_includes/test_build.py index 2c41561..c6f6005 100644 --- a/warp_drive/numba_includes/test_build.py +++ b/warp_drive/numba_includes/test_build.py @@ -1,3 +1,4 @@ from warp_drive.numba_includes.core.log import * from warp_drive.numba_includes.core.random import * from warp_drive.numba_includes.core.reset import * +from warp_drive.numba_includes.core.pool_reset import * diff --git a/warp_drive/training/example_training_script_numba.py b/warp_drive/training/example_training_script_numba.py index 4ab885f..1ce280a 100644 --- a/warp_drive/training/example_training_script_numba.py +++ b/warp_drive/training/example_training_script_numba.py @@ -18,7 +18,7 @@ import yaml from example_envs.tag_continuous.tag_continuous import TagContinuous -from example_envs.tag_gridworld.tag_gridworld import CUDATagGridWorld +from example_envs.tag_gridworld.tag_gridworld import CUDATagGridWorld, CUDATagGridWorldWithResetPool from warp_drive.env_wrapper import EnvWrapper from warp_drive.training.trainer import Trainer from warp_drive.training.utils.distributed_train.distributed_trainer_numba import ( @@ -31,6 +31,7 @@ _TAG_CONTINUOUS = "tag_continuous" _TAG_GRIDWORLD = "tag_gridworld" +_TAG_GRIDWORLD_WITH_RESET_POOL = "tag_gridworld_with_reset_pool" # Example usages (from the root folder): # >> python warp_drive/training/example_training_script.py -e tag_gridworld @@ -72,11 +73,20 @@ def setup_trainer_and_train( event_messenger=event_messenger, process_id=device_id, ) + elif run_configuration["name"] == _TAG_GRIDWORLD_WITH_RESET_POOL: + env_wrapper = EnvWrapper( + CUDATagGridWorldWithResetPool(**run_configuration["env"]), + num_envs=num_envs, + env_backend="numba", + event_messenger=event_messenger, + process_id=device_id, + ) else: raise NotImplementedError( f"Currently, the environments supported are [" f"{_TAG_GRIDWORLD}, " f"{_TAG_CONTINUOUS}" + f"{_TAG_GRIDWORLD_WITH_RESET_POOL}" f"]", ) # Policy mapping to agent ids: agents can share models diff --git a/warp_drive/training/run_configs/tag_gridworld_with_reset_pool.yaml b/warp_drive/training/run_configs/tag_gridworld_with_reset_pool.yaml new file mode 100644 index 0000000..27d7e48 --- /dev/null +++ b/warp_drive/training/run_configs/tag_gridworld_with_reset_pool.yaml @@ -0,0 +1,49 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +# YAML configuration for the tag gridworld environment +name: "tag_gridworld_with_reset_pool" +# Environment settings +env: + num_taggers: 4 + grid_length: 100 + episode_length: 100 + seed: 20 + wall_hit_penalty: 0.1 + tag_reward_for_tagger: 10.0 + tag_penalty_for_runner: 5.0 + step_cost_for_tagger: 0.01 +# Trainer settings +trainer: + num_envs: 2000 # number of environment replicas + num_episodes: 20000 # number of episodes to run the training for. Can be arbitrarily high! + train_batch_size: 200000 # total batch size used for training per iteration (across all the environments) + env_backend: "pycuda" # environment backend, pycuda or numba +# Policy network settings +policy: # list all the policies below + shared: + to_train: True # flag indicating whether the model needs to be trained + algorithm: "A2C" # algorithm used to train the policy + vf_loss_coeff: 1 # loss coefficient schedule for the value function loss + entropy_coeff: 0.05 # loss coefficient schedule for the entropy loss + clip_grad_norm: True # flag indicating whether to clip the gradient norm or not + max_grad_norm: 3 # when clip_grad_norm is True, the clip level + normalize_advantage: False # flag indicating whether to normalize advantage or not + normalize_return: False # flag indicating whether to normalize return or not + gamma: 0.98 # discount factor + lr: 0.001 # learning rate + model: # policy model settings + type: "fully_connected" # model type + fc_dims: [32, 32] # dimension(s) of the fully connected layers as a list + model_ckpt_filepath: "" # filepath (used to restore a previously saved model) +# Checkpoint saving setting +saving: + metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics + model_params_save_freq: 5000 # how often (in iterations) to save the model parameters + basedir: "/tmp" # base folder used for saving + name: "tag_gridworld" # base folder used for saving + tag: "experiments" # experiment name + diff --git a/warp_drive/training/trainer.py b/warp_drive/training/trainer.py index d74600f..b9c17d6 100644 --- a/warp_drive/training/trainer.py +++ b/warp_drive/training/trainer.py @@ -245,6 +245,7 @@ def __init__( torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) + self.cuda_envs.init_reset_pool(seed + random.randint(1, 10000)) # Define models, optimizers, and learning rate schedules self.models = {} diff --git a/warp_drive/utils/data_feed.py b/warp_drive/utils/data_feed.py index 7428ed7..ce8f881 100644 --- a/warp_drive/utils/data_feed.py +++ b/warp_drive/utils/data_feed.py @@ -85,3 +85,21 @@ def add_data_list(self, data_list): raise Exception( "Unknown type of data configure, only support tuple and dictionary" ) + + def add_pool_for_reset(self, name, data, reset_target): + """ + a special data that serves for the reset function to pick up values at random + :param name: name of the data + :param data: data in the form of list, array or scalar + :param reset_target: specify the name of the data for the reset pool to apply for + + for example, the following will add a reset pool called 'position_reset_pool' to reset 'position' + add_pool_for_reset("position_reset_pool", [1,2,3], "position") + """ + self.add_data(name, + data, + save_copy_and_apply_at_reset = False, + log_data_across_episode = False, + is_reset_pool=True, + reset_target=reset_target) + diff --git a/warp_drive/utils/gpu_environment_context.py b/warp_drive/utils/gpu_environment_context.py index b87fbc1..212dc3e 100644 --- a/warp_drive/utils/gpu_environment_context.py +++ b/warp_drive/utils/gpu_environment_context.py @@ -1,4 +1,5 @@ import logging +from warp_drive.utils.data_feed import DataFeed class CUDAEnvironmentContext: @@ -30,3 +31,15 @@ def initialize_step_function_context( except Exception as err: logging.error(err) return False + + def get_data_dictionary(self): + data_dict = DataFeed() + return data_dict + + def get_tensor_dictionary(self): + tensor_dict = DataFeed() + return tensor_dict + + def get_reset_pool_dictionary(self): + reset_pool_dict = DataFeed() + return reset_pool_dict