From dfaafa988338f9932b5633e07c7adf4799340170 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Thu, 14 Dec 2023 22:25:40 -0800 Subject: [PATCH 01/19] pendulum env, continuous action data loader --- .../classic_control/pendulum/pendulum.py | 110 ++++++++++++++++++ .../pendulum/pendulum_step_numba.py | 72 ++++++++++++ .../classic_control/test_pendulum.py | 86 ++++++++++++++ warp_drive/env_cpu_gpu_consistency_checker.py | 7 +- warp_drive/training/utils/data_loader.py | 48 +++++++- warp_drive/utils/numba_utils/misc.py | 1 + 6 files changed, 318 insertions(+), 6 deletions(-) create mode 100644 example_envs/single_agent/classic_control/pendulum/pendulum.py create mode 100644 example_envs/single_agent/classic_control/pendulum/pendulum_step_numba.py create mode 100644 tests/example_envs/numba_tests/single_agent/classic_control/test_pendulum.py diff --git a/example_envs/single_agent/classic_control/pendulum/pendulum.py b/example_envs/single_agent/classic_control/pendulum/pendulum.py new file mode 100644 index 0000000..a2d61de --- /dev/null +++ b/example_envs/single_agent/classic_control/pendulum/pendulum.py @@ -0,0 +1,110 @@ +import numpy as np +from warp_drive.utils.constants import Constants +from warp_drive.utils.data_feed import DataFeed +from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext + +from example_envs.single_agent.base import SingleAgentEnv, map_to_single_agent, get_action_for_single_agent +from gym.envs.classic_control.pendulum import PendulumEnv + +_OBSERVATIONS = Constants.OBSERVATIONS +_ACTIONS = Constants.ACTIONS +_REWARDS = Constants.REWARDS + + +class ClassicControlPendulumEnv(SingleAgentEnv): + + name = "ClassicControlPendulumEnv" + + def __init__(self, episode_length, env_backend="cpu", reset_pool_size=0, seed=None): + super().__init__(episode_length, env_backend, reset_pool_size, seed=seed) + + self.gym_env = PendulumEnv(g=9.81) + + self.action_space = map_to_single_agent(self.gym_env.action_space) + self.observation_space = map_to_single_agent(self.gym_env.observation_space) + + def step(self, action=None): + self.timestep += 1 + action = get_action_for_single_agent(action) + observation, reward, terminated, _, _ = self.gym_env.step(action) + + obs = map_to_single_agent(observation) + rew = map_to_single_agent(reward) + done = {"__all__": self.timestep >= self.episode_length or terminated} + info = {} + + return obs, rew, done, info + + def reset(self): + self.timestep = 0 + if self.reset_pool_size < 2: + # we use a fixed initial state all the time + initial_obs, _ = self.gym_env.reset(seed=self.seed) + else: + initial_obs, _ = self.gym_env.reset(seed=None) + obs = map_to_single_agent(initial_obs) + + return obs + + +class CUDAClassicControlPendulumEnv(ClassicControlPendulumEnv, CUDAEnvironmentContext): + + def get_data_dictionary(self): + data_dict = DataFeed() + # the reset function returns the initial observation which is a processed tuple from state + # so we will call env.state to access the initial state + self.gym_env.reset(seed=self.seed) + initial_state = self.gym_env.state + + if self.reset_pool_size < 2: + data_dict.add_data( + name="state", + data=np.atleast_2d(initial_state), + save_copy_and_apply_at_reset=True, + ) + else: + data_dict.add_data( + name="state", + data=np.atleast_2d(initial_state), + save_copy_and_apply_at_reset=False, + ) + return data_dict + + def get_tensor_dictionary(self): + tensor_dict = DataFeed() + return tensor_dict + + def get_reset_pool_dictionary(self): + reset_pool_dict = DataFeed() + if self.reset_pool_size >= 2: + state_reset_pool = [] + for _ in range(self.reset_pool_size): + self.gym_env.reset(seed=None) + initial_state = self.gym_env.state + state_reset_pool.append(np.atleast_2d(initial_state)) + state_reset_pool = np.stack(state_reset_pool, axis=0) + assert len(state_reset_pool.shape) == 3 and state_reset_pool.shape[2] == 2 + + reset_pool_dict.add_pool_for_reset(name="state_reset_pool", + data=state_reset_pool, + reset_target="state") + return reset_pool_dict + + def step(self, actions=None): + self.timestep += 1 + args = [ + "state", + _ACTIONS, + "_done_", + _REWARDS, + _OBSERVATIONS, + "_timestep_", + ("episode_length", "meta"), + ] + if self.env_backend == "numba": + self.cuda_step[ + self.cuda_function_manager.grid, self.cuda_function_manager.block + ](*self.cuda_step_function_feed(args)) + else: + raise Exception("CUDAClassicControlPendulumEnv expects env_backend = 'numba' ") + diff --git a/example_envs/single_agent/classic_control/pendulum/pendulum_step_numba.py b/example_envs/single_agent/classic_control/pendulum/pendulum_step_numba.py new file mode 100644 index 0000000..2bde00b --- /dev/null +++ b/example_envs/single_agent/classic_control/pendulum/pendulum_step_numba.py @@ -0,0 +1,72 @@ +import numba +import numba.cuda as numba_driver +import numpy as np +import math + +DEFAULT_X = np.pi +DEFAULT_Y = 1.0 + +max_speed = 8 +max_torque = 2.0 +dt = 0.05 +g = 9.81 +m = 1.0 +l = 1.0 + +@numba_driver.jit +def _clip(v, min, max): + if v < min: + return min + if v > max: + return max + return v + + +@numba_driver.jit +def angle_normalize(x): + return ((x + np.pi) % (2 * np.pi)) - np.pi + + +@numba_driver.jit +def NumbaClassicControlPendulumEnvStep( + state_arr, + action_arr, + done_arr, + reward_arr, + observation_arr, + env_timestep_arr, + episode_length): + + kEnvId = numba_driver.blockIdx.x + kThisAgentId = numba_driver.threadIdx.x + + assert kThisAgentId == 0, "We only have one agent per environment" + + env_timestep_arr[kEnvId] += 1 + + assert 0 < env_timestep_arr[kEnvId] <= episode_length + + action = action_arr[kEnvId, kThisAgentId, 0] + + u = _clip(action, -max_torque, max_torque) + + th = state_arr[kEnvId, kThisAgentId, 0] + thdot = state_arr[kEnvId, kThisAgentId, 1] + + costs = angle_normalize(th) ** 2 + 0.1 * thdot ** 2 + 0.001 * (u ** 2) + + newthdot = thdot + (3 * g / (2 * l) * math.sin(th) + 3.0 / (m * l ** 2) * u) * dt + newthdot = _clip(newthdot, -max_speed, max_speed) + newth = th + newthdot * dt + + state_arr[kEnvId, kThisAgentId, 0] = newth + state_arr[kEnvId, kThisAgentId, 1] = newthdot + + observation_arr[kEnvId, kThisAgentId, 0] = math.cos(newth) + observation_arr[kEnvId, kThisAgentId, 1] = math.sin(newth) + observation_arr[kEnvId, kThisAgentId, 2] = newthdot + + reward_arr[kEnvId, kThisAgentId] = -costs + + if env_timestep_arr[kEnvId] == episode_length: + done_arr[kEnvId] = 1 diff --git a/tests/example_envs/numba_tests/single_agent/classic_control/test_pendulum.py b/tests/example_envs/numba_tests/single_agent/classic_control/test_pendulum.py new file mode 100644 index 0000000..7e1c0de --- /dev/null +++ b/tests/example_envs/numba_tests/single_agent/classic_control/test_pendulum.py @@ -0,0 +1,86 @@ +import unittest +import numpy as np +import torch + +from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU +from example_envs.single_agent.classic_control.pendulum.pendulum import \ + ClassicControlPendulumEnv, CUDAClassicControlPendulumEnv +from warp_drive.env_wrapper import EnvWrapper + + +env_configs = { + "test1": { + "episode_length": 200, + "reset_pool_size": 0, + "seed": 32145, + }, +} + + +class MyTestCase(unittest.TestCase): + """ + CPU v GPU consistency unit tests + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.testing_class = EnvironmentCPUvsGPU( + cpu_env_class=ClassicControlPendulumEnv, + cuda_env_class=CUDAClassicControlPendulumEnv, + env_configs=env_configs, + gpu_env_backend="numba", + num_envs=5, + num_episodes=2, + ) + + def test_env_consistency(self): + try: + self.testing_class.test_env_reset_and_step() + except AssertionError: + self.fail("ClassicControlPendulumEnv environment consistency tests failed") + + def test_reset_pool(self): + env_wrapper = EnvWrapper( + env_obj=CUDAClassicControlPendulumEnv(episode_length=100, reset_pool_size=8), + num_envs=3, + env_backend="numba", + ) + env_wrapper.reset_all_envs() + env_wrapper.env_resetter.init_reset_pool(env_wrapper.cuda_data_manager, seed=12345) + self.assertTrue(env_wrapper.cuda_data_manager.reset_target_to_pool["state"] == "state_reset_pool") + + # squeeze() the agent dimension which is 1 always + state_after_initial_reset = env_wrapper.cuda_data_manager.pull_data_from_device("state").squeeze() + + reset_pool = env_wrapper.cuda_data_manager.pull_data_from_device( + env_wrapper.cuda_data_manager.get_reset_pool("state")) + reset_pool_mean = reset_pool.mean(axis=0).squeeze() + + self.assertTrue(reset_pool.std(axis=0).mean() > 1e-4) + + env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy( + np.array([1, 1, 0]) + ).cuda() + + state_values = {0: [], 1: [], 2: []} + for _ in range(10000): + env_wrapper.env_resetter.reset_when_done(env_wrapper.cuda_data_manager, mode="if_done", undo_done_after_reset=False) + res = env_wrapper.cuda_data_manager.pull_data_from_device("state") + state_values[0].append(res[0]) + state_values[1].append(res[1]) + state_values[2].append(res[2]) + + state_values_env0_mean = np.stack(state_values[0]).mean(axis=0).squeeze() + state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze() + state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze() + + for i in range(len(reset_pool_mean)): + self.assertTrue(np.absolute(state_values_env0_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) + self.assertTrue(np.absolute(state_values_env1_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) + self.assertTrue( + np.absolute( + state_values_env2_mean[i] - state_after_initial_reset[0][i] + ) < 0.001 * abs(state_after_initial_reset[0][i]) + ) + + diff --git a/warp_drive/env_cpu_gpu_consistency_checker.py b/warp_drive/env_cpu_gpu_consistency_checker.py index 0382ab0..f42ce96 100644 --- a/warp_drive/env_cpu_gpu_consistency_checker.py +++ b/warp_drive/env_cpu_gpu_consistency_checker.py @@ -12,7 +12,7 @@ import numpy as np import torch -from gym.spaces import Discrete, MultiDiscrete +from gym.spaces import Discrete, MultiDiscrete, Box from warp_drive.env_wrapper import EnvWrapper from warp_drive.training.utils.data_loader import ( @@ -61,8 +61,11 @@ def _generate_random_actions_helper(action_space, np_random): high=action_space.nvec, dtype=np.int32, ) + if isinstance(action_space, Box): + return np_random.uniform(low=action_space.low, high=action_space.high) + raise NotImplementedError( - "Only 'Discrete' or 'MultiDiscrete' type action spaces are supported" + "Only 'Discrete', 'MultiDiscrete' or 'Box' type action spaces are supported" ) diff --git a/warp_drive/training/utils/data_loader.py b/warp_drive/training/utils/data_loader.py index 165dc3b..b785a45 100644 --- a/warp_drive/training/utils/data_loader.py +++ b/warp_drive/training/utils/data_loader.py @@ -299,9 +299,11 @@ def _validate_obs_action_spaces(agent_ids, env_wrapper): action_dims = [tuple(act_space.nvec) for act_space in action_spaces] elif isinstance(first_agent_action_space, Discrete): action_dims = [tuple([act_space.n]) for act_space in action_spaces] + elif isinstance(first_agent_action_space, Box): + action_dims = [act_space.shape for act_space in action_spaces] else: raise NotImplementedError( - "Only 'Discrete' or 'MultiDiscrete' type action spaces are supported!" + "Only 'Discrete', 'MultiDiscrete' or 'Box' type action spaces are supported!" ) assert all_equal(action_dims) @@ -463,10 +465,43 @@ def _create_action_placeholders_helper( dtype=np.int32, ), ) - + # continuous action space + elif isinstance(action_space, Box): + num_action_types = action_space.shape[0] + if num_action_types == 1: + tensor_feed.add_data( + name=_ACTIONS + policy_suffix, + data=np.zeros( + (num_envs, num_agents, 1), + dtype=np.float32, + ), + ) + else: + # Add separate placeholders for each type of action space. + # This is required since our sampler will be invoked for each + # action dimension separately. + for action_type_id in range(num_action_types): + tensor_feed.add_data( + name=f"{_ACTIONS}_{action_type_id}" + policy_suffix, + data=np.zeros( + (num_envs, num_agents, 1), + dtype=np.float32, + ), + ) + tensor_feed.add_data( + name=_ACTIONS + policy_suffix, + data=np.zeros( + ( + num_envs, + num_agents, + ) + + (num_action_types,), + dtype=np.float32, + ), + ) else: raise NotImplementedError( - "Only 'Discrete' or 'MultiDiscrete' type action spaces are supported!" + "Only 'Discrete', 'MultiDiscrete' or 'Box' type action spaces are supported!" ) env_wrapper.cuda_data_manager.push_data_to_device( @@ -487,9 +522,14 @@ def _create_action_batches_helper( num_agents = len(agent_ids) first_agent_id = agent_ids[0] action_space = env_wrapper.env.action_space[first_agent_id] + action_dtype = np.int32 if isinstance(action_space, MultiDiscrete): action_dim = action_space.nvec num_action_types = len(action_dim) + # continuous action space + elif isinstance(action_space, Box): + num_action_types = action_space.shape[0] + action_dtype = np.float32 else: num_action_types = 1 @@ -504,7 +544,7 @@ def _create_action_batches_helper( num_agents, ) + (num_action_types,), - dtype=np.int32, + dtype=action_dtype, ), ) diff --git a/warp_drive/utils/numba_utils/misc.py b/warp_drive/utils/numba_utils/misc.py index e2f473d..855af97 100644 --- a/warp_drive/utils/numba_utils/misc.py +++ b/warp_drive/utils/numba_utils/misc.py @@ -20,6 +20,7 @@ def get_default_env_directory(env_name): "ClassicControlCartPoleEnv": "example_envs.single_agent.classic_control.cartpole.cartpole_step_numba", "ClassicControlMountainCarEnv": "example_envs.single_agent.classic_control.mountain_car.mountain_car_step_numba", "ClassicControlAcrobotEnv": "example_envs.single_agent.classic_control.acrobot.acrobot_step_numba", + "ClassicControlPendulumEnv": "example_envs.single_agent.classic_control.pendulum.pendulum_step_numba", "YOUR_ENVIRONMENT": "PYTHON_PATH_TO_YOUR_ENV_SRC", } return envs.get(env_name, None) From ff9e7448603e49f6f7fdfb2a721db3c595377be4 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Tue, 19 Dec 2023 21:39:18 -0800 Subject: [PATCH 02/19] sampler class --- warp_drive/managers/function_manager.py | 17 ++++- .../numba_managers/numba_function_manager.py | 74 +++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/warp_drive/managers/function_manager.py b/warp_drive/managers/function_manager.py index 1fe2a68..95244e7 100644 --- a/warp_drive/managers/function_manager.py +++ b/warp_drive/managers/function_manager.py @@ -168,7 +168,11 @@ def init_random(self, seed: Optional[int] = None): raise NotImplementedError def register_actions( - self, data_manager: CUDADataManager, action_name: str, num_actions: int + self, + data_manager: CUDADataManager, + action_name: str, + num_actions: int, + is_continuous=False, ): """ Register an action @@ -177,13 +181,21 @@ def register_actions( record the sampled actions :param num_actions: the number of actions for this action_name (the last dimension of the action distribution) + :param is_continuous: discrete or continuous action """ n_agents = data_manager.get_shape(action_name)[1] + if is_continuous: + num_actions = 1 host_array = np.zeros( shape=(self._grid[0], n_agents, num_actions), dtype=np.float32 ) data_feed = DataFeed() - data_feed.add_data(name=f"{action_name}_cum_distr", data=host_array) + if is_continuous: + # add ou noise data array + data_feed.add_data(name=f"{action_name}_ou_state", data=host_array) + else: + # add cumulative distribution data array + data_feed.add_data(name=f"{action_name}_cum_distr", data=host_array) data_manager.push_data_to_device(data_feed) def sample( @@ -191,6 +203,7 @@ def sample( data_manager: CUDADataManager, distribution: torch.Tensor, action_name: str, + **kwargs, ): raise NotImplementedError diff --git a/warp_drive/managers/numba_managers/numba_function_manager.py b/warp_drive/managers/numba_managers/numba_function_manager.py index df7ce33..6a9dcee 100644 --- a/warp_drive/managers/numba_managers/numba_function_manager.py +++ b/warp_drive/managers/numba_managers/numba_function_manager.py @@ -695,3 +695,77 @@ def _reset_log_mask(self, data_manager: NumbaDataManager): data_manager.device_data("_log_mask_"), data_manager.meta_info("episode_length"), ) + + +class NumbaOUProcess(CUDASampler): + + def __init__(self, function_manager: NumbaFunctionManager): + """ + :param function_manager: CUDAFunctionManager object + """ + super().__init__(function_manager) + + self.sample_ou_process = self._function_manager.get_function("sample_ou_process") + + self.rng_states_dict = {} + + def init_random(self, seed: Optional[int] = None): + """ + Init random function for all the threads + :param seed: random seed selected for the initialization + """ + if seed is None: + seed = time.time() + logging.info( + f"random seed is not provided, by default, " + f"using the current timestamp {seed} as seed" + ) + seed = np.int32(seed) + xoroshiro128p_dtype = np.dtype( + [("s0", np.uint64), ("s1", np.uint64)], align=True + ) + sz = self._function_manager._num_envs * self._function_manager._num_agents + rng_states = numba_driver.device_array(sz, dtype=xoroshiro128p_dtype) + init = self._function_manager.get_function("init_random") + init(rng_states, seed) + self.rng_states_dict["rng_states"] = rng_states + self._random_initialized = True + + def sample( + self, + data_manager: NumbaDataManager, + distribution: torch.Tensor, + action_name: str, + **kwargs, + ): + """ + Sample continuous actions based on the Ornstein–Uhlenbeck process + + :param data_manager: NumbaDataManager object + :param distribution: Torch tensor of deterministic action in the shape of + (num_env, num_agents, 1) + :param action_name: the name of action array that will + record the sampled actions + """ + assert self._random_initialized, ( + "sample() requires the random seed initialized first, " + "please call init_random()" + ) + assert torch.is_tensor(distribution) + assert distribution.shape[0] == self._num_envs + n_agents = int(distribution.shape[1]) + assert data_manager.get_shape(action_name)[1] == n_agents + assert data_manager.get_shape(f"{action_name}_ou_state")[2] == 1 + + # distribution is a runtime output from pytorch at device, + # it should not be managed by data manager because + # it is a temporary output and never sit at the host + self.sample_ou_process[ + self._grid, (int((n_agents - 1) // self._blocks_per_env + 1), 1, 1) + ]( + self.rng_states_dict["rng_states"], + numba_driver.as_cuda_array(distribution.detach()), + data_manager.device_data(action_name), + data_manager.device_data(f"{action_name}_ou_state"), + ) + From e4bd828958c5d0575dbf88abe5d64f235435738d Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Mon, 25 Dec 2023 18:35:19 -0800 Subject: [PATCH 03/19] ou sampler --- .../warp_drive/numba_tests/test_ou_sampler.py | 82 +++++++++++++++++++ .../numba_managers/numba_function_manager.py | 19 ++++- warp_drive/numba_includes/core/random.py | 44 +++++++++- 3 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 tests/warp_drive/numba_tests/test_ou_sampler.py diff --git a/tests/warp_drive/numba_tests/test_ou_sampler.py b/tests/warp_drive/numba_tests/test_ou_sampler.py new file mode 100644 index 0000000..d704f2b --- /dev/null +++ b/tests/warp_drive/numba_tests/test_ou_sampler.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +import unittest + +import numpy as np +import torch + +from warp_drive.managers.numba_managers.numba_data_manager import NumbaDataManager +from warp_drive.managers.numba_managers.numba_function_manager import ( + NumbaFunctionManager, + NumbaOUProcess, +) +from warp_drive.utils.common import get_project_root +from warp_drive.utils.constants import Constants +from warp_drive.utils.data_feed import DataFeed + +_NUMBA_FILEPATH = f"warp_drive.numba_includes" +_ACTIONS = Constants.ACTIONS + + +class TestOUProcessSampler(unittest.TestCase): + """ + Unit tests for the CUDA action sampler + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dm = NumbaDataManager(num_agents=5, episode_length=1, num_envs=1000) + self.fm = NumbaFunctionManager( + num_agents=int(self.dm.meta_info("n_agents")), + num_envs=int(self.dm.meta_info("n_envs")), + ) + self.fm.import_numba_from_source_code(f"{_NUMBA_FILEPATH}.test_build") + self.sampler = NumbaOUProcess(function_manager=self.fm) + self.sampler.init_random(seed=None) + + def test_variation(self): + tensor = DataFeed() + tensor.add_data(name=f"{_ACTIONS}_a", data=np.zeros((1000, 5, 1), dtype=np.float32)) + self.dm.push_data_to_device(tensor, torch_accessible=True) + self.sampler.register_actions(self.dm, f"{_ACTIONS}_a", 1, is_continuous=True) + + # deterministic agent actions + agent_distribution = np.zeros((1000, 5, 1), dtype=np.float32) + agent_distribution = torch.from_numpy(agent_distribution) + agent_distribution = agent_distribution.float().cuda() + + actions_a_cuda = torch.from_numpy( + np.empty((10000, 1000, 5), dtype=np.float32) + ).cuda() + + damping = 0.15 + stddev = 0.2 + for i in range(10000): + self.sampler.sample(self.dm, agent_distribution, action_name=f"{_ACTIONS}_a", damping=damping, stddev=stddev,) + actions_a_cuda[i] = self.dm.data_on_device_via_torch(f"{_ACTIONS}_a")[:, :, 0] + actions_a = actions_a_cuda.cpu().numpy() + + var_list = [] + for i in range(100, 10000): + var_list.append(actions_a[i].flatten().std()) + var_mean = np.array(var_list).mean() + + var_theory = stddev/(1 - (1 - damping)**2)**0.5 + + self.assertAlmostEqual(var_mean, var_theory, delta=0.001) + + cov_list = [] + # test the cov of step difference of 10 + # stddev^2/(1-(1-damping)^2)*(1-damping)^(n-k)*[1-(1-damping)^(n+k)] + # roughly it is stddev^2/(1-(1-damping)^2)*(1-damping)^(n-k) + for i in range(100, 9990): + cov_list.append(np.cov(actions_a[i].flatten(), actions_a[i + 10].flatten())[0, 1]) + cov_mean = np.array(cov_list).mean() + + cov_theory = stddev**2 / (1 - (1 - damping)**2) * (1 - damping)**10 + + self.assertAlmostEqual(cov_mean, cov_theory, delta=0.001) diff --git a/warp_drive/managers/numba_managers/numba_function_manager.py b/warp_drive/managers/numba_managers/numba_function_manager.py index 6a9dcee..eb6152a 100644 --- a/warp_drive/managers/numba_managers/numba_function_manager.py +++ b/warp_drive/managers/numba_managers/numba_function_manager.py @@ -193,6 +193,7 @@ def initialize_default_functions(self): "log_one_step_3d", "init_random", "sample_actions", + "sample_ou_process", "reset_when_done_1d", "reset_when_done_2d", "reset_when_done_3d", @@ -731,12 +732,22 @@ def init_random(self, seed: Optional[int] = None): self.rng_states_dict["rng_states"] = rng_states self._random_initialized = True + def reset_state(self, data_manager: NumbaDataManager, action_name: str,): + host_array = np.zeros( + shape=data_manager.get_shape(f"{action_name}_ou_state"), dtype=np.float32 + ) + data_feed = DataFeed() + data_feed.add_data(name=f"{action_name}_ou_state", data=host_array) + data_manager.push_data_to_device(data_feed) + def sample( self, data_manager: NumbaDataManager, distribution: torch.Tensor, action_name: str, - **kwargs, + damping=0.15, + stddev=0.2, + scale=1.0 ): """ Sample continuous actions based on the Ornstein–Uhlenbeck process @@ -746,6 +757,9 @@ def sample( (num_env, num_agents, 1) :param action_name: the name of action array that will record the sampled actions + :param damping: damping factor for OU process + :param stddev: standard dev for normal process + :param scale: scale of ou process """ assert self._random_initialized, ( "sample() requires the random seed initialized first, " @@ -767,5 +781,8 @@ def sample( numba_driver.as_cuda_array(distribution.detach()), data_manager.device_data(action_name), data_manager.device_data(f"{action_name}_ou_state"), + np.float32(damping), + np.float32(stddev), + np.float32(scale), ) diff --git a/warp_drive/numba_includes/core/random.py b/warp_drive/numba_includes/core/random.py index d338b19..5a24031 100644 --- a/warp_drive/numba_includes/core/random.py +++ b/warp_drive/numba_includes/core/random.py @@ -1,6 +1,6 @@ from numba import cuda as numba_driver from numba import float32, int32, boolean, from_dtype -from numba.cuda.random import init_xoroshiro128p_states, xoroshiro128p_uniform_float32 +from numba.cuda.random import init_xoroshiro128p_states, xoroshiro128p_uniform_float32, xoroshiro128p_normal_float32 import numpy as np kEps = 1.0e-8 @@ -61,3 +61,45 @@ def sample_actions(rng_states, distr, action_indices, cum_distr, num_actions, us ind = search_index(cum_distr, p, env_id, agent_id, num_actions - 1) # action_indices in the shape of [n_env, n_agent, 1] action_indices[env_id, agent_id, 0] = ind + + +@numba_driver.jit((xoroshiro128p_type[::1], + float32[:, :, ::1], + float32[:, :, ::1], + float32[:, :, ::1], + float32, + float32, + float32) + ) +def sample_ou_process( + rng_states, + distr, + actions, + ou_states, + damping=0.15, + stddev=0.2, + scale=1.0): + + # The temporal noise update equation is: + # ou_next = (1 - damping) * ou + N(0, std_dev)` + # ou = ou_next + # action = distr + scale * ou + + env_id = numba_driver.blockIdx.x + agent_id = numba_driver.threadIdx.x + + if scale < kEps: + # there is no random noise, assign the model deterministic distribution to action directly + actions[env_id, agent_id, 0] = distr[env_id, agent_id, 0] + return + + posidx = numba_driver.grid(1) + if posidx >= rng_states.shape[0]: + return + + normal_var = xoroshiro128p_normal_float32(rng_states, posidx) + normal_var = stddev * normal_var # an normal noise with std = stddev and mean 0 + + ou_states[env_id, agent_id, 0] = (1.0 - damping) * ou_states[env_id, agent_id, 0] + normal_var + + actions[env_id, agent_id, 0] = distr[env_id, agent_id, 0] + scale * ou_states[env_id, agent_id, 0] From 79be045b3bd305afaf56bf4813a26b5d39c7bbca Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Sun, 31 Dec 2023 22:40:35 -0800 Subject: [PATCH 04/19] model update; factorize sampler --- .../warp_drive/numba_tests/test_ou_sampler.py | 6 +- warp_drive/managers/function_manager.py | 8 +-- .../numba_managers/numba_function_manager.py | 55 ++++++++++++++----- warp_drive/training/models/fully_connected.py | 51 +++++++++-------- warp_drive/training/utils/data_loader.py | 19 +++++++ 5 files changed, 97 insertions(+), 42 deletions(-) diff --git a/tests/warp_drive/numba_tests/test_ou_sampler.py b/tests/warp_drive/numba_tests/test_ou_sampler.py index d704f2b..43a028a 100644 --- a/tests/warp_drive/numba_tests/test_ou_sampler.py +++ b/tests/warp_drive/numba_tests/test_ou_sampler.py @@ -12,7 +12,7 @@ from warp_drive.managers.numba_managers.numba_data_manager import NumbaDataManager from warp_drive.managers.numba_managers.numba_function_manager import ( NumbaFunctionManager, - NumbaOUProcess, + NumbaSampler, ) from warp_drive.utils.common import get_project_root from warp_drive.utils.constants import Constants @@ -35,14 +35,14 @@ def __init__(self, *args, **kwargs): num_envs=int(self.dm.meta_info("n_envs")), ) self.fm.import_numba_from_source_code(f"{_NUMBA_FILEPATH}.test_build") - self.sampler = NumbaOUProcess(function_manager=self.fm) + self.sampler = NumbaSampler(function_manager=self.fm) self.sampler.init_random(seed=None) def test_variation(self): tensor = DataFeed() tensor.add_data(name=f"{_ACTIONS}_a", data=np.zeros((1000, 5, 1), dtype=np.float32)) self.dm.push_data_to_device(tensor, torch_accessible=True) - self.sampler.register_actions(self.dm, f"{_ACTIONS}_a", 1, is_continuous=True) + self.sampler.register_actions(self.dm, f"{_ACTIONS}_a", 1, is_deterministic=True) # deterministic agent actions agent_distribution = np.zeros((1000, 5, 1), dtype=np.float32) diff --git a/warp_drive/managers/function_manager.py b/warp_drive/managers/function_manager.py index 95244e7..1bb0dfc 100644 --- a/warp_drive/managers/function_manager.py +++ b/warp_drive/managers/function_manager.py @@ -172,7 +172,7 @@ def register_actions( data_manager: CUDADataManager, action_name: str, num_actions: int, - is_continuous=False, + is_deterministic=False, ): """ Register an action @@ -181,16 +181,16 @@ def register_actions( record the sampled actions :param num_actions: the number of actions for this action_name (the last dimension of the action distribution) - :param is_continuous: discrete or continuous action + :param is_deterministic: if True: deterministic action, usually it means continuous action like for DDPG """ n_agents = data_manager.get_shape(action_name)[1] - if is_continuous: + if is_deterministic: num_actions = 1 host_array = np.zeros( shape=(self._grid[0], n_agents, num_actions), dtype=np.float32 ) data_feed = DataFeed() - if is_continuous: + if is_deterministic: # add ou noise data array data_feed.add_data(name=f"{action_name}_ou_state", data=host_array) else: diff --git a/warp_drive/managers/numba_managers/numba_function_manager.py b/warp_drive/managers/numba_managers/numba_function_manager.py index eb6152a..fc0f6b7 100644 --- a/warp_drive/managers/numba_managers/numba_function_manager.py +++ b/warp_drive/managers/numba_managers/numba_function_manager.py @@ -267,6 +267,8 @@ def __init__(self, function_manager: NumbaFunctionManager): self.sample_actions = self._function_manager.get_function("sample_actions") + self.sample_ou_process = self._function_manager.get_function("sample_ou_process") + self.rng_states_dict = {} def init_random(self, seed: Optional[int] = None): @@ -296,7 +298,7 @@ def sample( data_manager: NumbaDataManager, distribution: torch.Tensor, action_name: str, - use_argmax: bool = False, + **kwargs, ): """ Sample based on the distribution @@ -306,7 +308,12 @@ def sample( (num_env, num_agents, num_actions) :param action_name: the name of action array that will record the sampled actions - :param use_argmax: if True, sample based on the argmax(distribution) + :param kwargs: + kwargs["use_argmax"] = True, sample based on the argmax(distribution) + kwargs["damping"] + kwargs["stddev"] + kwargs["scale"] + """ assert self._random_initialized, ( "sample() requires the random seed initialized first, " @@ -317,21 +324,43 @@ def sample( n_agents = int(distribution.shape[1]) assert data_manager.get_shape(action_name)[1] == n_agents n_actions = distribution.shape[2] - assert data_manager.get_shape(f"{action_name}_cum_distr")[2] == n_actions # distribution is a runtime output from pytorch at device, # it should not be managed by data manager because # it is a temporary output and never sit at the host - self.sample_actions[ - self._grid, (int((n_agents - 1) // self._blocks_per_env + 1), 1, 1) - ]( - self.rng_states_dict["rng_states"], - numba_driver.as_cuda_array(distribution.detach()), - data_manager.device_data(action_name), - data_manager.device_data(f"{action_name}_cum_distr"), - np.int32(n_actions), - np.int32(use_argmax), - ) + if n_actions > 1: + # has a probability distribution over multiple discrete actions + assert data_manager.get_shape(f"{action_name}_cum_distr")[2] == n_actions + + use_argmax = kwargs.get("use_argmax", False) + + self.sample_actions[ + self._grid, (int((n_agents - 1) // self._blocks_per_env + 1), 1, 1) + ]( + self.rng_states_dict["rng_states"], + numba_driver.as_cuda_array(distribution.detach()), + data_manager.device_data(action_name), + data_manager.device_data(f"{action_name}_cum_distr"), + np.int32(n_actions), + np.int32(use_argmax), + ) + else: + # deterministic action + damping = kwargs.get("damping", 0.15) + stddev = kwargs.get("stddev", 0.2) + scale = kwargs.get("scale", 1.0) + + self.sample_ou_process[ + self._grid, (int((n_agents - 1) // self._blocks_per_env + 1), 1, 1) + ]( + self.rng_states_dict["rng_states"], + numba_driver.as_cuda_array(distribution.detach()), + data_manager.device_data(action_name), + data_manager.device_data(f"{action_name}_ou_state"), + np.float32(damping), + np.float32(stddev), + np.float32(scale), + ) class NumbaEnvironmentReset(CUDAEnvironmentReset): diff --git a/warp_drive/training/models/fully_connected.py b/warp_drive/training/models/fully_connected.py index fa85cf5..cc8f841 100644 --- a/warp_drive/training/models/fully_connected.py +++ b/warp_drive/training/models/fully_connected.py @@ -73,12 +73,16 @@ def __init__( sample_agent_id = self.policy_tag_to_agent_id_map[self.policy][0] # Flatten obs space self.observation_space = self.env.env.observation_space[sample_agent_id] - self.flattened_obs_size = self.get_flattened_obs_size(self.observation_space) - + self.flattened_obs_size = get_flattened_obs_size(self.observation_space) + self.is_deterministic = False if isinstance(self.env.env.action_space[sample_agent_id], Discrete): action_space = [self.env.env.action_space[sample_agent_id].n] elif isinstance(self.env.env.action_space[sample_agent_id], MultiDiscrete): action_space = self.env.env.action_space[sample_agent_id].nvec + elif isinstance(self.env.env.action_space[sample_agent_id], Box): + # deterministic action space + action_space = [1] * self.env.env.action_space[sample_agent_id].shape[0] + self.is_deterministic = True else: raise NotImplementedError @@ -93,12 +97,16 @@ def __init__( ) # policy network (list of heads) - policy_heads = [None for _ in range(len(action_space))] - self.output_dims = [] # Network output dimension(s) - for idx, act_space in enumerate(action_space): - self.output_dims += [act_space] - policy_heads[idx] = nn.Linear(fc_dims[-1], act_space) - self.policy_head = nn.ModuleList(policy_heads) + if self.is_deterministic: + self.output_dims = [len(action_space)] + self.policy_head = nn.Linear(fc_dims, len(action_space)) + else: + policy_heads = [None for _ in range(len(action_space))] + self.output_dims = [] # Network output dimension(s) + for idx, act_space in enumerate(action_space): + self.output_dims += [act_space] + policy_heads[idx] = nn.Linear(fc_dims[-1], act_space) + self.policy_head = nn.ModuleList(policy_heads) # value-function network head self.vf_head = nn.Linear(fc_dims[-1], 1) @@ -110,10 +118,6 @@ def __init__( name = f"{_PROCESSED_OBSERVATIONS}_batch_{self.policy}" self.batch_size = self.env.cuda_data_manager.get_shape(name=name)[0] - def get_flattened_obs_size(self, observation_space): - """Get the total size of the observations after flattening""" - return get_flattened_obs_size(observation_space) - def reshape_and_flatten_obs(self, obs): """ # Note: WarpDrive assumes that all the observation are shaped @@ -225,16 +229,19 @@ def forward(self, obs=None, batch_index=None): # Compute the action probabilities and the value function estimate # Apply action mask to the logits as well. - action_masks = [None for _ in range(len(self.output_dims))] - if self.action_mask is not None: - start = 0 - for idx, dim in enumerate(self.output_dims): - action_masks[idx] = self.action_mask[..., start : start + dim] - start = start + dim - action_probs = [ - func.softmax(apply_logit_mask(ph(logits), action_masks[idx]), dim=-1) - for idx, ph in enumerate(self.policy_head) - ] + if self.is_deterministic: + action_probs = func.tanh(apply_logit_mask(self.policy_head(logits), self.action_mask)) + else: + action_masks = [None for _ in range(len(self.output_dims))] + if self.action_mask is not None: + start = 0 + for idx, dim in enumerate(self.output_dims): + action_masks[idx] = self.action_mask[..., start : start + dim] + start = start + dim + action_probs = [ + func.softmax(apply_logit_mask(ph(logits), action_masks[idx]), dim=-1) + for idx, ph in enumerate(self.policy_head) + ] vals = self.vf_head(logits)[..., 0] return action_probs, vals diff --git a/warp_drive/training/utils/data_loader.py b/warp_drive/training/utils/data_loader.py index b785a45..b0bee5e 100644 --- a/warp_drive/training/utils/data_loader.py +++ b/warp_drive/training/utils/data_loader.py @@ -569,6 +569,7 @@ def _prepare_action_sampler_helper( env_wrapper.cuda_data_manager, action_name=_ACTIONS + policy_suffix, num_actions=action_dim, + is_deterministic=False, ) elif isinstance(action_space, MultiDiscrete): action_dim = action_space.nvec @@ -578,7 +579,25 @@ def _prepare_action_sampler_helper( env_wrapper.cuda_data_manager, action_name=f"{_ACTIONS}_{action_type_id}" + policy_suffix, num_actions=action_type_dim, + is_deterministic=False, ) + elif isinstance(action_space, Box): + num_action_types = action_space.shape[0] + if num_action_types == 1: + action_sampler.register_actions( + env_wrapper.cuda_data_manager, + action_name=_ACTIONS + policy_suffix, + num_actions=1, + is_deterministic=True, + ) + else: + for action_type_id in range(num_action_types): + action_sampler.register_actions( + env_wrapper.cuda_data_manager, + action_name=f"{_ACTIONS}_{action_type_id}" + policy_suffix, + num_actions=1, + is_deterministic=True, + ) else: raise NotImplementedError( "Only 'Discrete' or 'MultiDiscrete' type action spaces are supported!" From 3bf6bf060474577c98e46d77c35b173819ed9d30 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Mon, 1 Jan 2024 21:54:34 -0800 Subject: [PATCH 05/19] change --- .../managers/numba_managers/numba_function_manager.py | 1 + warp_drive/training/models/fully_connected.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/warp_drive/managers/numba_managers/numba_function_manager.py b/warp_drive/managers/numba_managers/numba_function_manager.py index fc0f6b7..8e716d4 100644 --- a/warp_drive/managers/numba_managers/numba_function_manager.py +++ b/warp_drive/managers/numba_managers/numba_function_manager.py @@ -320,6 +320,7 @@ def sample( "please call init_random()" ) assert torch.is_tensor(distribution) + assert distribution.is_contiguous(), "distribution is required to be C contiguous" assert distribution.shape[0] == self._num_envs n_agents = int(distribution.shape[1]) assert data_manager.get_shape(action_name)[1] == n_agents diff --git a/warp_drive/training/models/fully_connected.py b/warp_drive/training/models/fully_connected.py index cc8f841..9e790e2 100644 --- a/warp_drive/training/models/fully_connected.py +++ b/warp_drive/training/models/fully_connected.py @@ -230,7 +230,15 @@ def forward(self, obs=None, batch_index=None): # Compute the action probabilities and the value function estimate # Apply action mask to the logits as well. if self.is_deterministic: - action_probs = func.tanh(apply_logit_mask(self.policy_head(logits), self.action_mask)) + combined_action_probs = func.tanh(apply_logit_mask(self.policy_head(logits), self.action_mask)) + if self.output_dims[0] > 1: + # we split the actions to their corresponding heads + # we make sure after the split, we rearrange the memory so each chunk is still C-continguous + # otherwise the sampler may have index issue + action_probs = list(torch.split(combined_action_probs, 1, dim=-1)) + action_probs = [ap.contiguous() for ap in action_probs] + else: + action_probs = [combined_action_probs] else: action_masks = [None for _ in range(len(self.output_dims))] if self.action_mask is not None: From 5bbcd101a2917c0f8bd816e8104aaa26417c611c Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Wed, 3 Jan 2024 21:18:23 -0800 Subject: [PATCH 06/19] sample configuration --- warp_drive/managers/function_manager.py | 2 +- .../numba_managers/numba_function_manager.py | 20 ++++++++--------- warp_drive/training/trainer.py | 22 ++++++++++--------- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/warp_drive/managers/function_manager.py b/warp_drive/managers/function_manager.py index 1bb0dfc..c521fae 100644 --- a/warp_drive/managers/function_manager.py +++ b/warp_drive/managers/function_manager.py @@ -203,7 +203,7 @@ def sample( data_manager: CUDADataManager, distribution: torch.Tensor, action_name: str, - **kwargs, + **sample_params, ): raise NotImplementedError diff --git a/warp_drive/managers/numba_managers/numba_function_manager.py b/warp_drive/managers/numba_managers/numba_function_manager.py index 8e716d4..43894b5 100644 --- a/warp_drive/managers/numba_managers/numba_function_manager.py +++ b/warp_drive/managers/numba_managers/numba_function_manager.py @@ -298,7 +298,7 @@ def sample( data_manager: NumbaDataManager, distribution: torch.Tensor, action_name: str, - **kwargs, + **sample_params, ): """ Sample based on the distribution @@ -308,11 +308,11 @@ def sample( (num_env, num_agents, num_actions) :param action_name: the name of action array that will record the sampled actions - :param kwargs: - kwargs["use_argmax"] = True, sample based on the argmax(distribution) - kwargs["damping"] - kwargs["stddev"] - kwargs["scale"] + :param sample_params: + sample_params["use_argmax"] = True, sample based on the argmax(distribution) + sample_params["damping"] + sample_params["stddev"] + sample_params["scale"] """ assert self._random_initialized, ( @@ -333,7 +333,7 @@ def sample( # has a probability distribution over multiple discrete actions assert data_manager.get_shape(f"{action_name}_cum_distr")[2] == n_actions - use_argmax = kwargs.get("use_argmax", False) + use_argmax = sample_params.get("use_argmax", False) self.sample_actions[ self._grid, (int((n_agents - 1) // self._blocks_per_env + 1), 1, 1) @@ -347,9 +347,9 @@ def sample( ) else: # deterministic action - damping = kwargs.get("damping", 0.15) - stddev = kwargs.get("stddev", 0.2) - scale = kwargs.get("scale", 1.0) + damping = sample_params.get("damping", 0.15) + stddev = sample_params.get("stddev", 0.2) + scale = sample_params.get("scale", 1.0) self.sample_ou_process[ self._grid, (int((n_agents - 1) // self._blocks_per_env + 1), 1, 1) diff --git a/warp_drive/training/trainer.py b/warp_drive/training/trainer.py index 4c2fc9e..3a002ce 100644 --- a/warp_drive/training/trainer.py +++ b/warp_drive/training/trainer.py @@ -147,6 +147,9 @@ def __init__( self.config["policy"][key] = recursive_merge_config_dicts( self.config["policy"][key], default_config["policy"] ) + # Sampler-related configurations (usually Optional) + self.sample_params = self._get_config(["sampler", "params"]) if "sampler" in self.config else {} + # Saving-related configurations self.config["saving"] = recursive_merge_config_dicts( self.config["saving"], default_config["saving"] @@ -448,7 +451,7 @@ def _generate_rollout_batch(self): # Sample actions using the computed probabilities # and push to the batch of actions start_event.record() - self._sample_actions(probabilities, batch_index=batch_index) + self._sample_actions(probabilities, batch_index=batch_index, **self.sample_params) end_event.record() torch.cuda.synchronize() self.perf_stats.action_sample_time += ( @@ -522,7 +525,7 @@ def _evaluate_policies(self, batch_index=0): return probabilities - def _sample_actions(self, probabilities, batch_index=0, use_argmax=False): + def _sample_actions(self, probabilities, batch_index=0, **sample_params): """ Sample action probabilities (and push the sampled actions to the device). """ @@ -532,7 +535,7 @@ def _sample_actions(self, probabilities, batch_index=0, use_argmax=False): # Sample each individual policy policy_suffix = f"_{policy}" self._sample_actions_helper( - probabilities[policy], policy_suffix=policy_suffix, use_argmax=use_argmax + probabilities[policy], policy_suffix=policy_suffix, **sample_params ) # Push the actions to the batch actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( @@ -545,7 +548,7 @@ def _sample_actions(self, probabilities, batch_index=0, use_argmax=False): assert len(probabilities) == 1 policy = list(probabilities.keys())[0] # sample a single or a combined policy - self._sample_actions_helper(probabilities[policy], use_argmax=use_argmax) + self._sample_actions_helper(probabilities[policy], **sample_params) actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( _ACTIONS ) @@ -563,20 +566,20 @@ def _sample_actions(self, probabilities, batch_index=0, use_argmax=False): name=f"{_ACTIONS}_batch_{policy}" )[batch_index] = actions - def _sample_actions_helper(self, probabilities, policy_suffix="", use_argmax=False): + def _sample_actions_helper(self, probabilities, policy_suffix="", **sample_params): # Sample actions with policy_suffix tag num_action_types = len(probabilities) if num_action_types == 1: action_name = _ACTIONS + policy_suffix self.cuda_sample_controller.sample( - self.cuda_envs.cuda_data_manager, probabilities[0], action_name, use_argmax + self.cuda_envs.cuda_data_manager, probabilities[0], action_name, **sample_params ) else: for action_type_id, probs in enumerate(probabilities): action_name = f"{_ACTIONS}_{action_type_id}" + policy_suffix self.cuda_sample_controller.sample( - self.cuda_envs.cuda_data_manager, probs, action_name, use_argmax + self.cuda_envs.cuda_data_manager, probs, action_name, **sample_params ) # Push (indexed) actions to 'actions' actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( @@ -759,7 +762,6 @@ def _update_model_params(self, iteration): ) self.num_completed_episodes[policy] = 0 - end_event.record() torch.cuda.synchronize() @@ -901,7 +903,7 @@ def fetch_episode_states( env_id=0, # environment id to fetch the states from include_rewards_actions=False, # flag to output reward and action policy="", # if include_rewards_actions=True, the corresponding policy tag if any - use_argmax=False, + **sample_params ): """ Step through env and fetch the desired states (data arrays on the GPU) @@ -959,7 +961,7 @@ def fetch_episode_states( probabilities = self._evaluate_policies(batch_index=-1) # Sample actions - self._sample_actions(probabilities, use_argmax=use_argmax) + self._sample_actions(probabilities, **sample_params) # Step through all the environments self.cuda_envs.step_all_envs() From 7b96281d0f9884112d6d56aa847777a472f9172d Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Fri, 12 Jan 2024 22:12:17 -0800 Subject: [PATCH 07/19] action value network --- warp_drive/training/models/factory.py | 4 + warp_drive/training/models/fully_connected.py | 200 ++--------------- .../models/fully_connected_actor_critic.py | 138 ++++++++++++ warp_drive/training/models/model_base.py | 207 ++++++++++++++++++ warp_drive/training/trainer.py | 13 +- 5 files changed, 375 insertions(+), 187 deletions(-) create mode 100644 warp_drive/training/models/fully_connected_actor_critic.py create mode 100644 warp_drive/training/models/model_base.py diff --git a/warp_drive/training/models/factory.py b/warp_drive/training/models/factory.py index 8bb039c..afa746b 100644 --- a/warp_drive/training/models/factory.py +++ b/warp_drive/training/models/factory.py @@ -3,6 +3,10 @@ # warpdrive reserved models default_models = { "fully_connected": "warp_drive.training.models.fully_connected:FullyConnected", + "fully_connected_actor": + "warp_drive.training.models.fully_connected_actor_critic:FullyConnectedActor", + "fully_connected_action_value_critic": + "warp_drive.training.models.fully_connected_actor_critic:FullyConnectedActionValueCritic", } diff --git a/warp_drive/training/models/fully_connected.py b/warp_drive/training/models/fully_connected.py index 9e790e2..12fd348 100644 --- a/warp_drive/training/models/fully_connected.py +++ b/warp_drive/training/models/fully_connected.py @@ -11,36 +11,13 @@ import numpy as np import torch import torch.nn.functional as func -from gym.spaces import Box, Dict, Discrete, MultiDiscrete from torch import nn +from warp_drive.training.models.model_base import ModelBaseFullyConnected, apply_logit_mask -from warp_drive.utils.constants import Constants -from warp_drive.utils.data_feed import DataFeed -from warp_drive.training.utils.data_loader import get_flattened_obs_size -_OBSERVATIONS = Constants.OBSERVATIONS -_PROCESSED_OBSERVATIONS = Constants.PROCESSED_OBSERVATIONS -_ACTION_MASK = Constants.ACTION_MASK - -_LARGE_NEG_NUM = -1e20 - - -def apply_logit_mask(logits, mask=None): - """ - Mask values of 1 are valid actions. - Add huge negative values to logits with 0 mask values. - """ - if mask is None: - return logits - - logit_mask = torch.ones_like(logits) * _LARGE_NEG_NUM - logit_mask = logit_mask * (1 - mask) - return logits + logit_mask - - -# Policy networks +# Policy + Value networks # --------------- -class FullyConnected(nn.Module): +class FullyConnected(ModelBaseFullyConnected): """ Fully connected network implementation in Pytorch """ @@ -56,171 +33,27 @@ def __init__( create_separate_placeholders_for_each_policy=False, obs_dim_corresponding_to_num_agents="first", ): - super().__init__() - - self.env = env - fc_dims = model_config["fc_dims"] - assert isinstance(fc_dims, list) - num_fc_layers = len(fc_dims) - self.policy = policy - self.policy_tag_to_agent_id_map = policy_tag_to_agent_id_map - self.create_separate_placeholders_for_each_policy = ( - create_separate_placeholders_for_each_policy - ) - assert obs_dim_corresponding_to_num_agents in ["first", "last"] - self.obs_dim_corresponding_to_num_agents = obs_dim_corresponding_to_num_agents - - sample_agent_id = self.policy_tag_to_agent_id_map[self.policy][0] - # Flatten obs space - self.observation_space = self.env.env.observation_space[sample_agent_id] - self.flattened_obs_size = get_flattened_obs_size(self.observation_space) - self.is_deterministic = False - if isinstance(self.env.env.action_space[sample_agent_id], Discrete): - action_space = [self.env.env.action_space[sample_agent_id].n] - elif isinstance(self.env.env.action_space[sample_agent_id], MultiDiscrete): - action_space = self.env.env.action_space[sample_agent_id].nvec - elif isinstance(self.env.env.action_space[sample_agent_id], Box): - # deterministic action space - action_space = [1] * self.env.env.action_space[sample_agent_id].shape[0] - self.is_deterministic = True - else: - raise NotImplementedError - - input_dims = [self.flattened_obs_size] + fc_dims[:-1] - output_dims = fc_dims - - self.fc = nn.ModuleDict() + super().__init__(env, + model_config, + policy, + policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents,) + num_fc_layers = len(self.fc_dims) + input_dims = [self.flattened_obs_size] + self.fc_dims[:-1] + output_dims = self.fc_dims for fc_layer in range(num_fc_layers): self.fc[str(fc_layer)] = nn.Sequential( nn.Linear(input_dims[fc_layer], output_dims[fc_layer]), nn.ReLU(), ) - # policy network (list of heads) - if self.is_deterministic: - self.output_dims = [len(action_space)] - self.policy_head = nn.Linear(fc_dims, len(action_space)) - else: - policy_heads = [None for _ in range(len(action_space))] - self.output_dims = [] # Network output dimension(s) - for idx, act_space in enumerate(action_space): - self.output_dims += [act_space] - policy_heads[idx] = nn.Linear(fc_dims[-1], act_space) - self.policy_head = nn.ModuleList(policy_heads) - - # value-function network head - self.vf_head = nn.Linear(fc_dims[-1], 1) - - # used for action masking - self.action_mask = None - - # max batch size allowed - name = f"{_PROCESSED_OBSERVATIONS}_batch_{self.policy}" - self.batch_size = self.env.cuda_data_manager.get_shape(name=name)[0] - - def reshape_and_flatten_obs(self, obs): - """ - # Note: WarpDrive assumes that all the observation are shaped - # (num_agents, *feature_dim), i.e., the observation dimension - # corresponding to 'num_agents' is the first one. If the observation - # dimension corresponding to num_agents is last, we will need to - # permute the axes to align with WarpDrive's assumption. - """ - num_envs = obs.shape[0] - if self.create_separate_placeholders_for_each_policy: - num_agents = len(self.policy_tag_to_agent_id_map[self.policy]) - else: - num_agents = self.env.n_agents - - if self.obs_dim_corresponding_to_num_agents == "first": - pass - elif self.obs_dim_corresponding_to_num_agents == "last": - shape_len = len(obs.shape) - if shape_len == 1: - obs = obs.reshape(-1, num_agents) # valid only when num_agents = 1 - obs = obs.permute(0, -1, *range(1, shape_len - 1)) - else: - raise ValueError( - "num_agents can only be the first " - "or the last dimension in the observations." - ) - return obs.reshape(num_envs, num_agents, -1) - - def get_flattened_obs(self): - """ - If the obs is of Box type, it will already be flattened. - If the obs is of Dict type, then concatenate all the - obs values and flatten them out. - Returns the concatenated and flattened obs. - - """ - if isinstance(self.observation_space, Box): - if self.create_separate_placeholders_for_each_policy: - obs = self.env.cuda_data_manager.data_on_device_via_torch( - f"{_OBSERVATIONS}_{self.policy}" - ) - else: - obs = self.env.cuda_data_manager.data_on_device_via_torch(_OBSERVATIONS) - - flattened_obs = self.reshape_and_flatten_obs(obs) - elif isinstance(self.observation_space, Dict): - obs_dict = {} - for key in self.observation_space: - if self.create_separate_placeholders_for_each_policy: - obs = self.env.cuda_data_manager.data_on_device_via_torch( - f"{_OBSERVATIONS}_{self.policy}_{key}" - ) - else: - obs = self.env.cuda_data_manager.data_on_device_via_torch( - f"{_OBSERVATIONS}_{key}" - ) - - if key == _ACTION_MASK: - self.action_mask = self.reshape_and_flatten_obs(obs) - assert self.action_mask.shape[-1] == sum(self.output_dims) - else: - obs_dict[key] = obs - - flattened_obs_dict = {} - for key, value in obs_dict.items(): - flattened_obs_dict[key] = self.reshape_and_flatten_obs(value) - flattened_obs = torch.cat(list(flattened_obs_dict.values()), dim=-1) - else: - raise NotImplementedError("Observation space must be of Box or Dict type") - - assert flattened_obs.shape[-1] == self.flattened_obs_size, \ - f"The flattened observation size of {flattened_obs.shape[-1]} is different " \ - f"from the designated size of {self.flattened_obs_size} " - - return flattened_obs - - def forward(self, obs=None, batch_index=None): + def forward(self, obs=None, action=None): """ Forward pass through the model. Returns action probabilities and value functions. """ - if obs is None: - assert batch_index < self.batch_size, f"batch_index: {batch_index}, self.batch_size: {self.batch_size}" - # Read in observation from the placeholders and flatten them - # before passing through the fully connected layers. - # This is particularly relevant if the observations space is a Dict. - obs = self.get_flattened_obs() - - if self.create_separate_placeholders_for_each_policy: - ip = obs - else: - agent_ids_for_policy = self.policy_tag_to_agent_id_map[self.policy] - ip = obs[:, agent_ids_for_policy] - - # Push the processed (in this case, flattened) obs to the GPU (device). - # The writing happens to a specific batch index in the processed obs batch. - # The processed obs batch is required for training. - if batch_index >= 0: - self.push_processed_obs_to_batch(batch_index, ip) - - else: - ip = obs - + ip = obs # Feed through the FC layers for layer in range(len(self.fc)): op = self.fc[str(layer)](ip) @@ -254,8 +87,3 @@ def forward(self, obs=None, batch_index=None): return action_probs, vals - def push_processed_obs_to_batch(self, batch_index, processed_obs): - name = f"{_PROCESSED_OBSERVATIONS}_batch_{self.policy}" - self.env.cuda_data_manager.data_on_device_via_torch(name=name)[ - batch_index - ] = processed_obs diff --git a/warp_drive/training/models/fully_connected_actor_critic.py b/warp_drive/training/models/fully_connected_actor_critic.py new file mode 100644 index 0000000..185a5fc --- /dev/null +++ b/warp_drive/training/models/fully_connected_actor_critic.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause +# +""" +The Fully Connected Network class +""" + +import numpy as np +import torch +import torch.nn.functional as func +from torch import nn +from warp_drive.training.models.model_base import ModelBaseFullyConnected, apply_logit_mask + + +# Policy networks +# --------------- +class FullyConnectedActor(ModelBaseFullyConnected): + """ + Fully connected network implementation in Pytorch + """ + + name = "torch_fully_connected_actor" + + def __init__( + self, + env, + model_config, + policy, + policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=False, + obs_dim_corresponding_to_num_agents="first", + ): + super().__init__(env, + model_config, + policy, + policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents, + include_value_head=False,) + num_fc_layers = len(self.fc_dims) + input_dims = [self.flattened_obs_size] + self.fc_dims[:-1] + output_dims = self.fc_dims + for fc_layer in range(num_fc_layers): + self.fc[str(fc_layer)] = nn.Sequential( + nn.Linear(input_dims[fc_layer], output_dims[fc_layer]), + nn.ReLU(), + ) + + def forward(self, obs=None, action=None): + """ + Forward pass through the model. + Returns action probabilities. + """ + ip = obs + # Feed through the FC layers + for layer in range(len(self.fc)): + op = self.fc[str(layer)](ip) + ip = op + logits = op + + # Compute the action probabilities and the value function estimate + # Apply action mask to the logits as well. + if self.is_deterministic: + combined_action_probs = func.tanh(apply_logit_mask(self.policy_head(logits), self.action_mask)) + if self.output_dims[0] > 1: + # we split the actions to their corresponding heads + # we make sure after the split, we rearrange the memory so each chunk is still C-continguous + # otherwise the sampler may have index issue + action_probs = list(torch.split(combined_action_probs, 1, dim=-1)) + action_probs = [ap.contiguous() for ap in action_probs] + else: + action_probs = [combined_action_probs] + else: + action_masks = [None for _ in range(len(self.output_dims))] + if self.action_mask is not None: + start = 0 + for idx, dim in enumerate(self.output_dims): + action_masks[idx] = self.action_mask[..., start : start + dim] + start = start + dim + action_probs = [ + func.softmax(apply_logit_mask(ph(logits), action_masks[idx]), dim=-1) + for idx, ph in enumerate(self.policy_head) + ] + + return action_probs + + +# Q-Critic networks +# --------------- +class FullyConnectedActionValueCritic(ModelBaseFullyConnected): + """ + Fully connected network implementation in Pytorch + """ + + name = "torch_fully_connected_q" + + def __init__( + self, + env, + model_config, + policy, + policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=False, + obs_dim_corresponding_to_num_agents="first", + ): + super().__init__(env, + model_config, + policy, + policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents, + include_policy_head=False, ) + num_fc_layers = len(self.fc_dims) + input_dims = [self.flattened_obs_size + self.flattened_action_size] + self.fc_dims[:-1] + output_dims = self.fc_dims + for fc_layer in range(num_fc_layers): + self.fc[str(fc_layer)] = nn.Sequential( + nn.Linear(input_dims[fc_layer], output_dims[fc_layer]), + nn.ReLU(), + ) + + def forward(self, obs=None, action=None): + """ + Forward pass through the model. + Returns Q value. + """ + ip = torch.cat([obs, action], dim=-1) + # Feed through the FC layers + for layer in range(len(self.fc)): + op = self.fc[str(layer)](ip) + ip = op + logits = op + + vals = self.vf_head(logits)[..., 0] + return vals diff --git a/warp_drive/training/models/model_base.py b/warp_drive/training/models/model_base.py new file mode 100644 index 0000000..82a0606 --- /dev/null +++ b/warp_drive/training/models/model_base.py @@ -0,0 +1,207 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause +# +""" +The Fully Connected Network class +""" + +import numpy as np +import torch +import torch.nn.functional as func +from gym.spaces import Box, Dict, Discrete, MultiDiscrete +from torch import nn + +from warp_drive.utils.constants import Constants +from warp_drive.utils.data_feed import DataFeed +from warp_drive.training.utils.data_loader import get_flattened_obs_size + +_OBSERVATIONS = Constants.OBSERVATIONS +_PROCESSED_OBSERVATIONS = Constants.PROCESSED_OBSERVATIONS +_ACTION_MASK = Constants.ACTION_MASK + +_LARGE_NEG_NUM = -1e20 + + +class ModelBaseFullyConnected(nn.Module): + """ + Fully connected network implementation in Pytorch + """ + + name = "model_base_fully_connected" + + def __init__( + self, + env, + model_config, + policy, + policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=False, + obs_dim_corresponding_to_num_agents="first", + include_policy_head=True, + include_value_head=True, + ): + super().__init__() + + self.env = env + self.fc_dims = model_config["fc_dims"] + assert isinstance(self.fc_dims, list) + self.policy = policy + self.policy_tag_to_agent_id_map = policy_tag_to_agent_id_map + self.create_separate_placeholders_for_each_policy = ( + create_separate_placeholders_for_each_policy + ) + assert obs_dim_corresponding_to_num_agents in ["first", "last"] + self.obs_dim_corresponding_to_num_agents = obs_dim_corresponding_to_num_agents + + sample_agent_id = self.policy_tag_to_agent_id_map[self.policy][0] + # Flatten obs space + self.observation_space = self.env.env.observation_space[sample_agent_id] + self.flattened_obs_size = get_flattened_obs_size(self.observation_space) + self.is_deterministic = False + if isinstance(self.env.env.action_space[sample_agent_id], Discrete): + action_space = [self.env.env.action_space[sample_agent_id].n] + elif isinstance(self.env.env.action_space[sample_agent_id], MultiDiscrete): + action_space = self.env.env.action_space[sample_agent_id].nvec + elif isinstance(self.env.env.action_space[sample_agent_id], Box): + # deterministic action space + action_space = [1] * self.env.env.action_space[sample_agent_id].shape[0] + self.is_deterministic = True + else: + raise NotImplementedError + + self.flattened_action_size = len(action_space) + + self.fc = nn.ModuleDict() # this is defined in the child class + + if include_policy_head: + # policy network (list of heads) + if self.is_deterministic: + self.output_dims = [len(action_space)] + self.policy_head = nn.Linear(self.fc_dims, len(action_space)) + else: + policy_heads = [None for _ in range(len(action_space))] + self.output_dims = [] # Network output dimension(s) + for idx, act_space in enumerate(action_space): + self.output_dims += [act_space] + policy_heads[idx] = nn.Linear(self.fc_dims[-1], act_space) + self.policy_head = nn.ModuleList(policy_heads) + if include_value_head: + # value-function network head + self.vf_head = nn.Linear(self.fc_dims[-1], 1) + + # used for action masking + self.action_mask = None + + # max batch size allowed + name = f"{_PROCESSED_OBSERVATIONS}_batch_{self.policy}" + self.batch_size = self.env.cuda_data_manager.get_shape(name=name)[0] + + def reshape_and_flatten_obs(self, obs): + """ + # Note: WarpDrive assumes that all the observation are shaped + # (num_agents, *feature_dim), i.e., the observation dimension + # corresponding to 'num_agents' is the first one. If the observation + # dimension corresponding to num_agents is last, we will need to + # permute the axes to align with WarpDrive's assumption. + """ + num_envs = obs.shape[0] + if self.create_separate_placeholders_for_each_policy: + num_agents = len(self.policy_tag_to_agent_id_map[self.policy]) + else: + num_agents = self.env.n_agents + + if self.obs_dim_corresponding_to_num_agents == "first": + pass + elif self.obs_dim_corresponding_to_num_agents == "last": + shape_len = len(obs.shape) + if shape_len == 1: + obs = obs.reshape(-1, num_agents) # valid only when num_agents = 1 + obs = obs.permute(0, -1, *range(1, shape_len - 1)) + else: + raise ValueError( + "num_agents can only be the first " + "or the last dimension in the observations." + ) + return obs.reshape(num_envs, num_agents, -1) + + def get_flattened_obs(self): + """ + If the obs is of Box type, it will already be flattened. + If the obs is of Dict type, then concatenate all the + obs values and flatten them out. + Returns the concatenated and flattened obs. + + """ + if isinstance(self.observation_space, Box): + if self.create_separate_placeholders_for_each_policy: + obs = self.env.cuda_data_manager.data_on_device_via_torch( + f"{_OBSERVATIONS}_{self.policy}" + ) + else: + obs = self.env.cuda_data_manager.data_on_device_via_torch(_OBSERVATIONS) + + flattened_obs = self.reshape_and_flatten_obs(obs) + elif isinstance(self.observation_space, Dict): + obs_dict = {} + for key in self.observation_space: + if self.create_separate_placeholders_for_each_policy: + obs = self.env.cuda_data_manager.data_on_device_via_torch( + f"{_OBSERVATIONS}_{self.policy}_{key}" + ) + else: + obs = self.env.cuda_data_manager.data_on_device_via_torch( + f"{_OBSERVATIONS}_{key}" + ) + + if key == _ACTION_MASK: + self.action_mask = self.reshape_and_flatten_obs(obs) + assert self.action_mask.shape[-1] == sum(self.output_dims) + else: + obs_dict[key] = obs + + flattened_obs_dict = {} + for key, value in obs_dict.items(): + flattened_obs_dict[key] = self.reshape_and_flatten_obs(value) + flattened_obs = torch.cat(list(flattened_obs_dict.values()), dim=-1) + else: + raise NotImplementedError("Observation space must be of Box or Dict type") + + assert flattened_obs.shape[-1] == self.flattened_obs_size, \ + f"The flattened observation size of {flattened_obs.shape[-1]} is different " \ + f"from the designated size of {self.flattened_obs_size} " + + return flattened_obs + + def process_one_step_obs(self): + obs = self.get_flattened_obs() + if not self.create_separate_placeholders_for_each_policy: + agent_ids_for_policy = self.policy_tag_to_agent_id_map[self.policy] + obs = obs[:, agent_ids_for_policy] + return obs + + def forward(self, obs=None, action=None): + raise NotImplementedError + + def push_processed_obs_to_batch(self, batch_index, processed_obs): + if batch_index >= 0: + assert batch_index < self.batch_size, f"batch_index: {batch_index}, self.batch_size: {self.batch_size}" + name = f"{_PROCESSED_OBSERVATIONS}_batch_{self.policy}" + self.env.cuda_data_manager.data_on_device_via_torch(name=name)[ + batch_index + ] = processed_obs + + +def apply_logit_mask(logits, mask=None): + """ + Mask values of 1 are valid actions. + Add huge negative values to logits with 0 mask values. + """ + if mask is None: + return logits + + logit_mask = torch.ones_like(logits) * _LARGE_NEG_NUM + logit_mask = logit_mask * (1 - mask) + return logits + logit_mask diff --git a/warp_drive/training/trainer.py b/warp_drive/training/trainer.py index 3a002ce..b6eba8f 100644 --- a/warp_drive/training/trainer.py +++ b/warp_drive/training/trainer.py @@ -282,6 +282,7 @@ def __init__( # Note: Loading the model checkpoint may also update the current timestep! self.load_model_checkpoint() + self.ddp_mode = {} for policy in self.policies: # Push the models to the GPU self.models[policy].cuda() @@ -290,6 +291,9 @@ def __init__( self.models[policy] = DDP( self.models[policy], device_ids=[self.device_id] ) + self.ddp_mode[policy] = True + else: + self.ddp_mode[policy] = False # Initialize the (ADAM) optimizer lr_config = self._get_config(["policy", policy, "lr"]) @@ -481,7 +485,14 @@ def _evaluate_policies(self, batch_index=0): assert isinstance(batch_index, int) probabilities = {} for policy in self.policies: - probabilities[policy], _ = self.models[policy](batch_index=batch_index) + if self.ddp_mode[policy]: + # self.models[policy] is a DDP wrapper of the model instance + obs = self.models[policy].module.process_one_step_obs() + self.models[policy].module.push_processed_obs_to_batch(batch_index, obs) + else: + obs = self.models[policy].process_one_step_obs() + self.models[policy].push_processed_obs_to_batch(batch_index, obs) + probabilities[policy], _ = self.models[policy](obs) # Combine probabilities across policies if there are multiple policies, # yet they share the same action placeholders. From 7a1a1eaaf7cb6e606098ee9ae35f76540c8bc7c5 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Wed, 17 Jan 2024 12:32:56 -0800 Subject: [PATCH 08/19] refactor trainers --- .../{ => deprecate}/pytorch_lightning.py | 2 +- .../training/{ => deprecate}/trainer.py | 0 .../training/run_configs/single_pendulum.yaml | 47 ++ .../example_training_script_numba.py | 4 +- .../example_training_script_pycuda.py | 4 +- warp_drive/training/trainer_a2c.py | 364 ++++++++ warp_drive/training/trainer_base.py | 788 ++++++++++++++++++ 7 files changed, 1204 insertions(+), 5 deletions(-) rename warp_drive/training/{ => deprecate}/pytorch_lightning.py (99%) rename warp_drive/training/{ => deprecate}/trainer.py (100%) create mode 100644 warp_drive/training/run_configs/single_pendulum.yaml rename warp_drive/training/{ => scripts}/example_training_script_numba.py (99%) rename warp_drive/training/{ => scripts}/example_training_script_pycuda.py (99%) create mode 100644 warp_drive/training/trainer_a2c.py create mode 100644 warp_drive/training/trainer_base.py diff --git a/warp_drive/training/pytorch_lightning.py b/warp_drive/training/deprecate/pytorch_lightning.py similarity index 99% rename from warp_drive/training/pytorch_lightning.py rename to warp_drive/training/deprecate/pytorch_lightning.py index 4834f18..66cb3b6 100644 --- a/warp_drive/training/pytorch_lightning.py +++ b/warp_drive/training/deprecate/pytorch_lightning.py @@ -31,7 +31,7 @@ from warp_drive.training.algorithms.policygradient.a2c import A2C from warp_drive.training.algorithms.policygradient.ppo import PPO from warp_drive.training.models.factory import ModelFactory -from warp_drive.training.trainer import Metrics +from warp_drive.training.trainer_a2c import Metrics from warp_drive.training.utils.data_loader import create_and_push_data_placeholders from warp_drive.training.utils.param_scheduler import LRScheduler, ParamScheduler from warp_drive.utils.common import get_project_root diff --git a/warp_drive/training/trainer.py b/warp_drive/training/deprecate/trainer.py similarity index 100% rename from warp_drive/training/trainer.py rename to warp_drive/training/deprecate/trainer.py diff --git a/warp_drive/training/run_configs/single_pendulum.yaml b/warp_drive/training/run_configs/single_pendulum.yaml new file mode 100644 index 0000000..2d99ca4 --- /dev/null +++ b/warp_drive/training/run_configs/single_pendulum.yaml @@ -0,0 +1,47 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +# YAML configuration for the tag gridworld environment +name: "single_pendulum" +# Environment settings +env: + episode_length: 500 + reset_pool_size: 1000 +# Trainer settings +trainer: + num_envs: 100 # number of environment replicas + num_episodes: 200000 # number of episodes to run the training for. Can be arbitrarily high! + train_batch_size: 50000 # total batch size used for training per iteration (across all the environments) + env_backend: "numba" # environment backend, pycuda or numba +# Policy network settings +policy: # list all the policies below + shared: + to_train: True # flag indicating whether the model needs to be trained + algorithm: "DDPG" # algorithm used to train the policy + vf_loss_coeff: 1 # loss coefficient schedule for the value function loss + entropy_coeff: 0.05 # loss coefficient schedule for the entropy loss + clip_grad_norm: True # flag indicating whether to clip the gradient norm or not + max_grad_norm: 3 # when clip_grad_norm is True, the clip level + normalize_advantage: False # flag indicating whether to normalize advantage or not + normalize_return: False # flag indicating whether to normalize return or not + gamma: 0.99 # discount factor + lr: 0.001 # learning rate + model: # policy model settings + type: + actor: "fully_connected_actor" # model type + critic: "fully_connected_action_value_critic" + fc_dims: + actor: [32, 32] # dimension(s) of the fully connected layers as a list + critic: [32, 32] + model_ckpt_filepath: "" # filepath (used to restore a previously saved model) +# Checkpoint saving setting +saving: + metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics + model_params_save_freq: 5000 # how often (in iterations) to save the model parameters + basedir: "/tmp" # base folder used for saving + name: "single_pendulum" # base folder used for saving + tag: "experiments" # experiment name + diff --git a/warp_drive/training/example_training_script_numba.py b/warp_drive/training/scripts/example_training_script_numba.py similarity index 99% rename from warp_drive/training/example_training_script_numba.py rename to warp_drive/training/scripts/example_training_script_numba.py index 77a8b08..0baf85b 100644 --- a/warp_drive/training/example_training_script_numba.py +++ b/warp_drive/training/scripts/example_training_script_numba.py @@ -23,7 +23,7 @@ from example_envs.single_agent.classic_control.mountain_car.mountain_car import CUDAClassicControlMountainCarEnv from example_envs.single_agent.classic_control.acrobot.acrobot import CUDAClassicControlAcrobotEnv from warp_drive.env_wrapper import EnvWrapper -from warp_drive.training.trainer import Trainer +from warp_drive.training.trainer_a2c import TrainerA2C from warp_drive.training.utils.distributed_train.distributed_trainer_numba import ( perform_distributed_training, ) @@ -152,7 +152,7 @@ def setup_trainer_and_train( ) # Trainer object # -------------- - trainer = Trainer( + trainer = TrainerA2C( env_wrapper=env_wrapper, config=run_configuration, policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, diff --git a/warp_drive/training/example_training_script_pycuda.py b/warp_drive/training/scripts/example_training_script_pycuda.py similarity index 99% rename from warp_drive/training/example_training_script_pycuda.py rename to warp_drive/training/scripts/example_training_script_pycuda.py index 10b78ac..efb9d56 100644 --- a/warp_drive/training/example_training_script_pycuda.py +++ b/warp_drive/training/scripts/example_training_script_pycuda.py @@ -20,7 +20,7 @@ from example_envs.tag_continuous.tag_continuous import TagContinuous from example_envs.tag_gridworld.tag_gridworld import CUDATagGridWorld from warp_drive.env_wrapper import EnvWrapper -from warp_drive.training.trainer import Trainer +from warp_drive.training.trainer_a2c import TrainerA2C from warp_drive.training.utils.distributed_train.distributed_trainer_pycuda import ( perform_distributed_training, ) @@ -112,7 +112,7 @@ def setup_trainer_and_train( ) # Trainer object # -------------- - trainer = Trainer( + trainer = TrainerA2C( env_wrapper=env_wrapper, config=run_configuration, policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, diff --git a/warp_drive/training/trainer_a2c.py b/warp_drive/training/trainer_a2c.py new file mode 100644 index 0000000..f1f45a1 --- /dev/null +++ b/warp_drive/training/trainer_a2c.py @@ -0,0 +1,364 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +""" +The Trainer, PerfStats and Metrics classes +""" + +import json +import logging +import os +import random +import time + +import numpy as np +import torch +import yaml +from gym.spaces import Discrete, MultiDiscrete +from torch import nn +from torch.nn.parallel import DistributedDataParallel as DDP +from warp_drive.training.trainer_base import TrainerBase, all_equal, verbose_print + +from warp_drive.training.algorithms.policygradient.a2c import A2C +from warp_drive.training.algorithms.policygradient.ppo import PPO +from warp_drive.training.models.factory import ModelFactory +from warp_drive.training.utils.data_loader import create_and_push_data_placeholders +from warp_drive.training.utils.param_scheduler import ParamScheduler +from warp_drive.utils.common import get_project_root +from warp_drive.utils.constants import Constants + +_ROOT_DIR = get_project_root() + +_ACTIONS = Constants.ACTIONS +_REWARDS = Constants.REWARDS +_DONE_FLAGS = Constants.DONE_FLAGS +_PROCESSED_OBSERVATIONS = Constants.PROCESSED_OBSERVATIONS +_COMBINED = "combined" +_EPSILON = 1e-10 # small number to prevent indeterminate divisions + + +class TrainerA2C(TrainerBase): + def __init__( + self, + env_wrapper=None, + config=None, + policy_tag_to_agent_id_map=None, + create_separate_placeholders_for_each_policy=False, + obs_dim_corresponding_to_num_agents="first", + num_devices=1, + device_id=0, + results_dir=None, + verbose=True, + ): + self.models = {} + super().__init__( + env_wrapper=env_wrapper, + config=config, + policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=obs_dim_corresponding_to_num_agents, + num_devices=num_devices, + device_id=device_id, + results_dir=results_dir, + verbose=verbose, + ) + + def _initialize_policy_algorithm(self, policy): + algorithm = self._get_config(["policy", policy, "algorithm"]) + assert algorithm in ["A2C", "PPO"] + entropy_coeff = self._get_config(["policy", policy, "entropy_coeff"]) + vf_loss_coeff = self._get_config(["policy", policy, "vf_loss_coeff"]) + self.clip_grad_norm[policy] = self._get_config( + ["policy", policy, "clip_grad_norm"] + ) + if self.clip_grad_norm[policy]: + self.max_grad_norm[policy] = self._get_config( + ["policy", policy, "max_grad_norm"] + ) + normalize_advantage = self._get_config( + ["policy", policy, "normalize_advantage"] + ) + normalize_return = self._get_config(["policy", policy, "normalize_return"]) + gamma = self._get_config(["policy", policy, "gamma"]) + if algorithm == "A2C": + # Advantage Actor-Critic + self.trainers[policy] = A2C( + discount_factor_gamma=gamma, + normalize_advantage=normalize_advantage, + normalize_return=normalize_return, + vf_loss_coeff=vf_loss_coeff, + entropy_coeff=entropy_coeff, + ) + logging.info(f"Initializing the A2C trainer for policy {policy}") + elif algorithm == "PPO": + # Proximal Policy Optimization + clip_param = self._get_config(["policy", policy, "clip_param"]) + self.trainers[policy] = PPO( + discount_factor_gamma=gamma, + clip_param=clip_param, + normalize_advantage=normalize_advantage, + normalize_return=normalize_return, + vf_loss_coeff=vf_loss_coeff, + entropy_coeff=entropy_coeff, + ) + logging.info(f"Initializing the PPO trainer for policy {policy}") + else: + raise NotImplementedError + + def _initialize_policy_model(self, policy): + policy_model_config = self._get_config(["policy", policy, "model"]) + model_obj = ModelFactory.create(policy_model_config["type"]) + model = model_obj( + env=self.cuda_envs, + model_config=policy_model_config, + policy=policy, + policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, + ) + + if "init_method" in policy_model_config and \ + policy_model_config["init_method"] == "xavier": + def init_weights_by_xavier_uniform(m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform(m.weight) + + model.apply(init_weights_by_xavier_uniform) + + self.models[policy] = model + + def _send_policy_model_to_device(self, policy): + self.models[policy].cuda() + # If distributed train, sync model using DDP + if self.num_devices > 1: + self.models[policy] = DDP( + self.models[policy], device_ids=[self.device_id] + ) + self.ddp_mode[policy] = True + else: + self.ddp_mode[policy] = False + + def _initialize_optimizer(self, policy): + # Initialize the (ADAM) optimizer + lr_config = self._get_config(["policy", policy, "lr"]) + self.lr_schedules[policy] = ParamScheduler(lr_config) + initial_lr = self.lr_schedules[policy].get_param_value( + timestep=self.current_timestep[policy] + ) + self.optimizers[policy] = torch.optim.Adam( + self.models[policy].parameters(), lr=initial_lr + ) + + def _evaluate_policies(self, batch_index=0): + """ + Perform the policy evaluation (forward pass through the models) + and compute action probabilities + """ + assert isinstance(batch_index, int) + probabilities = {} + for policy in self.policies: + if self.ddp_mode[policy]: + # self.models[policy] is a DDP wrapper of the model instance + obs = self.models[policy].module.process_one_step_obs() + self.models[policy].module.push_processed_obs_to_batch(batch_index, obs) + else: + obs = self.models[policy].process_one_step_obs() + self.models[policy].push_processed_obs_to_batch(batch_index, obs) + probabilities[policy], _ = self.models[policy](obs) + + # Combine probabilities across policies if there are multiple policies, + # yet they share the same action placeholders. + # The action sampler will then need to run just once on each action type. + if ( + len(self.policies) > 1 + and not self.create_separate_placeholders_for_each_policy + ): + # Assert that all the probabilities are of the same length + # in other words the number of action types for each policy + # is the same. + num_action_types = {} + for policy in self.policies: + num_action_types[policy] = len(probabilities[policy]) + assert all_equal(list(num_action_types.values())) + + # Initialize combined_probabilities. + first_policy = list(probabilities.keys())[0] + num_action_types = num_action_types[first_policy] + + first_action_type_id = 0 + num_envs = probabilities[first_policy][first_action_type_id].shape[0] + num_agents = self.cuda_envs.env.num_agents + + combined_probabilities = [None for _ in range(num_action_types)] + for action_type_id in range(num_action_types): + action_dim = probabilities[first_policy][action_type_id].shape[-1] + combined_probabilities[action_type_id] = torch.zeros( + (num_envs, num_agents, action_dim) + ).cuda() + + # Combine the probabilities across policies + for action_type_id in range(num_action_types): + for policy, prob_values in probabilities.items(): + agent_to_id_mapping = self.policy_tag_to_agent_id_map[policy] + combined_probabilities[action_type_id][ + :, agent_to_id_mapping + ] = prob_values[action_type_id] + + probabilities = {_COMBINED: combined_probabilities} + + return probabilities + + def _update_model_params(self, iteration): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + # Flag for logging (which also happens after the last iteration) + logging_flag = ( + iteration % self.config["saving"]["metrics_log_freq"] == 0 + or iteration == self.num_iters - 1 + ) + + metrics_dict = {} + + done_flags_batch = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_DONE_FLAGS}_batch" + ) + # On the device, observations_batch, actions_batch, + # rewards_batch are all shaped + # (batch_size, num_envs, num_agents, *feature_dim). + # done_flags_batch is shaped (batch_size, num_envs) + # Perform training sequentially for each policy + for policy in self.policies_to_train: + actions_batch = ( + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_ACTIONS}_batch_{policy}" + ) + ) + rewards_batch = ( + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_REWARDS}_batch_{policy}" + ) + ) + processed_obs_batch = ( + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_PROCESSED_OBSERVATIONS}_batch_{policy}" + ) + ) + # Policy evaluation for the entire batch + probabilities_batch, value_functions_batch = self.models[policy]( + obs=processed_obs_batch + ) + # Loss and metrics computation + loss, metrics = self.trainers[policy].compute_loss_and_metrics( + self.current_timestep[policy], + actions_batch, + rewards_batch, + done_flags_batch, + probabilities_batch, + value_functions_batch, + perform_logging=logging_flag, + ) + # Compute the gradient norm + grad_norm = 0.0 + for param in list( + filter(lambda p: p.grad is not None, self.models[policy].parameters()) + ): + grad_norm += param.grad.data.norm(2).item() + + # Update the timestep and learning rate based on the schedule + self.current_timestep[policy] += self.training_batch_size + lr = self.lr_schedules[policy].get_param_value( + self.current_timestep[policy] + ) + for param_group in self.optimizers[policy].param_groups: + param_group["lr"] = lr + + # Loss backpropagation and optimization step + self.optimizers[policy].zero_grad() + loss.backward() + if self.clip_grad_norm[policy]: + nn.utils.clip_grad_norm_( + self.models[policy].parameters(), self.max_grad_norm[policy] + ) + + self.optimizers[policy].step() + # Logging + if logging_flag: + metrics_dict[policy] = metrics + # Update the metrics dictionary + metrics_dict[policy].update( + { + "Current timestep": self.current_timestep[policy], + "Gradient norm": grad_norm, + "Learning rate": lr, + "Mean episodic reward": self.episodic_reward_sum[policy].item() + / (self.num_completed_episodes[policy] + _EPSILON), + "Mean episodic steps": self.episodic_step_sum[policy].item() + / (self.num_completed_episodes[policy] + _EPSILON), + } + ) + + # Reset sum and counter + self.episodic_reward_sum[policy] = ( + torch.tensor(0).type(torch.float32).cuda() + ) + self.episodic_step_sum[policy] = ( + torch.tensor(0).type(torch.int64).cuda() + ) + self.num_completed_episodes[policy] = 0 + + end_event.record() + torch.cuda.synchronize() + + self.perf_stats.training_time += start_event.elapsed_time(end_event) / 1000 + return metrics_dict + + def _load_model_checkpoint_helper(self, policy, ckpt_filepath): + if ckpt_filepath != "": + assert os.path.isfile(ckpt_filepath), "Invalid model checkpoint path!" + if self.verbose: + verbose_print( + f"Loading the '{policy}' torch model " + f"from the previously saved checkpoint: '{ckpt_filepath}'", + self.device_id, + ) + self.models[policy].load_state_dict(torch.load(ckpt_filepath)) + + # Update the current timestep using the saved checkpoint filename + timestep = int(ckpt_filepath.split(".state_dict")[0].split("_")[-1]) + if self.verbose: + verbose_print( + f"Updating the timestep for the '{policy}' model to {timestep}.", + self.device_id, + ) + self.current_timestep[policy] = timestep + + def save_model_checkpoint(self, iteration=0): + """ + Save the model parameters + """ + # If multiple devices, save the synced-up model only for device id 0 + if self.device_id == 0: + # Save model checkpoints if specified (and also for the last iteration) + if ( + iteration % self.config["saving"]["model_params_save_freq"] == 0 + or iteration == self.num_iters - 1 + ): + for policy, model in self.models.items(): + filepath = os.path.join( + self.save_dir, + f"{policy}_{self.current_timestep[policy]}.state_dict", + ) + if self.verbose: + verbose_print( + f"Saving the '{policy}' torch model " + f"to the file: '{filepath}'.", + self.device_id, + ) + + torch.save(model.state_dict(), filepath) diff --git a/warp_drive/training/trainer_base.py b/warp_drive/training/trainer_base.py new file mode 100644 index 0000000..8d488e3 --- /dev/null +++ b/warp_drive/training/trainer_base.py @@ -0,0 +1,788 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +""" +The Trainer, PerfStats and Metrics classes +""" + +import json +import logging +import os +import random +import time + +import numpy as np +import torch +import yaml +from gym.spaces import Discrete, MultiDiscrete +from torch import nn + +from warp_drive.training.utils.data_loader import create_and_push_data_placeholders +from warp_drive.utils.common import get_project_root +from warp_drive.utils.constants import Constants + +_ROOT_DIR = get_project_root() + +_ACTIONS = Constants.ACTIONS +_REWARDS = Constants.REWARDS +_DONE_FLAGS = Constants.DONE_FLAGS +_PROCESSED_OBSERVATIONS = Constants.PROCESSED_OBSERVATIONS +_COMBINED = "combined" +_EPSILON = 1e-10 # small number to prevent indeterminate divisions + + +def all_equal(iterable): + """ + Check all elements of an iterable (e.g., list) are identical + """ + return len(set(iterable)) <= 1 + + +def recursive_merge_config_dicts(config, default_config): + """ + Merge the configuration dictionary with the default configuration + dictionary to fill in any missing configuration keys. + """ + assert isinstance(config, dict) + assert isinstance(default_config, dict) + + for k, v in default_config.items(): + if k not in config: + config[k] = v + else: + if isinstance(v, dict): + recursive_merge_config_dicts(config[k], v) + return config + + +def verbose_print(message, device_id=None): + if device_id is None: + device_id = 0 + print(f"[Device {device_id}]: {message} ") + + +class TrainerBase: + """ + The trainer object. Contains modules train(), save_model_checkpoint() and + fetch_episode_global_states() + """ + + def __init__( + self, + env_wrapper=None, + config=None, + policy_tag_to_agent_id_map=None, + create_separate_placeholders_for_each_policy=False, + obs_dim_corresponding_to_num_agents="first", + num_devices=1, + device_id=0, + results_dir=None, + verbose=True, + ): + """ + Args: + env_wrapper: the wrapped environment object. + config: the experiment run configuration. + policy_tag_to_agent_id_map: + a dictionary mapping policy tag to agent ids. + create_separate_placeholders_for_each_policy: + a flag indicating whether there exist separate observations, + actions and rewards placeholders, for each policy, + as designed in the step function. The placeholders will be + used in the step() function and during training. + When there's only a single policy, this flag will be False. + It can also be True when there are multiple policies, yet + all the agents have the same obs and action space shapes, + so we can share the same placeholder. + Defaults to "False". + obs_dim_corresponding_to_num_agents: + indicative of which dimension in the observation corresponds + to the number of agents, as designed in the step function. + It may be "first" or "last". In other words, + observations may be shaped (num_agents, *feature_dim) or + (*feature_dim, num_agents). This is required in order for + WarpDrive to process the observations correctly. This is only + relevant when a single obs key corresponds to multiple agents. + Defaults to "first". + num_devices: number of GPU devices used for (distributed) training. + Defaults to 1. + device_id: device ID. This is set in the context of multi-GPU training. + results_dir: (optional) name of the directory to save results into. + verbose: + if False, training metrics are not printed to the screen. + Defaults to True. + """ + assert env_wrapper is not None + assert not env_wrapper.env_backend == "cpu" + assert config is not None + assert isinstance(create_separate_placeholders_for_each_policy, bool) + assert obs_dim_corresponding_to_num_agents in ["first", "last"] + self.obs_dim_corresponding_to_num_agents = obs_dim_corresponding_to_num_agents + + self.cuda_envs = env_wrapper + + # Load in the default configuration + default_config_path = os.path.join( + _ROOT_DIR, "warp_drive", "training", "run_configs", "default_configs.yaml" + ) + with open(default_config_path, "r", encoding="utf8") as fp: + default_config = yaml.safe_load(fp) + + self.config = config + # Fill in any missing configuration parameters using the default values + # Trainer-related configurations + self.config["trainer"] = recursive_merge_config_dicts( + self.config["trainer"], default_config["trainer"] + ) + # Policy-related configurations + for key in config["policy"]: + self.config["policy"][key] = recursive_merge_config_dicts( + self.config["policy"][key], default_config["policy"] + ) + # Sampler-related configurations (usually Optional) + self.sample_params = self._get_config(["sampler", "params"]) if "sampler" in self.config else {} + + # Saving-related configurations + self.config["saving"] = recursive_merge_config_dicts( + self.config["saving"], default_config["saving"] + ) + + if results_dir is None: + # Use the current time as the name for the results directory. + results_dir = f"{time.time():10.0f}" + + # Directory to save model checkpoints and metrics + self.save_dir = os.path.join( + self._get_config(["saving", "basedir"]), + self._get_config(["saving", "name"]), + self._get_config(["saving", "tag"]), + results_dir, + ) + if not os.path.isdir(self.save_dir): + os.makedirs(self.save_dir, exist_ok=True) + + # Save the run configuration + config_filename = os.path.join(self.save_dir, "run_config.json") + with open(config_filename, "a+", encoding="utf8") as fp: + json.dump(self.config, fp) + fp.write("\n") + + # Flag to determine whether to print training metrics + self.verbose = verbose + + # Number of GPU devices in the train + self.num_devices = num_devices + self.device_id = device_id + + # Policies + self.policy_tag_to_agent_id_map = policy_tag_to_agent_id_map + self.policies = list(self._get_config(["policy"]).keys()) + self.policies_to_train = [ + policy + for policy in self.policies + if self._get_config(["policy", policy, "to_train"]) + ] + + # Flag indicating whether there needs to be separate placeholders / arrays + # for observation, actions and rewards, for each policy + self.create_separate_placeholders_for_each_policy = ( + create_separate_placeholders_for_each_policy + ) + # Note: separate placeholders are needed only when there are + # multiple policies + if self.create_separate_placeholders_for_each_policy: + assert len(self.policies) > 1 + + # Number of iterations algebra + self.num_episodes = self._get_config(["trainer", "num_episodes"]) + assert self.num_episodes > 0 + self.training_batch_size = self._get_config(["trainer", "train_batch_size"]) + self.num_envs = self._get_config(["trainer", "num_envs"]) + + self.training_batch_size_per_env = self.training_batch_size // self.num_envs + assert self.training_batch_size_per_env > 0 + + # Push all the data and tensor arrays to the GPU + # upon resetting environments for the very first time. + self.cuda_envs.reset_all_envs() + + if env_wrapper.env_backend == "pycuda": + from warp_drive.managers.pycuda_managers.pycuda_function_manager import ( + PyCUDASampler, + ) + + self.cuda_sample_controller = PyCUDASampler( + self.cuda_envs.cuda_function_manager + ) + elif env_wrapper.env_backend == "numba": + from warp_drive.managers.numba_managers.numba_function_manager import ( + NumbaSampler, + ) + + self.cuda_sample_controller = NumbaSampler( + self.cuda_envs.cuda_function_manager + ) + + # Create and push data placeholders to the device + create_and_push_data_placeholders( + env_wrapper=self.cuda_envs, + action_sampler=self.cuda_sample_controller, + policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, + training_batch_size_per_env=self.training_batch_size_per_env, + ) + # Seeding (device_id is included for distributed training) + seed = ( + self.config["trainer"].get("seed", np.int32(time.time())) + self.device_id + ) + self.cuda_sample_controller.init_random(seed) + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + self.cuda_envs.init_reset_pool(seed + random.randint(1, 10000)) + + # Define optimizers, and learning rate schedules + self.optimizers = {} + self.lr_schedules = {} + + # For logging episodic reward + self.num_completed_episodes = {} + self.episodic_reward_sum = {} + self.reward_running_sum = {} + self.episodic_step_sum = {} + self.step_running_sum = {} + + # Indicates the current timestep of the policy model + self.current_timestep = {} + + self.total_steps = self.cuda_envs.episode_length * self.num_episodes + self.num_iters = int(self.total_steps // self.training_batch_size) + if self.num_iters == 0: + raise ValueError( + "Not enough steps to even perform a single training iteration!. " + "Please increase the number of episodes or reduce the training " + "batch size." + ) + + for policy in self.policies: + self.current_timestep[policy] = 0 + self._initialize_policy_model(policy) + + # Load the model parameters (if model checkpoints are specified) + # Note: Loading the model checkpoint may also update the current timestep! + self.load_model_checkpoint() + self.ddp_mode = {} + for policy in self.policies: + self._send_policy_model_to_device(policy) + self._initialize_optimizer(policy) + # Initialize episodic rewards and push to the GPU + num_agents_for_policy = len(self.policy_tag_to_agent_id_map[policy]) + self.num_completed_episodes[policy] = 0 + self.episodic_reward_sum[policy] = ( + torch.tensor(0).type(torch.float32).cuda() + ) + self.reward_running_sum[policy] = torch.zeros( + (self.num_envs, num_agents_for_policy) + ).cuda() + self.episodic_step_sum[policy] = ( + torch.tensor(0).type(torch.int64).cuda() + ) + self.step_running_sum[policy] = torch.zeros(self.num_envs, dtype=torch.int64).cuda() + + # Initialize the trainers + self.trainers = {} + self.clip_grad_norm = {} + self.max_grad_norm = {} + for policy in self.policies_to_train: + self._initialize_policy_algorithm(policy) + + # Performance Stats + self.perf_stats = PerfStats() + + # Metrics + self.metrics = Metrics() + + def _get_config(self, args): + assert isinstance(args, (tuple, list)) + config = self.config + for arg in args: + try: + config = config[arg] + except ValueError: + logging.error("Missing configuration '{arg}'!") + return config + + # The followings are abstract classes that trainer class needs to finalize + # They are mostly about how to manage and run the models + def _initialize_policy_algorithm(self, policy): + raise NotImplementedError + + def _initialize_policy_model(self, policy): + raise NotImplementedError + + def _send_policy_model_to_device(self, policy): + raise NotImplementedError + + def _initialize_optimizer(self, policy): + raise NotImplementedError + + def _evaluate_policies(self, batch_index=0): + raise NotImplementedError + + def _update_model_params(self, iteration): + raise NotImplementedError + + def _load_model_checkpoint_helper(self, policy, ckpt_filepath): + raise NotImplementedError + + def save_model_checkpoint(self, iteration=0): + raise NotImplementedError + + # End of abstract classes + + def train(self): + """ + Perform training. + """ + # Ensure env is reset before the start of training, and done flags are False + self.cuda_envs.reset_all_envs() + + for iteration in range(self.num_iters): + start_time = time.time() + + # Generate a batched rollout for every CUDA environment. + self._generate_rollout_batch() + + # Train / update model parameters. + metrics = self._update_model_params(iteration) + + self.perf_stats.iters = iteration + 1 + self.perf_stats.steps = self.perf_stats.iters * self.training_batch_size + end_time = time.time() + self.perf_stats.total_time += end_time - start_time + + # Log the training metrics + self._log_metrics(metrics) + + # Save torch model + self.save_model_checkpoint(iteration) + + def _generate_rollout_batch(self): + """ + Generate an environment rollout batch. + """ + # Code timing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for batch_index in range(self.training_batch_size_per_env): + + # Evaluate policies to compute action probabilities + start_event.record() + probabilities = self._evaluate_policies(batch_index=batch_index) + end_event.record() + torch.cuda.synchronize() + self.perf_stats.policy_eval_time += ( + start_event.elapsed_time(end_event) / 1000 + ) + + # Sample actions using the computed probabilities + # and push to the batch of actions + start_event.record() + self._sample_actions(probabilities, batch_index=batch_index, **self.sample_params) + end_event.record() + torch.cuda.synchronize() + self.perf_stats.action_sample_time += ( + start_event.elapsed_time(end_event) / 1000 + ) + + # Step through all the environments + start_event.record() + self.cuda_envs.step_all_envs() + + # Bookkeeping rewards and done flags + _, done_flags = self._bookkeep_rewards_and_done_flags(batch_index=batch_index) + + # Reset all the environments that are in done state. + if done_flags.any(): + self.cuda_envs.reset_only_done_envs() + + end_event.record() + torch.cuda.synchronize() + self.perf_stats.env_step_time += start_event.elapsed_time(end_event) / 1000 + + def _sample_actions(self, probabilities, batch_index=0, **sample_params): + """ + Sample action probabilities (and push the sampled actions to the device). + """ + assert isinstance(batch_index, int) + if self.create_separate_placeholders_for_each_policy: + for policy in self.policies: + # Sample each individual policy + policy_suffix = f"_{policy}" + self._sample_actions_helper( + probabilities[policy], policy_suffix=policy_suffix, **sample_params + ) + # Push the actions to the batch + actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + _ACTIONS + policy_suffix + ) + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=f"{_ACTIONS}_batch" + policy_suffix + )[batch_index] = actions + else: + assert len(probabilities) == 1 + policy = list(probabilities.keys())[0] + # sample a single or a combined policy + self._sample_actions_helper(probabilities[policy], **sample_params) + actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + _ACTIONS + ) + # Push the actions to the batch, if action sampler has no policy tag + # (1) there is only one policy, then action -> action_batch_policy + # (2) there are multiple policies, then action[policy_tag_to_agent_id[policy]] -> action_batch_policy + for policy in self.policies: + if len(self.policies) > 1: + agent_ids_for_policy = self.policy_tag_to_agent_id_map[policy] + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=f"{_ACTIONS}_batch_{policy}" + )[batch_index] = actions[:, agent_ids_for_policy] + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=f"{_ACTIONS}_batch_{policy}" + )[batch_index] = actions + + def _sample_actions_helper(self, probabilities, policy_suffix="", **sample_params): + # Sample actions with policy_suffix tag + num_action_types = len(probabilities) + + if num_action_types == 1: + action_name = _ACTIONS + policy_suffix + self.cuda_sample_controller.sample( + self.cuda_envs.cuda_data_manager, probabilities[0], action_name, **sample_params + ) + else: + for action_type_id, probs in enumerate(probabilities): + action_name = f"{_ACTIONS}_{action_type_id}" + policy_suffix + self.cuda_sample_controller.sample( + self.cuda_envs.cuda_data_manager, probs, action_name, **sample_params + ) + # Push (indexed) actions to 'actions' + actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + action_name + ) + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=_ACTIONS + policy_suffix + )[:, :, action_type_id] = actions[:, :, 0] + + def _bookkeep_rewards_and_done_flags(self, batch_index): + """ + Push rewards and done flags to the corresponding batched versions. + Also, update the episodic reward + """ + assert isinstance(batch_index, int) + # Push done flags to done_flags_batch + done_flags = ( + self.cuda_envs.cuda_data_manager.data_on_device_via_torch("_done_") > 0 + ) + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=f"{_DONE_FLAGS}_batch" + )[batch_index] = done_flags + + done_env_ids = done_flags.nonzero() + + # Push rewards to rewards_batch and update the episodic rewards + if self.create_separate_placeholders_for_each_policy: + for policy in self.policies: + rewards = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_REWARDS}_{policy}" + ) + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=f"{_REWARDS}_batch_{policy}" + )[batch_index] = rewards + + # Update the episodic rewards + self._update_episodic_rewards(rewards, done_env_ids, policy) + + else: + rewards = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + _REWARDS + ) + for policy in self.policies: + if len(self.policies) > 1: + agent_ids_for_policy = self.policy_tag_to_agent_id_map[policy] + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=f"{_REWARDS}_batch_{policy}" + )[batch_index] = rewards[:, agent_ids_for_policy] + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=f"{_REWARDS}_batch_{policy}" + )[batch_index] = rewards + + # Update the episodic rewards + # (sum of individual step rewards over an episode) + for policy in self.policies: + self._update_episodic_rewards( + rewards[:, self.policy_tag_to_agent_id_map[policy]], + done_env_ids, + policy, + ) + return rewards, done_flags + + def _update_episodic_rewards(self, rewards, done_env_ids, policy): + self.reward_running_sum[policy] += rewards + self.step_running_sum[policy] += 1 + + num_completed_episodes = len(done_env_ids) + if num_completed_episodes > 0: + # Update the episodic rewards + self.episodic_reward_sum[policy] += torch.sum( + self.reward_running_sum[policy][done_env_ids] + ) + self.episodic_step_sum[policy] += torch.sum( + self.step_running_sum[policy][done_env_ids] + ) + self.num_completed_episodes[policy] += num_completed_episodes + # Reset the reward running sum + self.reward_running_sum[policy][done_env_ids] = 0 + self.step_running_sum[policy][done_env_ids] = 0 + + def _log_metrics(self, metrics): + # Log the metrics if it is not empty + if len(metrics) > 0: + perf_stats = self.perf_stats.get_perf_stats() + + if self.verbose: + print("\n") + print("=" * 40) + print(f"Device: {self.device_id}") + print( + f"{'Iterations Completed':40}: " + f"{self.perf_stats.iters} / {self.num_iters}" + ) + self.perf_stats.pretty_print(perf_stats) + self.metrics.pretty_print(metrics) + print("=" * 40, "\n") + + # Log metrics and performance stats + logs = {"Iterations Completed": self.perf_stats.iters} + logs.update(metrics) + logs.update({"Perf. Stats": perf_stats}) + + if self.num_devices > 1: + fn = f"results_device_{self.device_id}.json" + else: + fn = "results.json" + results_filename = os.path.join(self.save_dir, fn) + if self.verbose: + verbose_print( + f"Saving the results to the file '{results_filename}'", + self.device_id, + ) + with open(results_filename, "a+", encoding="utf8") as fp: + json.dump(logs, fp) + fp.write("\n") + self._heartbeat_check_printout(metrics) + + def _heartbeat_check_printout(self, metrics, check=False): + if check is False: + return + + if self.num_devices > 1: + heartbeat_print = ( + "Iterations Completed: " + + f"{self.perf_stats.iters} / {self.num_iters}: \n" + ) + for policy in self.policies: + heartbeat_print += ( + f"policy '{policy}' - Mean episodic reward: " + f"{metrics[policy]['Mean episodic reward']} \n" + ) + verbose_print(heartbeat_print, self.device_id) + + def load_model_checkpoint(self, ckpts_dict=None): + """ + Load the model parameters if a checkpoint path is specified. + """ + if ckpts_dict is None: + logging.info( + "Loading trainer model checkpoints from the run configuration." + ) + for policy in self.policies: + ckpt_filepath = self.config["policy"][policy]["model"][ + "model_ckpt_filepath" + ] + self._load_model_checkpoint_helper(policy, ckpt_filepath) + else: + assert isinstance(ckpts_dict, dict) + if self.verbose: + verbose_print( + "Loading the provided trainer model checkpoints.", self.device_id + ) + for policy, ckpt_filepath in ckpts_dict.items(): + assert policy in self.policies + self._load_model_checkpoint_helper(policy, ckpt_filepath) + + def graceful_close(self): + # Delete the sample controller to clear + # the random seeds defined in the CUDA memory heap. + # Warning: Not closing gracefully could lead to a memory leak. + del self.cuda_sample_controller + if self.verbose: + verbose_print("Trainer exits gracefully", self.device_id) + + def fetch_episode_states( + self, + list_of_states=None, # list of states (data array names) to fetch + env_id=0, # environment id to fetch the states from + include_rewards_actions=False, # flag to output reward and action + policy="", # if include_rewards_actions=True, the corresponding policy tag if any + **sample_params + ): + """ + Step through env and fetch the desired states (data arrays on the GPU) + for an entire episode. The trained models will be used for evaluation. + """ + assert 0 <= env_id < self.num_envs + assert list_of_states is not None + assert isinstance(list_of_states, list) + assert len(list_of_states) > 0 + + logging.info(f"Fetching the episode states: {list_of_states} from the GPU.") + # Ensure env is reset before the start of training, and done flags are False + self.cuda_envs.reset_all_envs() + env = self.cuda_envs.env + + episode_states = {} + for state in list_of_states: + assert self.cuda_envs.cuda_data_manager.is_data_on_device( + state + ), f"{state} is not a valid array name on the GPU!" + # Note: Discard the first dimension, which is the env dimension + array_shape = self.cuda_envs.cuda_data_manager.get_shape(state)[1:] + + # Initialize the episode states + episode_states[state] = np.nan * np.stack( + [np.ones(array_shape) for _ in range(env.episode_length + 1)] + ) + + if include_rewards_actions: + policy_suffix = f"_{policy}" if len(policy) > 0 else "" + action_name = _ACTIONS + policy_suffix + reward_name = _REWARDS + policy_suffix + # Note the size is 1 step smaller than states because we do not have r_0 and a_T + episode_actions = np.zeros( + ( + env.episode_length, *self.cuda_envs.cuda_data_manager.get_shape(action_name)[1:] + ), + dtype=np.int32 + ) + episode_rewards= np.zeros( + ( + env.episode_length, *self.cuda_envs.cuda_data_manager.get_shape(reward_name)[1:] + ), + dtype=np.float32) + + for timestep in range(env.episode_length): + # Update the episode states s_t + for state in list_of_states: + episode_states[state][ + timestep + ] = self.cuda_envs.cuda_data_manager.pull_data_from_device(state)[ + env_id + ] + # Evaluate policies to compute action probabilities, we set batch_index=-1 to avoid batch writing + probabilities = self._evaluate_policies(batch_index=-1) + + # Sample actions + self._sample_actions(probabilities, **sample_params) + + # Step through all the environments + self.cuda_envs.step_all_envs() + + if include_rewards_actions: + # Update the episode action a_t + episode_actions[timestep] = \ + self.cuda_envs.cuda_data_manager.pull_data_from_device(action_name)[env_id] + # Update the episode reward r_(t+1) + episode_rewards[timestep] = \ + self.cuda_envs.cuda_data_manager.pull_data_from_device(reward_name)[env_id] + + # Fetch the states when episode is complete + if env.cuda_data_manager.pull_data_from_device("_done_")[env_id]: + for state in list_of_states: + episode_states[state][ + timestep + 1 + ] = self.cuda_envs.cuda_data_manager.pull_data_from_device(state)[ + env_id + ] + break + if include_rewards_actions: + return episode_states, episode_actions, episode_rewards + else: + return episode_states + + +class PerfStats: + """ + Performance stats that will be included in rollout metrics. + """ + + def __init__(self): + self.iters = 0 + self.steps = 0 + self.policy_eval_time = 0.0 + self.action_sample_time = 0.0 + self.env_step_time = 0.0 + self.training_time = 0.0 + self.total_time = 0.0 + + def get_perf_stats(self): + return { + "Mean policy eval time per iter (ms)": self.policy_eval_time + * 1000 + / self.iters, + "Mean action sample time per iter (ms)": self.action_sample_time + * 1000 + / self.iters, + "Mean env. step time per iter (ms)": self.env_step_time * 1000 / self.iters, + "Mean training time per iter (ms)": self.training_time * 1000 / self.iters, + "Mean total time per iter (ms)": self.total_time * 1000 / self.iters, + "Mean steps per sec (policy eval)": self.steps / self.policy_eval_time, + "Mean steps per sec (action sample)": self.steps / self.action_sample_time, + "Mean steps per sec (env. step)": self.steps / self.env_step_time, + "Mean steps per sec (training time)": self.steps / self.training_time, + "Mean steps per sec (total)": self.steps / self.total_time, + } + + @staticmethod + def pretty_print(stats): + print("=" * 40) + print("Speed performance stats") + print("=" * 40) + for k, v in stats.items(): + print(f"{k:40}: {v:10.2f}") + + +class Metrics: + """ + Metrics class to log and print the key metrics + """ + + def __init__(self): + pass + + def pretty_print(self, metrics): + assert metrics is not None + assert isinstance(metrics, dict) + + for policy in metrics: + print("=" * 40) + print(f"Metrics for policy '{policy}'") + print("=" * 40) + for k, v in metrics[policy].items(): + print(f"{k:40}: {v:10.5f}") From e55bf2f2fae30e55b1941b4ab015f4c4a86fc421 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Wed, 17 Jan 2024 21:56:44 -0800 Subject: [PATCH 09/19] ddpg trainer --- .../training/deprecate/pytorch_lightning.py | 2 +- .../training/run_configs/single_pendulum.yaml | 16 +- .../scripts/example_training_script_numba.py | 2 +- .../scripts/example_training_script_pycuda.py | 2 +- warp_drive/training/trainers/__init__.py | 5 + .../training/{ => trainers}/trainer_a2c.py | 6 +- .../training/{ => trainers}/trainer_base.py | 4 - warp_drive/training/trainers/trainer_ddpg.py | 461 ++++++++++++++++++ 8 files changed, 483 insertions(+), 15 deletions(-) create mode 100644 warp_drive/training/trainers/__init__.py rename warp_drive/training/{ => trainers}/trainer_a2c.py (98%) rename warp_drive/training/{ => trainers}/trainer_base.py (99%) create mode 100644 warp_drive/training/trainers/trainer_ddpg.py diff --git a/warp_drive/training/deprecate/pytorch_lightning.py b/warp_drive/training/deprecate/pytorch_lightning.py index 66cb3b6..fab22b2 100644 --- a/warp_drive/training/deprecate/pytorch_lightning.py +++ b/warp_drive/training/deprecate/pytorch_lightning.py @@ -31,7 +31,7 @@ from warp_drive.training.algorithms.policygradient.a2c import A2C from warp_drive.training.algorithms.policygradient.ppo import PPO from warp_drive.training.models.factory import ModelFactory -from warp_drive.training.trainer_a2c import Metrics +from warp_drive.training.trainers.trainer_a2c import Metrics from warp_drive.training.utils.data_loader import create_and_push_data_placeholders from warp_drive.training.utils.param_scheduler import LRScheduler, ParamScheduler from warp_drive.utils.common import get_project_root diff --git a/warp_drive/training/run_configs/single_pendulum.yaml b/warp_drive/training/run_configs/single_pendulum.yaml index 2d99ca4..5181c57 100644 --- a/warp_drive/training/run_configs/single_pendulum.yaml +++ b/warp_drive/training/run_configs/single_pendulum.yaml @@ -30,13 +30,15 @@ policy: # list all the policies below gamma: 0.99 # discount factor lr: 0.001 # learning rate model: # policy model settings - type: - actor: "fully_connected_actor" # model type - critic: "fully_connected_action_value_critic" - fc_dims: - actor: [32, 32] # dimension(s) of the fully connected layers as a list - critic: [32, 32] - model_ckpt_filepath: "" # filepath (used to restore a previously saved model) + actor: + type: "fully_connected_actor" # model type + fc_dims: [32, 32] + critic: + type: "fully_connected_action_value_critic" # model type + fc_dims: [32, 32] + model_ckpt_filepath: + actor: "" # filepath (used to restore a previously saved model) + critic: "" # Checkpoint saving setting saving: metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics diff --git a/warp_drive/training/scripts/example_training_script_numba.py b/warp_drive/training/scripts/example_training_script_numba.py index 0baf85b..8907f9d 100644 --- a/warp_drive/training/scripts/example_training_script_numba.py +++ b/warp_drive/training/scripts/example_training_script_numba.py @@ -23,7 +23,7 @@ from example_envs.single_agent.classic_control.mountain_car.mountain_car import CUDAClassicControlMountainCarEnv from example_envs.single_agent.classic_control.acrobot.acrobot import CUDAClassicControlAcrobotEnv from warp_drive.env_wrapper import EnvWrapper -from warp_drive.training.trainer_a2c import TrainerA2C +from warp_drive.training.trainers.trainer_a2c import TrainerA2C from warp_drive.training.utils.distributed_train.distributed_trainer_numba import ( perform_distributed_training, ) diff --git a/warp_drive/training/scripts/example_training_script_pycuda.py b/warp_drive/training/scripts/example_training_script_pycuda.py index efb9d56..c8f867e 100644 --- a/warp_drive/training/scripts/example_training_script_pycuda.py +++ b/warp_drive/training/scripts/example_training_script_pycuda.py @@ -20,7 +20,7 @@ from example_envs.tag_continuous.tag_continuous import TagContinuous from example_envs.tag_gridworld.tag_gridworld import CUDATagGridWorld from warp_drive.env_wrapper import EnvWrapper -from warp_drive.training.trainer_a2c import TrainerA2C +from warp_drive.training.trainers.trainer_a2c import TrainerA2C from warp_drive.training.utils.distributed_train.distributed_trainer_pycuda import ( perform_distributed_training, ) diff --git a/warp_drive/training/trainers/__init__.py b/warp_drive/training/trainers/__init__.py new file mode 100644 index 0000000..93bee4b --- /dev/null +++ b/warp_drive/training/trainers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause diff --git a/warp_drive/training/trainer_a2c.py b/warp_drive/training/trainers/trainer_a2c.py similarity index 98% rename from warp_drive/training/trainer_a2c.py rename to warp_drive/training/trainers/trainer_a2c.py index f1f45a1..c172e5a 100644 --- a/warp_drive/training/trainer_a2c.py +++ b/warp_drive/training/trainers/trainer_a2c.py @@ -20,7 +20,7 @@ from gym.spaces import Discrete, MultiDiscrete from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP -from warp_drive.training.trainer_base import TrainerBase, all_equal, verbose_print +from warp_drive.training.trainers.trainer_base import TrainerBase, all_equal, verbose_print from warp_drive.training.algorithms.policygradient.a2c import A2C from warp_drive.training.algorithms.policygradient.ppo import PPO @@ -53,7 +53,11 @@ def __init__( results_dir=None, verbose=True, ): + # Define models, optimizers, and learning rate schedules self.models = {} + self.optimizers = {} + self.lr_schedules = {} + super().__init__( env_wrapper=env_wrapper, config=config, diff --git a/warp_drive/training/trainer_base.py b/warp_drive/training/trainers/trainer_base.py similarity index 99% rename from warp_drive/training/trainer_base.py rename to warp_drive/training/trainers/trainer_base.py index 8d488e3..cb00085 100644 --- a/warp_drive/training/trainer_base.py +++ b/warp_drive/training/trainers/trainer_base.py @@ -245,10 +245,6 @@ def __init__( np.random.seed(seed) self.cuda_envs.init_reset_pool(seed + random.randint(1, 10000)) - # Define optimizers, and learning rate schedules - self.optimizers = {} - self.lr_schedules = {} - # For logging episodic reward self.num_completed_episodes = {} self.episodic_reward_sum = {} diff --git a/warp_drive/training/trainers/trainer_ddpg.py b/warp_drive/training/trainers/trainer_ddpg.py new file mode 100644 index 0000000..05dd2f9 --- /dev/null +++ b/warp_drive/training/trainers/trainer_ddpg.py @@ -0,0 +1,461 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +""" +The Trainer, PerfStats and Metrics classes +""" + +import json +import logging +import os +import random +import time + +import numpy as np +import torch +import yaml +from gym.spaces import Discrete, MultiDiscrete +from torch import nn +from torch.nn.parallel import DistributedDataParallel as DDP +from warp_drive.training.trainers.trainer_base import TrainerBase, all_equal, verbose_print + +from warp_drive.training.algorithms.policygradient.a2c import A2C +from warp_drive.training.algorithms.policygradient.ppo import PPO +from warp_drive.training.models.factory import ModelFactory +from warp_drive.training.utils.data_loader import create_and_push_data_placeholders +from warp_drive.training.utils.param_scheduler import ParamScheduler +from warp_drive.utils.common import get_project_root +from warp_drive.utils.constants import Constants + +_ROOT_DIR = get_project_root() + +_ACTIONS = Constants.ACTIONS +_REWARDS = Constants.REWARDS +_DONE_FLAGS = Constants.DONE_FLAGS +_PROCESSED_OBSERVATIONS = Constants.PROCESSED_OBSERVATIONS +_COMBINED = "combined" +_EPSILON = 1e-10 # small number to prevent indeterminate divisions + + +class TrainerDDPG(TrainerBase): + def __init__( + self, + env_wrapper=None, + config=None, + policy_tag_to_agent_id_map=None, + create_separate_placeholders_for_each_policy=False, + obs_dim_corresponding_to_num_agents="first", + num_devices=1, + device_id=0, + results_dir=None, + verbose=True, + ): + # Define models, optimizers, and learning rate schedules + self.actor_models = {} + self.critic_models = {} + self.actor_optimizers = {} + self.critic_optimizers = {} + self.actor_lr_schedules = {} + self.critic_lr_schedules = {} + + super().__init__( + env_wrapper=env_wrapper, + config=config, + policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=obs_dim_corresponding_to_num_agents, + num_devices=num_devices, + device_id=device_id, + results_dir=results_dir, + verbose=verbose, + ) + + def _initialize_policy_algorithm(self, policy): + algorithm = self._get_config(["policy", policy, "algorithm"]) + assert algorithm in ["DDPG"] + entropy_coeff = self._get_config(["policy", policy, "entropy_coeff"]) + vf_loss_coeff = self._get_config(["policy", policy, "vf_loss_coeff"]) + self.clip_grad_norm[policy] = self._get_config( + ["policy", policy, "clip_grad_norm"] + ) + if self.clip_grad_norm[policy]: + self.max_grad_norm[policy] = self._get_config( + ["policy", policy, "max_grad_norm"] + ) + normalize_advantage = self._get_config( + ["policy", policy, "normalize_advantage"] + ) + normalize_return = self._get_config(["policy", policy, "normalize_return"]) + gamma = self._get_config(["policy", policy, "gamma"]) + if algorithm == "A2C": + # Advantage Actor-Critic + self.trainers[policy] = A2C( + discount_factor_gamma=gamma, + normalize_advantage=normalize_advantage, + normalize_return=normalize_return, + vf_loss_coeff=vf_loss_coeff, + entropy_coeff=entropy_coeff, + ) + logging.info(f"Initializing the A2C trainer for policy {policy}") + elif algorithm == "PPO": + # Proximal Policy Optimization + clip_param = self._get_config(["policy", policy, "clip_param"]) + self.trainers[policy] = PPO( + discount_factor_gamma=gamma, + clip_param=clip_param, + normalize_advantage=normalize_advantage, + normalize_return=normalize_return, + vf_loss_coeff=vf_loss_coeff, + entropy_coeff=entropy_coeff, + ) + logging.info(f"Initializing the PPO trainer for policy {policy}") + else: + raise NotImplementedError + + def _initialize_policy_model(self, policy): + if "actor" not in self._get_config(["policy", policy, "model"]) or "critic" \ + not in self._get_config(["policy", policy, "model"]): + actor_model_config = self._get_config(["policy", policy, "model"]) + critic_model_config = actor_model_config + else: + actor_model_config = self._get_config(["policy", policy, "model", "actor"]) + critic_model_config = self._get_config(["policy", policy, "model", "critic"]) + + model_obj_actor = ModelFactory.create(actor_model_config["type"]) + actor = model_obj_actor( + env=self.cuda_envs, + model_config=actor_model_config, + policy=policy, + policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, + ) + + if "init_method" in actor_model_config and \ + actor_model_config["init_method"] == "xavier": + def init_weights_by_xavier_uniform(m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform(m.weight) + + actor.apply(init_weights_by_xavier_uniform) + + self.actor_models[policy] = actor + + model_obj_critic = ModelFactory.create(critic_model_config["type"]) + critic = model_obj_critic( + env=self.cuda_envs, + model_config=critic_model_config, + policy=policy, + policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, + ) + + if "init_method" in critic_model_config and \ + critic_model_config["init_method"] == "xavier": + def init_weights_by_xavier_uniform(m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform(m.weight) + + critic.apply(init_weights_by_xavier_uniform) + + self.critic_models[policy] = critic + + def _send_policy_model_to_device(self, policy): + self.actor_models[policy].cuda() + self.critic_models[policy].cuda() + # If distributed train, sync model using DDP + if self.num_devices > 1: + self.actor_models[policy] = DDP( + self.actor_models[policy], device_ids=[self.device_id] + ) + self.critic_models[policy] = DDP( + self.critic_models[policy], device_ids=[self.device_id] + ) + self.ddp_mode[policy] = True + else: + self.ddp_mode[policy] = False + + def _initialize_optimizer(self, policy): + # Initialize the (ADAM) optimizer + if "actor" not in self._get_config(["policy", policy, "lr"]) or "critic" \ + not in self._get_config(["policy", policy, "lr"]): + actor_lr_config = self._get_config(["policy", policy, "lr"]) + critic_lr_config = actor_lr_config + else: + actor_lr_config = self._get_config(["policy", policy, "lr", "actor"]) + critic_lr_config = self._get_config(["policy", policy, "lr", "critic"]) + + self.actor_lr_schedules[policy] = ParamScheduler(actor_lr_config) + self.critic_lr_schedules[policy] = ParamScheduler(critic_lr_config) + initial_actor_lr = self.actor_lr_schedules[policy].get_param_value( + timestep=self.current_timestep[policy] + ) + initial_critic_lr = self.critic_lr_schedules[policy].get_param_value( + timestep=self.current_timestep[policy] + ) + self.actor_optimizers[policy] = torch.optim.Adam( + self.actor_models[policy].parameters(), lr=initial_actor_lr + ) + self.critic_optimizers[policy] = torch.optim.Adam( + self.critic_models[policy].parameters(), lr=initial_critic_lr + ) + + def _evaluate_policies(self, batch_index=0): + """ + Perform the policy evaluation (forward pass through the models) + and compute action probabilities + """ + assert isinstance(batch_index, int) + probabilities = {} + for policy in self.policies: + if self.ddp_mode[policy]: + # self.models[policy] is a DDP wrapper of the model instance + obs = self.actor_models[policy].module.process_one_step_obs() + self.actor_models[policy].module.push_processed_obs_to_batch(batch_index, obs) + else: + obs = self.actor_models[policy].process_one_step_obs() + self.actor_models[policy].push_processed_obs_to_batch(batch_index, obs) + probabilities[policy], _ = self.actor_models[policy](obs) + + # Combine probabilities across policies if there are multiple policies, + # yet they share the same action placeholders. + # The action sampler will then need to run just once on each action type. + if ( + len(self.policies) > 1 + and not self.create_separate_placeholders_for_each_policy + ): + # Assert that all the probabilities are of the same length + # in other words the number of action types for each policy + # is the same. + num_action_types = {} + for policy in self.policies: + num_action_types[policy] = len(probabilities[policy]) + assert all_equal(list(num_action_types.values())) + + # Initialize combined_probabilities. + first_policy = list(probabilities.keys())[0] + num_action_types = num_action_types[first_policy] + + first_action_type_id = 0 + num_envs = probabilities[first_policy][first_action_type_id].shape[0] + num_agents = self.cuda_envs.env.num_agents + + combined_probabilities = [None for _ in range(num_action_types)] + for action_type_id in range(num_action_types): + action_dim = probabilities[first_policy][action_type_id].shape[-1] + combined_probabilities[action_type_id] = torch.zeros( + (num_envs, num_agents, action_dim) + ).cuda() + + # Combine the probabilities across policies + for action_type_id in range(num_action_types): + for policy, prob_values in probabilities.items(): + agent_to_id_mapping = self.policy_tag_to_agent_id_map[policy] + combined_probabilities[action_type_id][ + :, agent_to_id_mapping + ] = prob_values[action_type_id] + + probabilities = {_COMBINED: combined_probabilities} + + return probabilities + + def _update_model_params(self, iteration): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + # Flag for logging (which also happens after the last iteration) + logging_flag = ( + iteration % self.config["saving"]["metrics_log_freq"] == 0 + or iteration == self.num_iters - 1 + ) + + metrics_dict = {} + + done_flags_batch = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_DONE_FLAGS}_batch" + ) + # On the device, observations_batch, actions_batch, + # rewards_batch are all shaped + # (batch_size, num_envs, num_agents, *feature_dim). + # done_flags_batch is shaped (batch_size, num_envs) + # Perform training sequentially for each policy + for policy in self.policies_to_train: + actions_batch = ( + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_ACTIONS}_batch_{policy}" + ) + ) + rewards_batch = ( + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_REWARDS}_batch_{policy}" + ) + ) + processed_obs_batch = ( + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + f"{_PROCESSED_OBSERVATIONS}_batch_{policy}" + ) + ) + # Policy evaluation for the entire batch + probabilities_batch = self.actor_models[policy]( + obs=processed_obs_batch + ) + value_functions_batch = self.critic_models[policy]( + obs=processed_obs_batch, action=actions_batch + ) + # Loss and metrics computation + actor_loss, critic_loss, metrics = self.trainers[policy].compute_loss_and_metrics( + self.current_timestep[policy], + actions_batch, + rewards_batch, + done_flags_batch, + probabilities_batch, + value_functions_batch, + perform_logging=logging_flag, + ) + # Compute the gradient norm + actor_grad_norm = 0.0 + for param in list( + filter(lambda p: p.grad is not None, self.actor_models[policy].parameters()) + ): + actor_grad_norm += param.grad.data.norm(2).item() + + critic_grad_norm = 0.0 + for param in list( + filter(lambda p: p.grad is not None, self.critic_models[policy].parameters()) + ): + critic_grad_norm += param.grad.data.norm(2).item() + + # Update the timestep and learning rate based on the schedule + self.current_timestep[policy] += self.training_batch_size + actor_lr = self.actor_lr_schedules[policy].get_param_value( + self.current_timestep[policy] + ) + for param_group in self.actor_optimizers[policy].param_groups: + param_group["lr"] = actor_lr + + critic_lr = self.critic_lr_schedules[policy].get_param_value( + self.current_timestep[policy] + ) + for param_group in self.critic_optimizers[policy].param_groups: + param_group["lr"] = critic_lr + + # Loss backpropagation and optimization step + self.actor_optimizers[policy].zero_grad() + self.critic_optimizers[policy].zero_grad() + actor_loss.backward() + critic_loss.backward() + if self.clip_grad_norm[policy]: + nn.utils.clip_grad_norm_( + self.actor_models[policy].parameters(), self.max_grad_norm[policy] + ) + nn.utils.clip_grad_norm_( + self.critic_models[policy].parameters(), self.max_grad_norm[policy] + ) + + self.actor_optimizers[policy].step() + self.critic_optimizers[policy].step() + # Logging + if logging_flag: + metrics_dict[policy] = metrics + # Update the metrics dictionary + metrics_dict[policy].update( + { + "Current timestep": self.current_timestep[policy], + "Gradient norm (Actor)": actor_grad_norm, + "Gradient norm (Critic)": critic_grad_norm, + "Learning rate (Actor)": actor_lr, + "Learning rate (Critic)": critic_lr, + "Mean episodic reward": self.episodic_reward_sum[policy].item() + / (self.num_completed_episodes[policy] + _EPSILON), + "Mean episodic steps": self.episodic_step_sum[policy].item() + / (self.num_completed_episodes[policy] + _EPSILON), + } + ) + + # Reset sum and counter + self.episodic_reward_sum[policy] = ( + torch.tensor(0).type(torch.float32).cuda() + ) + self.episodic_step_sum[policy] = ( + torch.tensor(0).type(torch.int64).cuda() + ) + self.num_completed_episodes[policy] = 0 + + end_event.record() + torch.cuda.synchronize() + + self.perf_stats.training_time += start_event.elapsed_time(end_event) / 1000 + return metrics_dict + + def _load_model_checkpoint_helper(self, policy, ckpt_filepath): + if isinstance(ckpt_filepath, dict) and "actor" in ckpt_filepath and "critic" in ckpt_filepath: + if ckpt_filepath["actor"] != "" and ckpt_filepath["critic"] != "": + assert os.path.isfile(ckpt_filepath["actor"]), "Invalid actor model checkpoint path!" + assert os.path.isfile(ckpt_filepath["critic"]), "Invalid critic model checkpoint path!" + # Update the current timestep using the saved checkpoint filename + actor_timestep = int(ckpt_filepath["actor"].split(".state_dict")[0].split("_")[-1]) + critic_timestep = int(ckpt_filepath["critic"].split(".state_dict")[0].split("_")[-1]) + assert actor_timestep == critic_timestep, \ + "The timestep is different between the actor model and the critic model " + if self.verbose: + verbose_print( + f"Loading the '{policy}' torch model " + f"from the previously saved checkpoint: " + f"actor: '{ckpt_filepath['actor']}'" + f"critic: '{ckpt_filepath['critic']}'", + self.device_id, + ) + self.actor_models[policy].load_state_dict(torch.load(ckpt_filepath["actor"])) + self.critic_models[policy].load_state_dict(torch.load(ckpt_filepath["critic"])) + + if self.verbose: + verbose_print( + f"Updating the timestep for the '{policy}' model to {timestep}.", + self.device_id, + ) + self.current_timestep[policy] = actor_timestep + + def save_model_checkpoint(self, iteration=0): + """ + Save the model parameters + """ + # If multiple devices, save the synced-up model only for device id 0 + if self.device_id == 0: + # Save model checkpoints if specified (and also for the last iteration) + if ( + iteration % self.config["saving"]["model_params_save_freq"] == 0 + or iteration == self.num_iters - 1 + ): + for policy, actor_model in self.actor_models.items(): + filepath = os.path.join( + self.save_dir, + f"{policy}_actor_{self.current_timestep[policy]}.state_dict", + ) + if self.verbose: + verbose_print( + f"Saving the '{policy}' (actor) torch model " + f"to the file: '{filepath}'.", + self.device_id, + ) + + torch.save(actor_model.state_dict(), filepath) + + for policy, critic_model in self.critic_models.items(): + filepath = os.path.join( + self.save_dir, + f"{policy}_critic_{self.current_timestep[policy]}.state_dict", + ) + if self.verbose: + verbose_print( + f"Saving the '{policy}' (critic) torch model " + f"to the file: '{filepath}'.", + self.device_id, + ) + + torch.save(critic_model.state_dict(), filepath) From 0696237a662a1a145f20a755febb37ff3d8b5f72 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Sun, 21 Jan 2024 10:35:44 -0800 Subject: [PATCH 10/19] mostly finalize ddpg --- .../algorithms/policygradient/ddpg.py | 158 ++++++++++++++++++ warp_drive/training/models/fully_connected.py | 1 + .../models/fully_connected_actor_critic.py | 8 +- warp_drive/training/models/model_base.py | 5 +- .../training/run_configs/single_pendulum.yaml | 24 ++- .../scripts/example_training_script_numba.py | 44 ++++- warp_drive/training/trainers/trainer_base.py | 2 +- warp_drive/training/trainers/trainer_ddpg.py | 44 ++--- 8 files changed, 239 insertions(+), 47 deletions(-) create mode 100644 warp_drive/training/algorithms/policygradient/ddpg.py diff --git a/warp_drive/training/algorithms/policygradient/ddpg.py b/warp_drive/training/algorithms/policygradient/ddpg.py new file mode 100644 index 0000000..aa7c941 --- /dev/null +++ b/warp_drive/training/algorithms/policygradient/ddpg.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause +# + +import torch +from torch import nn +from torch.distributions import Categorical + +from warp_drive.training.utils.param_scheduler import ParamScheduler + +_EPSILON = 1e-10 # small number to prevent indeterminate division + + +class DDPG: + """ + The Advantage Actor-Critic Class + https://arxiv.org/abs/1602.01783 + """ + + def __init__( + self, + discount_factor_gamma=1.0, + normalize_advantage=False, + normalize_return=False, + ): + assert 0 <= discount_factor_gamma <= 1 + self.discount_factor_gamma = discount_factor_gamma + self.normalize_advantage = normalize_advantage + self.normalize_return = normalize_return + + def compute_loss_and_metrics( + self, + timestep=None, + actions_batch=None, + rewards_batch=None, + done_flags_batch=None, + value_functions_batch=None, + j_functions_batch=None, + perform_logging=False, + ): + assert actions_batch is not None + assert timestep is not None + assert rewards_batch is not None + assert done_flags_batch is not None + assert value_functions_batch is not None + assert j_functions_batch is not None + + # Detach value_functions_batch from the computation graph + # for return and advantage computations. + value_functions_batch_detached = value_functions_batch.detach() + + # Value objective. + returns_batch = torch.zeros_like(rewards_batch) + + returns_batch[-1] = ( + done_flags_batch[-1][:, None] * rewards_batch[-1] + + (1 - done_flags_batch[-1][:, None]) * value_functions_batch_detached[-1] + ) + for step in range(-2, -returns_batch.shape[0] - 1, -1): + future_return = ( + done_flags_batch[step][:, None] * torch.zeros_like(rewards_batch[step]) + + (1 - done_flags_batch[step][:, None]) + * self.discount_factor_gamma + * returns_batch[step + 1] + ) + returns_batch[step] = rewards_batch[step] + future_return + + # Normalize across the agents and env dimensions + if self.normalize_return: + normalized_returns_batch = ( + returns_batch - returns_batch.mean(dim=(1, 2), keepdim=True) + ) / (returns_batch.std(dim=(1, 2), keepdim=True) + torch.tensor(_EPSILON)) + else: + normalized_returns_batch = returns_batch + + critic_loss = nn.MSELoss()(normalized_returns_batch, value_functions_batch) + + advantages_batch = normalized_returns_batch - value_functions_batch_detached + + # Normalize across the agents and env dimensions + if self.normalize_advantage: + normalized_advantages_batch = ( + advantages_batch - advantages_batch.mean(dim=(1, 2), keepdim=True) + ) / ( + advantages_batch.std(dim=(1, 2), keepdim=True) + torch.tensor(_EPSILON) + ) + else: + normalized_advantages_batch = advantages_batch + + # Policy objective + + actor_loss = -j_functions_batch.mean() + + variance_explained = max( + torch.tensor(-1.0), + ( + 1 + - ( + normalized_advantages_batch.detach().var() + / (normalized_returns_batch.detach().var() + torch.tensor(_EPSILON)) + ) + ), + ) + + if perform_logging: + metrics = { + "Total loss": actor_loss.item() + critic_loss.item(), + "Actor loss": actor_loss.item(), + "Critic loss": critic_loss.item(), + "Mean rewards": rewards_batch.mean().item(), + "Max. rewards": rewards_batch.max().item(), + "Min. rewards": rewards_batch.min().item(), + "Mean value function": value_functions_batch.mean().item(), + "Mean J function": j_functions_batch.mean().item(), + "Mean advantages": advantages_batch.mean().item(), + "Mean (norm.) advantages": normalized_advantages_batch.mean().item(), + "Mean (discounted) returns": returns_batch.mean().item(), + "Mean normalized returns": normalized_returns_batch.mean().item(), + "Variance explained by the value function": variance_explained.item(), + } + # mean of the standard deviation of sampled actions + std_over_agent_per_action = ( + actions_batch.float().std(axis=2).mean(axis=(0, 1)) + ) + std_over_time_per_action = ( + actions_batch.float().std(axis=0).mean(axis=(0, 1)) + ) + std_over_env_per_action = ( + actions_batch.float().std(axis=1).mean(axis=(0, 1)) + ) + # max_per_action = [actions_batch.float().max()] + # min_per_action = [actions_batch.float().min()] + + for idx, _ in enumerate(std_over_agent_per_action): + std_action = { + f"Std. of action_{idx} over agents": std_over_agent_per_action[ + idx + ].item(), + f"Std. of action_{idx} over envs": std_over_env_per_action[ + idx + ].item(), + f"Std. of action_{idx} over time": std_over_time_per_action[ + idx + ].item(), + # f"Max of action_{idx}": max_per_action[ + # idx + # ].item(), + # f"Min of action_{idx}": min_per_action[ + # idx + # ].item(), + } + metrics.update(std_action) + else: + metrics = {} + return actor_loss, critic_loss, metrics diff --git a/warp_drive/training/models/fully_connected.py b/warp_drive/training/models/fully_connected.py index 12fd348..6d5b628 100644 --- a/warp_drive/training/models/fully_connected.py +++ b/warp_drive/training/models/fully_connected.py @@ -64,6 +64,7 @@ def forward(self, obs=None, action=None): # Apply action mask to the logits as well. if self.is_deterministic: combined_action_probs = func.tanh(apply_logit_mask(self.policy_head(logits), self.action_mask)) + combined_action_probs = self.action_scale * combined_action_probs + self.action_bias if self.output_dims[0] > 1: # we split the actions to their corresponding heads # we make sure after the split, we rearrange the memory so each chunk is still C-continguous diff --git a/warp_drive/training/models/fully_connected_actor_critic.py b/warp_drive/training/models/fully_connected_actor_critic.py index 185a5fc..cb665fe 100644 --- a/warp_drive/training/models/fully_connected_actor_critic.py +++ b/warp_drive/training/models/fully_connected_actor_critic.py @@ -65,6 +65,7 @@ def forward(self, obs=None, action=None): # Apply action mask to the logits as well. if self.is_deterministic: combined_action_probs = func.tanh(apply_logit_mask(self.policy_head(logits), self.action_mask)) + combined_action_probs = self.action_scale * combined_action_probs + self.action_bias if self.output_dims[0] > 1: # we split the actions to their corresponding heads # we make sure after the split, we rearrange the memory so each chunk is still C-continguous @@ -127,7 +128,12 @@ def forward(self, obs=None, action=None): Forward pass through the model. Returns Q value. """ - ip = torch.cat([obs, action], dim=-1) + assert action is not None + if isinstance(action, list): + obs_n_act = [obs] + action + else: + obs_n_act = [obs, action] + ip = torch.cat(obs_n_act, dim=-1) # Feed through the FC layers for layer in range(len(self.fc)): op = self.fc[str(layer)](ip) diff --git a/warp_drive/training/models/model_base.py b/warp_drive/training/models/model_base.py index 82a0606..a97bac8 100644 --- a/warp_drive/training/models/model_base.py +++ b/warp_drive/training/models/model_base.py @@ -47,6 +47,9 @@ def __init__( self.env = env self.fc_dims = model_config["fc_dims"] + self.action_scale = model_config["output_w"] if "output_w" in model_config else 1.0 + self.action_bias = model_config["output_b"] if "output_b" in model_config else 0.0 + assert isinstance(self.fc_dims, list) self.policy = policy self.policy_tag_to_agent_id_map = policy_tag_to_agent_id_map @@ -80,7 +83,7 @@ def __init__( # policy network (list of heads) if self.is_deterministic: self.output_dims = [len(action_space)] - self.policy_head = nn.Linear(self.fc_dims, len(action_space)) + self.policy_head = nn.Linear(self.fc_dims[-1], len(action_space)) else: policy_heads = [None for _ in range(len(action_space))] self.output_dims = [] # Network output dimension(s) diff --git a/warp_drive/training/run_configs/single_pendulum.yaml b/warp_drive/training/run_configs/single_pendulum.yaml index 5181c57..38abd20 100644 --- a/warp_drive/training/run_configs/single_pendulum.yaml +++ b/warp_drive/training/run_configs/single_pendulum.yaml @@ -9,36 +9,42 @@ name: "single_pendulum" # Environment settings env: episode_length: 500 - reset_pool_size: 1000 + reset_pool_size: 10000 # Trainer settings trainer: - num_envs: 100 # number of environment replicas - num_episodes: 200000 # number of episodes to run the training for. Can be arbitrarily high! - train_batch_size: 50000 # total batch size used for training per iteration (across all the environments) + num_envs: 10000 # number of environment replicas + num_episodes: 20000000 # number of episodes to run the training for. Can be arbitrarily high! + train_batch_size: 5000000 # total batch size used for training per iteration (across all the environments) env_backend: "numba" # environment backend, pycuda or numba # Policy network settings policy: # list all the policies below shared: to_train: True # flag indicating whether the model needs to be trained algorithm: "DDPG" # algorithm used to train the policy - vf_loss_coeff: 1 # loss coefficient schedule for the value function loss - entropy_coeff: 0.05 # loss coefficient schedule for the entropy loss clip_grad_norm: True # flag indicating whether to clip the gradient norm or not max_grad_norm: 3 # when clip_grad_norm is True, the clip level - normalize_advantage: False # flag indicating whether to normalize advantage or not - normalize_return: False # flag indicating whether to normalize return or not + normalize_advantage: True # flag indicating whether to normalize advantage or not + normalize_return: True # flag indicating whether to normalize return or not gamma: 0.99 # discount factor - lr: 0.001 # learning rate + lr: + actor: 0.001 # learning rate + critic: 0.001 model: # policy model settings actor: type: "fully_connected_actor" # model type fc_dims: [32, 32] + output_w: 2.0 # model default range is (-1, 1), this changes to (-2, 2) critic: type: "fully_connected_action_value_critic" # model type fc_dims: [32, 32] model_ckpt_filepath: actor: "" # filepath (used to restore a previously saved model) critic: "" +sampler: + params: + damping: 0.15 + stddev: 0.2 + scale: 1.0 # Checkpoint saving setting saving: metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics diff --git a/warp_drive/training/scripts/example_training_script_numba.py b/warp_drive/training/scripts/example_training_script_numba.py index 8907f9d..b2d6895 100644 --- a/warp_drive/training/scripts/example_training_script_numba.py +++ b/warp_drive/training/scripts/example_training_script_numba.py @@ -22,8 +22,10 @@ from example_envs.single_agent.classic_control.cartpole.cartpole import CUDAClassicControlCartPoleEnv from example_envs.single_agent.classic_control.mountain_car.mountain_car import CUDAClassicControlMountainCarEnv from example_envs.single_agent.classic_control.acrobot.acrobot import CUDAClassicControlAcrobotEnv +from example_envs.single_agent.classic_control.pendulum.pendulum import CUDAClassicControlPendulumEnv from warp_drive.env_wrapper import EnvWrapper from warp_drive.training.trainers.trainer_a2c import TrainerA2C +from warp_drive.training.trainers.trainer_ddpg import TrainerDDPG from warp_drive.training.utils.distributed_train.distributed_trainer_numba import ( perform_distributed_training, ) @@ -39,6 +41,7 @@ _CLASSIC_CONTROL_CARTPOLE = "single_cartpole" _CLASSIC_CONTROL_MOUNTAIN_CAR = "single_mountain_car" _CLASSIC_CONTROL_ACROBOT = "single_acrobot" +_CLASSIC_CONTROL_PENDULUM = "single_pendulum" # Example usages (from the root folder): @@ -113,6 +116,14 @@ def setup_trainer_and_train( event_messenger=event_messenger, process_id=device_id, ) + elif run_configuration["name"] == _CLASSIC_CONTROL_PENDULUM: + env_wrapper = EnvWrapper( + CUDAClassicControlPendulumEnv(**run_configuration["env"]), + num_envs=num_envs, + env_backend="numba", + event_messenger=event_messenger, + process_id=device_id, + ) else: raise NotImplementedError( f"Currently, the environments supported are [" @@ -120,6 +131,9 @@ def setup_trainer_and_train( f"{_TAG_CONTINUOUS}" f"{_TAG_GRIDWORLD_WITH_RESET_POOL}" f"{_CLASSIC_CONTROL_CARTPOLE}" + f"{_CLASSIC_CONTROL_MOUNTAIN_CAR}" + f"{_CLASSIC_CONTROL_ACROBOT}" + f"{_CLASSIC_CONTROL_PENDULUM}" f"]", ) # Policy mapping to agent ids: agents can share models @@ -152,15 +166,27 @@ def setup_trainer_and_train( ) # Trainer object # -------------- - trainer = TrainerA2C( - env_wrapper=env_wrapper, - config=run_configuration, - policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, - device_id=device_id, - num_devices=num_devices, - results_dir=results_directory, - verbose=verbose, - ) + first_policy_name = list(run_configuration["policy"])[0] + if run_configuration["policy"][first_policy_name]["algorithm"] == "DDPG": + trainer = TrainerDDPG( + env_wrapper=env_wrapper, + config=run_configuration, + policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, + device_id=device_id, + num_devices=num_devices, + results_dir=results_directory, + verbose=verbose, + ) + else: + trainer = TrainerA2C( + env_wrapper=env_wrapper, + config=run_configuration, + policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, + device_id=device_id, + num_devices=num_devices, + results_dir=results_directory, + verbose=verbose, + ) # Perform training # ---------------- diff --git a/warp_drive/training/trainers/trainer_base.py b/warp_drive/training/trainers/trainer_base.py index cb00085..beb227c 100644 --- a/warp_drive/training/trainers/trainer_base.py +++ b/warp_drive/training/trainers/trainer_base.py @@ -219,7 +219,7 @@ def __init__( ) elif env_wrapper.env_backend == "numba": from warp_drive.managers.numba_managers.numba_function_manager import ( - NumbaSampler, + NumbaSampler ) self.cuda_sample_controller = NumbaSampler( diff --git a/warp_drive/training/trainers/trainer_ddpg.py b/warp_drive/training/trainers/trainer_ddpg.py index 05dd2f9..ca8339d 100644 --- a/warp_drive/training/trainers/trainer_ddpg.py +++ b/warp_drive/training/trainers/trainer_ddpg.py @@ -22,8 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from warp_drive.training.trainers.trainer_base import TrainerBase, all_equal, verbose_print -from warp_drive.training.algorithms.policygradient.a2c import A2C -from warp_drive.training.algorithms.policygradient.ppo import PPO +from warp_drive.training.algorithms.policygradient.ddpg import DDPG from warp_drive.training.models.factory import ModelFactory from warp_drive.training.utils.data_loader import create_and_push_data_placeholders from warp_drive.training.utils.param_scheduler import ParamScheduler @@ -76,8 +75,6 @@ def __init__( def _initialize_policy_algorithm(self, policy): algorithm = self._get_config(["policy", policy, "algorithm"]) assert algorithm in ["DDPG"] - entropy_coeff = self._get_config(["policy", policy, "entropy_coeff"]) - vf_loss_coeff = self._get_config(["policy", policy, "vf_loss_coeff"]) self.clip_grad_norm[policy] = self._get_config( ["policy", policy, "clip_grad_norm"] ) @@ -90,33 +87,20 @@ def _initialize_policy_algorithm(self, policy): ) normalize_return = self._get_config(["policy", policy, "normalize_return"]) gamma = self._get_config(["policy", policy, "gamma"]) - if algorithm == "A2C": + if algorithm == "DDPG": # Advantage Actor-Critic - self.trainers[policy] = A2C( + self.trainers[policy] = DDPG( discount_factor_gamma=gamma, normalize_advantage=normalize_advantage, normalize_return=normalize_return, - vf_loss_coeff=vf_loss_coeff, - entropy_coeff=entropy_coeff, ) - logging.info(f"Initializing the A2C trainer for policy {policy}") - elif algorithm == "PPO": - # Proximal Policy Optimization - clip_param = self._get_config(["policy", policy, "clip_param"]) - self.trainers[policy] = PPO( - discount_factor_gamma=gamma, - clip_param=clip_param, - normalize_advantage=normalize_advantage, - normalize_return=normalize_return, - vf_loss_coeff=vf_loss_coeff, - entropy_coeff=entropy_coeff, - ) - logging.info(f"Initializing the PPO trainer for policy {policy}") + logging.info(f"Initializing the DDPG trainer for policy {policy}") else: raise NotImplementedError def _initialize_policy_model(self, policy): - if "actor" not in self._get_config(["policy", policy, "model"]) or "critic" \ + if not isinstance(self._get_config(["policy", policy, "model"]), dict) or \ + "actor" not in self._get_config(["policy", policy, "model"]) or "critic" \ not in self._get_config(["policy", policy, "model"]): actor_model_config = self._get_config(["policy", policy, "model"]) critic_model_config = actor_model_config @@ -181,7 +165,8 @@ def _send_policy_model_to_device(self, policy): def _initialize_optimizer(self, policy): # Initialize the (ADAM) optimizer - if "actor" not in self._get_config(["policy", policy, "lr"]) or "critic" \ + if not isinstance(self._get_config(["policy", policy, "lr"]), dict) or \ + "actor" not in self._get_config(["policy", policy, "lr"]) or "critic" \ not in self._get_config(["policy", policy, "lr"]): actor_lr_config = self._get_config(["policy", policy, "lr"]) critic_lr_config = actor_lr_config @@ -219,7 +204,7 @@ def _evaluate_policies(self, batch_index=0): else: obs = self.actor_models[policy].process_one_step_obs() self.actor_models[policy].push_processed_obs_to_batch(batch_index, obs) - probabilities[policy], _ = self.actor_models[policy](obs) + probabilities[policy] = self.actor_models[policy](obs) # Combine probabilities across policies if there are multiple policies, # yet they share the same action placeholders. @@ -247,6 +232,7 @@ def _evaluate_policies(self, batch_index=0): combined_probabilities = [None for _ in range(num_action_types)] for action_type_id in range(num_action_types): action_dim = probabilities[first_policy][action_type_id].shape[-1] + assert action_dim == 1, "action_dim != 1 but DDPG samples deterministic actions" combined_probabilities[action_type_id] = torch.zeros( (num_envs, num_agents, action_dim) ).cuda() @@ -305,17 +291,23 @@ def _update_model_params(self, iteration): probabilities_batch = self.actor_models[policy]( obs=processed_obs_batch ) + # Critic Q(s, a) is a function of both obs and action + # value_functions_batch includes sampled actions value_functions_batch = self.critic_models[policy]( obs=processed_obs_batch, action=actions_batch ) + # j_functions_batch includes actor network for the back-propagation + j_functions_batch = self.critic_models[policy]( + obs=processed_obs_batch, action=probabilities_batch + ) # Loss and metrics computation actor_loss, critic_loss, metrics = self.trainers[policy].compute_loss_and_metrics( self.current_timestep[policy], actions_batch, rewards_batch, done_flags_batch, - probabilities_batch, value_functions_batch, + j_functions_batch, perform_logging=logging_flag, ) # Compute the gradient norm @@ -416,7 +408,7 @@ def _load_model_checkpoint_helper(self, policy, ckpt_filepath): if self.verbose: verbose_print( - f"Updating the timestep for the '{policy}' model to {timestep}.", + f"Updating the timestep for the '{policy}' model to {actor_timestep}.", self.device_id, ) self.current_timestep[policy] = actor_timestep From 9a8641fbec33ca2a067bc94e5db5a0672da4f04f Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Wed, 7 Feb 2024 22:13:12 -0800 Subject: [PATCH 11/19] configs --- .../algorithms/policygradient/ddpg.py | 43 +++++++------ .../training/run_configs/single_pendulum.yaml | 21 ++++--- warp_drive/training/trainers/trainer_base.py | 24 ++++++-- warp_drive/training/trainers/trainer_ddpg.py | 61 ++++++++++++++++++- 4 files changed, 112 insertions(+), 37 deletions(-) diff --git a/warp_drive/training/algorithms/policygradient/ddpg.py b/warp_drive/training/algorithms/policygradient/ddpg.py index aa7c941..399730d 100644 --- a/warp_drive/training/algorithms/policygradient/ddpg.py +++ b/warp_drive/training/algorithms/policygradient/ddpg.py @@ -38,6 +38,7 @@ def compute_loss_and_metrics( rewards_batch=None, done_flags_batch=None, value_functions_batch=None, + next_value_functions_batch=None, j_functions_batch=None, perform_logging=False, ): @@ -46,27 +47,22 @@ def compute_loss_and_metrics( assert rewards_batch is not None assert done_flags_batch is not None assert value_functions_batch is not None + assert next_value_functions_batch is not None assert j_functions_batch is not None # Detach value_functions_batch from the computation graph # for return and advantage computations. - value_functions_batch_detached = value_functions_batch.detach() + next_value_functions_batch_detached = next_value_functions_batch.detach() # Value objective. returns_batch = torch.zeros_like(rewards_batch) returns_batch[-1] = ( done_flags_batch[-1][:, None] * rewards_batch[-1] - + (1 - done_flags_batch[-1][:, None]) * value_functions_batch_detached[-1] + + (1 - done_flags_batch[-1][:, None]) * next_value_functions_batch_detached[-1] ) - for step in range(-2, -returns_batch.shape[0] - 1, -1): - future_return = ( - done_flags_batch[step][:, None] * torch.zeros_like(rewards_batch[step]) - + (1 - done_flags_batch[step][:, None]) - * self.discount_factor_gamma - * returns_batch[step + 1] - ) - returns_batch[step] = rewards_batch[step] + future_return + returns_batch[:-1] = rewards_batch[:-1] + \ + self.discount_factor_gamma * (1 - done_flags_batch[:-1][:, :, None]) * next_value_functions_batch_detached # Normalize across the agents and env dimensions if self.normalize_return: @@ -78,7 +74,7 @@ def compute_loss_and_metrics( critic_loss = nn.MSELoss()(normalized_returns_batch, value_functions_batch) - advantages_batch = normalized_returns_batch - value_functions_batch_detached + advantages_batch = normalized_returns_batch - value_functions_batch # Normalize across the agents and env dimensions if self.normalize_advantage: @@ -91,8 +87,15 @@ def compute_loss_and_metrics( normalized_advantages_batch = advantages_batch # Policy objective + if self.normalize_return: + normalized_j_functions_batch = ( + j_functions_batch - j_functions_batch.mean(dim=(1, 2), keepdim=True) + ) / ( + j_functions_batch.std(dim=(1, 2), keepdim=True) + torch.tensor(_EPSILON)) + else: + normalized_j_functions_batch = j_functions_batch - actor_loss = -j_functions_batch.mean() + actor_loss = -normalized_j_functions_batch.mean() variance_explained = max( torch.tensor(-1.0), @@ -131,8 +134,8 @@ def compute_loss_and_metrics( std_over_env_per_action = ( actions_batch.float().std(axis=1).mean(axis=(0, 1)) ) - # max_per_action = [actions_batch.float().max()] - # min_per_action = [actions_batch.float().min()] + max_per_action = torch.amax(actions_batch, dim=(0, 1, 2)) + min_per_action = torch.amin(actions_batch, dim=(0, 1, 2)) for idx, _ in enumerate(std_over_agent_per_action): std_action = { @@ -145,12 +148,12 @@ def compute_loss_and_metrics( f"Std. of action_{idx} over time": std_over_time_per_action[ idx ].item(), - # f"Max of action_{idx}": max_per_action[ - # idx - # ].item(), - # f"Min of action_{idx}": min_per_action[ - # idx - # ].item(), + f"Max of action_{idx}": max_per_action[ + idx + ].item(), + f"Min of action_{idx}": min_per_action[ + idx + ].item(), } metrics.update(std_action) else: diff --git a/warp_drive/training/run_configs/single_pendulum.yaml b/warp_drive/training/run_configs/single_pendulum.yaml index 38abd20..05b20f9 100644 --- a/warp_drive/training/run_configs/single_pendulum.yaml +++ b/warp_drive/training/run_configs/single_pendulum.yaml @@ -12,9 +12,9 @@ env: reset_pool_size: 10000 # Trainer settings trainer: - num_envs: 10000 # number of environment replicas - num_episodes: 20000000 # number of episodes to run the training for. Can be arbitrarily high! - train_batch_size: 5000000 # total batch size used for training per iteration (across all the environments) + num_envs: 100 # number of environment replicas + num_episodes: 2000000 # number of episodes to run the training for. Can be arbitrarily high! + train_batch_size: 500 # total batch size used for training per iteration (across all the environments) env_backend: "numba" # environment backend, pycuda or numba # Policy network settings policy: # list all the policies below @@ -23,20 +23,21 @@ policy: # list all the policies below algorithm: "DDPG" # algorithm used to train the policy clip_grad_norm: True # flag indicating whether to clip the gradient norm or not max_grad_norm: 3 # when clip_grad_norm is True, the clip level - normalize_advantage: True # flag indicating whether to normalize advantage or not - normalize_return: True # flag indicating whether to normalize return or not + normalize_advantage: False # flag indicating whether to normalize advantage or not + normalize_return: False # flag indicating whether to normalize return or not gamma: 0.99 # discount factor + tau: 0.01 # target copy rate lr: - actor: 0.001 # learning rate - critic: 0.001 + actor: 0.0001 # learning rate + critic: 0.0001 model: # policy model settings actor: type: "fully_connected_actor" # model type - fc_dims: [32, 32] + fc_dims: [64, 64] output_w: 2.0 # model default range is (-1, 1), this changes to (-2, 2) critic: type: "fully_connected_action_value_critic" # model type - fc_dims: [32, 32] + fc_dims: [64, 64] model_ckpt_filepath: actor: "" # filepath (used to restore a previously saved model) critic: "" @@ -47,7 +48,7 @@ sampler: scale: 1.0 # Checkpoint saving setting saving: - metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics + metrics_log_freq: 1000 # how often (in iterations) to log (and print) the metrics model_params_save_freq: 5000 # how often (in iterations) to save the model parameters basedir: "/tmp" # base folder used for saving name: "single_pendulum" # base folder used for saving diff --git a/warp_drive/training/trainers/trainer_base.py b/warp_drive/training/trainers/trainer_base.py index beb227c..01a7b32 100644 --- a/warp_drive/training/trainers/trainer_base.py +++ b/warp_drive/training/trainers/trainer_base.py @@ -636,6 +636,7 @@ def fetch_episode_states( list_of_states=None, # list of states (data array names) to fetch env_id=0, # environment id to fetch the states from include_rewards_actions=False, # flag to output reward and action + include_probabilities=False, # flag to output action probability policy="", # if include_rewards_actions=True, the corresponding policy tag if any **sample_params ): @@ -644,9 +645,9 @@ def fetch_episode_states( for an entire episode. The trained models will be used for evaluation. """ assert 0 <= env_id < self.num_envs - assert list_of_states is not None + if list_of_states is None: + list_of_states = [] assert isinstance(list_of_states, list) - assert len(list_of_states) > 0 logging.info(f"Fetching the episode states: {list_of_states} from the GPU.") # Ensure env is reset before the start of training, and done flags are False @@ -675,7 +676,7 @@ def fetch_episode_states( ( env.episode_length, *self.cuda_envs.cuda_data_manager.get_shape(action_name)[1:] ), - dtype=np.int32 + dtype=self.cuda_envs.cuda_data_manager.get_dtype(action_name) ) episode_rewards= np.zeros( ( @@ -683,6 +684,9 @@ def fetch_episode_states( ), dtype=np.float32) + if include_probabilities: + episode_probabilities = {} + for timestep in range(env.episode_length): # Update the episode states s_t for state in list_of_states: @@ -707,7 +711,15 @@ def fetch_episode_states( # Update the episode reward r_(t+1) episode_rewards[timestep] = \ self.cuda_envs.cuda_data_manager.pull_data_from_device(reward_name)[env_id] - + if include_probabilities: + # Update the episode action probability p_t + if len(policy) > 0: + probs = {policy: [p[env_id].detach().cpu().numpy() for p in probabilities[policy]]} + else: + probs = {} + for policy, value in probabilities.items(): + probs[policy] = [v[env_id].detach().cpu().numpy() for v in value] + episode_probabilities[timestep] = probs # Fetch the states when episode is complete if env.cuda_data_manager.pull_data_from_device("_done_")[env_id]: for state in list_of_states: @@ -717,8 +729,10 @@ def fetch_episode_states( env_id ] break - if include_rewards_actions: + if include_rewards_actions and not include_probabilities: return episode_states, episode_actions, episode_rewards + elif include_rewards_actions and include_probabilities: + return episode_states, episode_actions, episode_rewards, episode_probabilities else: return episode_states diff --git a/warp_drive/training/trainers/trainer_ddpg.py b/warp_drive/training/trainers/trainer_ddpg.py index ca8339d..5e47cce 100644 --- a/warp_drive/training/trainers/trainer_ddpg.py +++ b/warp_drive/training/trainers/trainer_ddpg.py @@ -39,6 +39,18 @@ _EPSILON = 1e-10 # small number to prevent indeterminate divisions +def soft_update(target, source, tau): + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.data.copy_( + target_param.data * (1.0 - tau) + param.data * tau + ) + + +def hard_update(target, source): + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.data.copy_(param.data) + + class TrainerDDPG(TrainerBase): def __init__( self, @@ -55,10 +67,13 @@ def __init__( # Define models, optimizers, and learning rate schedules self.actor_models = {} self.critic_models = {} + self.target_actor_models = {} + self.target_critic_models = {} self.actor_optimizers = {} self.critic_optimizers = {} self.actor_lr_schedules = {} self.critic_lr_schedules = {} + self.tau = {} super().__init__( env_wrapper=env_wrapper, @@ -87,6 +102,7 @@ def _initialize_policy_algorithm(self, policy): ) normalize_return = self._get_config(["policy", policy, "normalize_return"]) gamma = self._get_config(["policy", policy, "gamma"]) + self.tau[policy] = self._get_config(["policy", policy, "tau"]) if algorithm == "DDPG": # Advantage Actor-Critic self.trainers[policy] = DDPG( @@ -117,6 +133,14 @@ def _initialize_policy_model(self, policy): create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, ) + target_actor = model_obj_actor( + env=self.cuda_envs, + model_config=actor_model_config, + policy=policy, + policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, + ) if "init_method" in actor_model_config and \ actor_model_config["init_method"] == "xavier": @@ -126,7 +150,9 @@ def init_weights_by_xavier_uniform(m): actor.apply(init_weights_by_xavier_uniform) + hard_update(target_actor, actor) self.actor_models[policy] = actor + self.target_actor_models[policy] = target_actor model_obj_critic = ModelFactory.create(critic_model_config["type"]) critic = model_obj_critic( @@ -137,6 +163,14 @@ def init_weights_by_xavier_uniform(m): create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, ) + target_critic = model_obj_critic( + env=self.cuda_envs, + model_config=critic_model_config, + policy=policy, + policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, + ) if "init_method" in critic_model_config and \ critic_model_config["init_method"] == "xavier": @@ -146,11 +180,15 @@ def init_weights_by_xavier_uniform(m): critic.apply(init_weights_by_xavier_uniform) + hard_update(target_critic, critic) self.critic_models[policy] = critic + self.target_critic_models[policy] = target_critic def _send_policy_model_to_device(self, policy): self.actor_models[policy].cuda() self.critic_models[policy].cuda() + self.target_actor_models[policy].cuda() + self.target_critic_models[policy].cuda() # If distributed train, sync model using DDP if self.num_devices > 1: self.actor_models[policy] = DDP( @@ -159,6 +197,12 @@ def _send_policy_model_to_device(self, policy): self.critic_models[policy] = DDP( self.critic_models[policy], device_ids=[self.device_id] ) + self.target_actor_models[policy] = DDP( + self.target_actor_models[policy], device_ids=[self.device_id] + ) + self.target_critic_models[policy] = DDP( + self.target_critic_models[policy], device_ids=[self.device_id] + ) self.ddp_mode[policy] = True else: self.ddp_mode[policy] = False @@ -291,12 +335,21 @@ def _update_model_params(self, iteration): probabilities_batch = self.actor_models[policy]( obs=processed_obs_batch ) - # Critic Q(s, a) is a function of both obs and action + target_probabilities_batch = self.target_actor_models[policy]( + obs=processed_obs_batch + ) + # Critic Q(s_t, a_t) is a function of both obs and action # value_functions_batch includes sampled actions value_functions_batch = self.critic_models[policy]( obs=processed_obs_batch, action=actions_batch ) - # j_functions_batch includes actor network for the back-propagation + # Critic Q(s_t+1, a_t+1) is a function of both obs and action + # next_value_functions_batch not includes sampled action but + # the detached output from actor directly + next_value_functions_batch = self.target_critic_models[policy]( + obs=processed_obs_batch[1:], action=[pb[1:].detach() for pb in target_probabilities_batch] + ) + # j_functions_batch includes the graph of actor network for the back-propagation j_functions_batch = self.critic_models[policy]( obs=processed_obs_batch, action=probabilities_batch ) @@ -307,6 +360,7 @@ def _update_model_params(self, iteration): rewards_batch, done_flags_batch, value_functions_batch, + next_value_functions_batch, j_functions_batch, perform_logging=logging_flag, ) @@ -352,6 +406,9 @@ def _update_model_params(self, iteration): self.actor_optimizers[policy].step() self.critic_optimizers[policy].step() + + soft_update(self.target_actor_models[policy], self.actor_models[policy], self.tau[policy]) + soft_update(self.target_critic_models[policy], self.critic_models[policy], self.tau[policy]) # Logging if logging_flag: metrics_dict[policy] = metrics From 7bf8eb9e824b8591c0c90c6e193161310e4e0dcc Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Fri, 16 Feb 2024 21:34:29 -0800 Subject: [PATCH 12/19] ring buffer and fix ddpg --- .../numba_tests/test_ring_buffer.py | 81 +++++++++++++++++ warp_drive/managers/data_manager.py | 4 + .../algorithms/policygradient/ddpg.py | 35 ++++++-- warp_drive/training/models/model_base.py | 11 ++- .../training/run_configs/single_pendulum.yaml | 1 + warp_drive/training/trainers/trainer_base.py | 76 +++++++++++----- warp_drive/training/trainers/trainer_ddpg.py | 40 ++++----- warp_drive/training/utils/ring_buffer.py | 87 +++++++++++++++++++ 8 files changed, 280 insertions(+), 55 deletions(-) create mode 100644 tests/warp_drive/numba_tests/test_ring_buffer.py create mode 100644 warp_drive/training/utils/ring_buffer.py diff --git a/tests/warp_drive/numba_tests/test_ring_buffer.py b/tests/warp_drive/numba_tests/test_ring_buffer.py new file mode 100644 index 0000000..760e877 --- /dev/null +++ b/tests/warp_drive/numba_tests/test_ring_buffer.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +import unittest +import numpy as np +import torch +from warp_drive.managers.numba_managers.numba_data_manager import NumbaDataManager +from warp_drive.training.utils.ring_buffer import RingBuffer, RingBufferManager +from warp_drive.utils.data_feed import DataFeed + + +class TestRingBuffer(unittest.TestCase): + """ + Unit tests for the RingBuffer + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dm = NumbaDataManager(num_agents=5, num_envs=1, episode_length=3) + self.rbm = RingBufferManager() + + def test(self): + x = np.zeros((5, 3), dtype=np.float32) + data = DataFeed() + data.add_data(name="X", data=x) + self.dm.push_data_to_device(data, torch_accessible=True) + self.rbm.add(name="X", data_manager=self.dm) + buffer = self.rbm.get("X") + for i in [0, 1, 2]: + buffer.enqueue(torch.tensor([i, i, i])) + + res1 = buffer.unroll().cpu().numpy() + self.assertEqual( + res1.tolist(), + np.array([[0, 0, 0], + [1, 1, 1], + [2, 2, 2]] + ).tolist() + ) + + self.assertTrue(not buffer.isfull()) + + for i in [3, 4]: + buffer.enqueue(torch.tensor([i, i, i])) + res2 = buffer.unroll().cpu().numpy() + self.assertEqual( + res2.tolist(), + np.array([[0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [3, 3, 3], + [4, 4, 4]] + ).tolist() + ) + + self.assertTrue(buffer.isfull()) + + for i in [5, 6, 7]: + buffer.enqueue(torch.tensor([i, i, i])) + res3 = buffer.unroll().cpu().numpy() + self.assertEqual( + res3.tolist(), + np.array([[3, 3, 3], + [4, 4, 4], + [5, 5, 5], + [6, 6, 6], + [7, 7, 7]] + ).tolist() + ) + + self.assertTrue(buffer.isfull()) + + + + + + + diff --git a/warp_drive/managers/data_manager.py b/warp_drive/managers/data_manager.py index bb95689..6b8b9db 100644 --- a/warp_drive/managers/data_manager.py +++ b/warp_drive/managers/data_manager.py @@ -479,3 +479,7 @@ def reset_target_to_pool(self): @property def log_data_list(self): return self._log_data_list + + @property + def device_data_via_torch(self): + return self._device_data_via_torch diff --git a/warp_drive/training/algorithms/policygradient/ddpg.py b/warp_drive/training/algorithms/policygradient/ddpg.py index 399730d..dca6f95 100644 --- a/warp_drive/training/algorithms/policygradient/ddpg.py +++ b/warp_drive/training/algorithms/policygradient/ddpg.py @@ -8,6 +8,7 @@ import torch from torch import nn from torch.distributions import Categorical +import numpy as np from warp_drive.training.utils.param_scheduler import ParamScheduler @@ -25,11 +26,14 @@ def __init__( discount_factor_gamma=1.0, normalize_advantage=False, normalize_return=False, + n_step=1, ): assert 0 <= discount_factor_gamma <= 1 + assert n_step >= 1 self.discount_factor_gamma = discount_factor_gamma self.normalize_advantage = normalize_advantage self.normalize_return = normalize_return + self.n_step = n_step def compute_loss_and_metrics( self, @@ -50,20 +54,31 @@ def compute_loss_and_metrics( assert next_value_functions_batch is not None assert j_functions_batch is not None + # we only calculate up to batch - n_step + 1 point, + # after that it is not enough points to calculate n_step + valid_batch_range = rewards_batch.shape[0] - self.n_step + 1 # Detach value_functions_batch from the computation graph # for return and advantage computations. next_value_functions_batch_detached = next_value_functions_batch.detach() # Value objective. - returns_batch = torch.zeros_like(rewards_batch) - - returns_batch[-1] = ( - done_flags_batch[-1][:, None] * rewards_batch[-1] - + (1 - done_flags_batch[-1][:, None]) * next_value_functions_batch_detached[-1] - ) - returns_batch[:-1] = rewards_batch[:-1] + \ - self.discount_factor_gamma * (1 - done_flags_batch[:-1][:, :, None]) * next_value_functions_batch_detached - + returns_batch = torch.zeros_like(rewards_batch[:valid_batch_range]) + + for i in range(valid_batch_range): + last_step = i + self.n_step - 1 + if last_step < rewards_batch.shape[0] - 1: + r = rewards_batch[last_step] + \ + (1 - done_flags_batch[last_step][:, None]) * \ + self.discount_factor_gamma * next_value_functions_batch_detached[last_step] + else: + r = done_flags_batch[last_step][:, None] * rewards_batch[last_step] + \ + (1 - done_flags_batch[last_step][:, None]) * \ + self.discount_factor_gamma * next_value_functions_batch_detached[-1] + for j in range(1, self.n_step): + r = (1 - done_flags_batch[last_step - j][:, None]) * self.discount_factor_gamma * r + \ + done_flags_batch[last_step - j][:, None] * torch.zeros_like(rewards_batch[last_step - j]) + r += rewards_batch[last_step - j] + returns_batch[i] = r # Normalize across the agents and env dimensions if self.normalize_return: normalized_returns_batch = ( @@ -72,6 +87,7 @@ def compute_loss_and_metrics( else: normalized_returns_batch = returns_batch + value_functions_batch = value_functions_batch[:valid_batch_range] critic_loss = nn.MSELoss()(normalized_returns_batch, value_functions_batch) advantages_batch = normalized_returns_batch - value_functions_batch @@ -87,6 +103,7 @@ def compute_loss_and_metrics( normalized_advantages_batch = advantages_batch # Policy objective + j_functions_batch = j_functions_batch[:valid_batch_range] if self.normalize_return: normalized_j_functions_batch = ( j_functions_batch - j_functions_batch.mean(dim=(1, 2), keepdim=True) diff --git a/warp_drive/training/models/model_base.py b/warp_drive/training/models/model_base.py index a97bac8..c1a3b06 100644 --- a/warp_drive/training/models/model_base.py +++ b/warp_drive/training/models/model_base.py @@ -188,13 +188,16 @@ def process_one_step_obs(self): def forward(self, obs=None, action=None): raise NotImplementedError - def push_processed_obs_to_batch(self, batch_index, processed_obs): + def push_processed_obs_to_batch(self, batch_index, processed_obs, ring_buffer=None): if batch_index >= 0: assert batch_index < self.batch_size, f"batch_index: {batch_index}, self.batch_size: {self.batch_size}" name = f"{_PROCESSED_OBSERVATIONS}_batch_{self.policy}" - self.env.cuda_data_manager.data_on_device_via_torch(name=name)[ - batch_index - ] = processed_obs + if ring_buffer is not None and ring_buffer.has(name): + ring_buffer.get(name).enqueue(processed_obs) + else: + self.env.cuda_data_manager.data_on_device_via_torch(name=name)[ + batch_index + ] = processed_obs def apply_logit_mask(logits, mask=None): diff --git a/warp_drive/training/run_configs/single_pendulum.yaml b/warp_drive/training/run_configs/single_pendulum.yaml index 05b20f9..e274e9a 100644 --- a/warp_drive/training/run_configs/single_pendulum.yaml +++ b/warp_drive/training/run_configs/single_pendulum.yaml @@ -15,6 +15,7 @@ trainer: num_envs: 100 # number of environment replicas num_episodes: 2000000 # number of episodes to run the training for. Can be arbitrarily high! train_batch_size: 500 # total batch size used for training per iteration (across all the environments) + n_step: 5 # n_step for calculating return env_backend: "numba" # environment backend, pycuda or numba # Policy network settings policy: # list all the policies below diff --git a/warp_drive/training/trainers/trainer_base.py b/warp_drive/training/trainers/trainer_base.py index 01a7b32..86b391a 100644 --- a/warp_drive/training/trainers/trainer_base.py +++ b/warp_drive/training/trainers/trainer_base.py @@ -21,6 +21,7 @@ from torch import nn from warp_drive.training.utils.data_loader import create_and_push_data_placeholders +from warp_drive.training.utils.ring_buffer import RingBufferManager from warp_drive.utils.common import get_project_root from warp_drive.utils.constants import Constants @@ -205,6 +206,8 @@ def __init__( self.training_batch_size_per_env = self.training_batch_size // self.num_envs assert self.training_batch_size_per_env > 0 + self.n_step = self.config["trainer"].get("n_step", 1) + # Push all the data and tensor arrays to the GPU # upon resetting environments for the very first time. self.cuda_envs.reset_all_envs() @@ -233,7 +236,7 @@ def __init__( policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, - training_batch_size_per_env=self.training_batch_size_per_env, + training_batch_size_per_env=self.training_batch_size_per_env + self.n_step - 1, ) # Seeding (device_id is included for distributed training) seed = ( @@ -302,6 +305,9 @@ def __init__( # Metrics self.metrics = Metrics() + # Ring Buffer to save batch data + self.ring_buffer = RingBufferManager() + def _get_config(self, args): assert isinstance(args, (tuple, list)) config = self.config @@ -427,9 +433,13 @@ def _sample_actions(self, probabilities, batch_index=0, **sample_params): actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( _ACTIONS + policy_suffix ) - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - name=f"{_ACTIONS}_batch" + policy_suffix - )[batch_index] = actions + action_batch_name = f"{_ACTIONS}_batch" + policy_suffix + if self.ring_buffer.has(action_batch_name): + self.ring_buffer.get(action_batch_name).enqueue(actions) + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=action_batch_name + )[batch_index] = actions else: assert len(probabilities) == 1 policy = list(probabilities.keys())[0] @@ -442,15 +452,22 @@ def _sample_actions(self, probabilities, batch_index=0, **sample_params): # (1) there is only one policy, then action -> action_batch_policy # (2) there are multiple policies, then action[policy_tag_to_agent_id[policy]] -> action_batch_policy for policy in self.policies: + action_batch_name = f"{_ACTIONS}_batch_{policy}" if len(self.policies) > 1: agent_ids_for_policy = self.policy_tag_to_agent_id_map[policy] - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - name=f"{_ACTIONS}_batch_{policy}" - )[batch_index] = actions[:, agent_ids_for_policy] + if self.ring_buffer.has(action_batch_name): + self.ring_buffer.get(action_batch_name).enqueue(actions[:, agent_ids_for_policy]) + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=action_batch_name + )[batch_index] = actions[:, agent_ids_for_policy] else: - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - name=f"{_ACTIONS}_batch_{policy}" - )[batch_index] = actions + if self.ring_buffer.has(action_batch_name): + self.ring_buffer.get(action_batch_name).enqueue(actions) + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=action_batch_name + )[batch_index] = actions def _sample_actions_helper(self, probabilities, policy_suffix="", **sample_params): # Sample actions with policy_suffix tag @@ -485,9 +502,13 @@ def _bookkeep_rewards_and_done_flags(self, batch_index): done_flags = ( self.cuda_envs.cuda_data_manager.data_on_device_via_torch("_done_") > 0 ) - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - name=f"{_DONE_FLAGS}_batch" - )[batch_index] = done_flags + done_batch_name = f"{_DONE_FLAGS}_batch" + if self.ring_buffer.has(done_batch_name): + self.ring_buffer.get(done_batch_name).enqueue(done_flags) + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=done_batch_name + )[batch_index] = done_flags done_env_ids = done_flags.nonzero() @@ -497,9 +518,13 @@ def _bookkeep_rewards_and_done_flags(self, batch_index): rewards = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( f"{_REWARDS}_{policy}" ) - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - name=f"{_REWARDS}_batch_{policy}" - )[batch_index] = rewards + reward_batch_name = f"{_REWARDS}_batch_{policy}" + if self.ring_buffer.has(reward_batch_name): + self.ring_buffer.get(reward_batch_name).enqueue(rewards) + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=reward_batch_name + )[batch_index] = rewards # Update the episodic rewards self._update_episodic_rewards(rewards, done_env_ids, policy) @@ -509,15 +534,22 @@ def _bookkeep_rewards_and_done_flags(self, batch_index): _REWARDS ) for policy in self.policies: + reward_batch_name = f"{_REWARDS}_batch_{policy}" if len(self.policies) > 1: agent_ids_for_policy = self.policy_tag_to_agent_id_map[policy] - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - name=f"{_REWARDS}_batch_{policy}" - )[batch_index] = rewards[:, agent_ids_for_policy] + if self.ring_buffer.has(reward_batch_name): + self.ring_buffer.get(reward_batch_name).enqueue(rewards[:, agent_ids_for_policy]) + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=reward_batch_name + )[batch_index] = rewards[:, agent_ids_for_policy] else: - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - name=f"{_REWARDS}_batch_{policy}" - )[batch_index] = rewards + if self.ring_buffer.has(reward_batch_name): + self.ring_buffer.get(reward_batch_name).enqueue(rewards) + else: + self.cuda_envs.cuda_data_manager.data_on_device_via_torch( + name=reward_batch_name + )[batch_index] = rewards # Update the episodic rewards # (sum of individual step rewards over an episode) diff --git a/warp_drive/training/trainers/trainer_ddpg.py b/warp_drive/training/trainers/trainer_ddpg.py index 5e47cce..354e30f 100644 --- a/warp_drive/training/trainers/trainer_ddpg.py +++ b/warp_drive/training/trainers/trainer_ddpg.py @@ -86,6 +86,12 @@ def __init__( results_dir=results_dir, verbose=verbose, ) + self._init_ring_buffer() + + def _init_ring_buffer(self): + for k in self.cuda_envs.cuda_data_manager.device_data_via_torch.keys(): + if "_batch" in k: + self.ring_buffer.add(name=k, data_manager=self.cuda_envs.cuda_data_manager) def _initialize_policy_algorithm(self, policy): algorithm = self._get_config(["policy", policy, "algorithm"]) @@ -109,6 +115,7 @@ def _initialize_policy_algorithm(self, policy): discount_factor_gamma=gamma, normalize_advantage=normalize_advantage, normalize_return=normalize_return, + n_step=self.n_step, ) logging.info(f"Initializing the DDPG trainer for policy {policy}") else: @@ -244,10 +251,12 @@ def _evaluate_policies(self, batch_index=0): if self.ddp_mode[policy]: # self.models[policy] is a DDP wrapper of the model instance obs = self.actor_models[policy].module.process_one_step_obs() - self.actor_models[policy].module.push_processed_obs_to_batch(batch_index, obs) + self.actor_models[policy].module.push_processed_obs_to_batch( + batch_index, obs, ring_buffer=self.ring_buffer) else: obs = self.actor_models[policy].process_one_step_obs() - self.actor_models[policy].push_processed_obs_to_batch(batch_index, obs) + self.actor_models[policy].push_processed_obs_to_batch( + batch_index, obs, ring_buffer=self.ring_buffer) probabilities[policy] = self.actor_models[policy](obs) # Combine probabilities across policies if there are multiple policies, @@ -307,30 +316,21 @@ def _update_model_params(self, iteration): metrics_dict = {} - done_flags_batch = self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - f"{_DONE_FLAGS}_batch" - ) + if not self.ring_buffer.get(f"{_DONE_FLAGS}_batch").isfull(): + return metrics_dict + + done_flags_batch = self.ring_buffer.get(f"{_DONE_FLAGS}_batch").unroll() + # On the device, observations_batch, actions_batch, # rewards_batch are all shaped # (batch_size, num_envs, num_agents, *feature_dim). + # Notice that, in ddpg we used a ring_buffer to rolling store the batch # done_flags_batch is shaped (batch_size, num_envs) # Perform training sequentially for each policy for policy in self.policies_to_train: - actions_batch = ( - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - f"{_ACTIONS}_batch_{policy}" - ) - ) - rewards_batch = ( - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - f"{_REWARDS}_batch_{policy}" - ) - ) - processed_obs_batch = ( - self.cuda_envs.cuda_data_manager.data_on_device_via_torch( - f"{_PROCESSED_OBSERVATIONS}_batch_{policy}" - ) - ) + actions_batch = self.ring_buffer.get(f"{_ACTIONS}_batch_{policy}").unroll() + rewards_batch = self.ring_buffer.get(f"{_REWARDS}_batch_{policy}").unroll() + processed_obs_batch = self.ring_buffer.get(f"{_PROCESSED_OBSERVATIONS}_batch_{policy}").unroll() # Policy evaluation for the entire batch probabilities_batch = self.actor_models[policy]( obs=processed_obs_batch diff --git a/warp_drive/training/utils/ring_buffer.py b/warp_drive/training/utils/ring_buffer.py new file mode 100644 index 0000000..3ebd4dd --- /dev/null +++ b/warp_drive/training/utils/ring_buffer.py @@ -0,0 +1,87 @@ +import torch +from warp_drive.managers.data_manager import CUDADataManager + + +class RingBuffer: + """ + We manage the batch data as a circular queue + """ + + def __init__( + self, + name: str = None, + size: int = None, + data_manager: CUDADataManager = None, + ): + self.buffer_name = f"RingBuffer_{name}" + assert data_manager.is_data_on_device_via_torch(name) + # initializing queue with none + self.front = -1 + self.rear = -1 + self.current_size = 0 + + self.queue = data_manager.data_on_device_via_torch(name=name) + if size is None: + self.size = data_manager.get_shape(name)[0] + else: + self.size = size + assert self.size <= data_manager.get_shape(name)[0], \ + f"The managed the ring buffer size could not exceed the size of the container: {name}" + + def enqueue(self, data): + assert isinstance(data, torch.Tensor) + # condition if queue is full + if (self.rear + 1) % self.size == self.front: + self._dequeue() + # condition for empty queue + if self.front == -1: + self.front = 0 + self.rear = 0 + self.queue[self.rear] = data + else: + # next position of rear + self.rear = (self.rear + 1) % self.size + self.queue[self.rear] = data + self.current_size += 1 + + def _dequeue(self): + if self.front == -1: + return + # condition for only one element + elif self.front == self.rear: + self.front = -1 + self.rear = -1 + else: + self.front = (self.front + 1) % self.size + self.current_size -= 1 + + def unroll(self): + # we unroll the circular queue to a flattened array with index following the order from front to tail + if self.front == -1: + return None + + elif self.rear >= self.front: + return self.queue[self.front: self.rear+1] + + else: + return torch.roll(self.queue, shifts=-self.front, dims=0)[:self.current_size] + + def isfull(self): + return self.current_size == self.size + + +class RingBufferManager(dict): + + def add(self, name, size=None, data_manager=None): + r = RingBuffer(name=name, size=size, data_manager=data_manager) + self[name] = r + + def get(self, name): + assert name in self, \ + f"{name} not in the RingBufferManager" + return self[name] + + def has(self, name): + return name in self + + From b3ecd380bdd47ce6c9dbe855105337efedabaf2a Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Sat, 17 Feb 2024 11:02:53 -0800 Subject: [PATCH 13/19] finish configu --- .../training/run_configs/single_pendulum.yaml | 18 +- warp_drive/training/trainers/trainer_ddpg.py | 166 +++++++++--------- 2 files changed, 96 insertions(+), 88 deletions(-) diff --git a/warp_drive/training/run_configs/single_pendulum.yaml b/warp_drive/training/run_configs/single_pendulum.yaml index e274e9a..02e707f 100644 --- a/warp_drive/training/run_configs/single_pendulum.yaml +++ b/warp_drive/training/run_configs/single_pendulum.yaml @@ -12,9 +12,9 @@ env: reset_pool_size: 10000 # Trainer settings trainer: - num_envs: 100 # number of environment replicas - num_episodes: 2000000 # number of episodes to run the training for. Can be arbitrarily high! - train_batch_size: 500 # total batch size used for training per iteration (across all the environments) + num_envs: 10000 # number of environment replicas + num_episodes: 10000000 # number of episodes to run the training for. Can be arbitrarily high! + train_batch_size: 50000 # total batch size used for training per iteration (across all the environments) n_step: 5 # n_step for calculating return env_backend: "numba" # environment backend, pycuda or numba # Policy network settings @@ -27,10 +27,10 @@ policy: # list all the policies below normalize_advantage: False # flag indicating whether to normalize advantage or not normalize_return: False # flag indicating whether to normalize return or not gamma: 0.99 # discount factor - tau: 0.01 # target copy rate + tau: 0.05 # target copy rate lr: - actor: 0.0001 # learning rate - critic: 0.0001 + actor: [[2500000000, 0.001], [3750000000, 0.0005]] # learning rate + critic: [[2500000000, 0.0001], [3750000000, 0.00005]] model: # policy model settings actor: type: "fully_connected_actor" # model type @@ -51,7 +51,7 @@ sampler: saving: metrics_log_freq: 1000 # how often (in iterations) to log (and print) the metrics model_params_save_freq: 5000 # how often (in iterations) to save the model parameters - basedir: "/tmp" # base folder used for saving - name: "single_pendulum" # base folder used for saving - tag: "experiments" # experiment name + basedir: "/export/home/experiments/warpdrive" # base folder used for saving + name: "pendulum" # base folder used for saving + tag: "10000" # experiment name diff --git a/warp_drive/training/trainers/trainer_ddpg.py b/warp_drive/training/trainers/trainer_ddpg.py index 354e30f..ec431f8 100644 --- a/warp_drive/training/trainers/trainer_ddpg.py +++ b/warp_drive/training/trainers/trainer_ddpg.py @@ -316,9 +316,6 @@ def _update_model_params(self, iteration): metrics_dict = {} - if not self.ring_buffer.get(f"{_DONE_FLAGS}_batch").isfull(): - return metrics_dict - done_flags_batch = self.ring_buffer.get(f"{_DONE_FLAGS}_batch").unroll() # On the device, observations_batch, actions_batch, @@ -328,87 +325,98 @@ def _update_model_params(self, iteration): # done_flags_batch is shaped (batch_size, num_envs) # Perform training sequentially for each policy for policy in self.policies_to_train: - actions_batch = self.ring_buffer.get(f"{_ACTIONS}_batch_{policy}").unroll() - rewards_batch = self.ring_buffer.get(f"{_REWARDS}_batch_{policy}").unroll() - processed_obs_batch = self.ring_buffer.get(f"{_PROCESSED_OBSERVATIONS}_batch_{policy}").unroll() - # Policy evaluation for the entire batch - probabilities_batch = self.actor_models[policy]( - obs=processed_obs_batch - ) - target_probabilities_batch = self.target_actor_models[policy]( - obs=processed_obs_batch - ) - # Critic Q(s_t, a_t) is a function of both obs and action - # value_functions_batch includes sampled actions - value_functions_batch = self.critic_models[policy]( - obs=processed_obs_batch, action=actions_batch - ) - # Critic Q(s_t+1, a_t+1) is a function of both obs and action - # next_value_functions_batch not includes sampled action but - # the detached output from actor directly - next_value_functions_batch = self.target_critic_models[policy]( - obs=processed_obs_batch[1:], action=[pb[1:].detach() for pb in target_probabilities_batch] - ) - # j_functions_batch includes the graph of actor network for the back-propagation - j_functions_batch = self.critic_models[policy]( - obs=processed_obs_batch, action=probabilities_batch - ) - # Loss and metrics computation - actor_loss, critic_loss, metrics = self.trainers[policy].compute_loss_and_metrics( - self.current_timestep[policy], - actions_batch, - rewards_batch, - done_flags_batch, - value_functions_batch, - next_value_functions_batch, - j_functions_batch, - perform_logging=logging_flag, - ) - # Compute the gradient norm - actor_grad_norm = 0.0 - for param in list( - filter(lambda p: p.grad is not None, self.actor_models[policy].parameters()) - ): - actor_grad_norm += param.grad.data.norm(2).item() - - critic_grad_norm = 0.0 - for param in list( - filter(lambda p: p.grad is not None, self.critic_models[policy].parameters()) - ): - critic_grad_norm += param.grad.data.norm(2).item() - - # Update the timestep and learning rate based on the schedule - self.current_timestep[policy] += self.training_batch_size - actor_lr = self.actor_lr_schedules[policy].get_param_value( - self.current_timestep[policy] - ) - for param_group in self.actor_optimizers[policy].param_groups: - param_group["lr"] = actor_lr - critic_lr = self.critic_lr_schedules[policy].get_param_value( - self.current_timestep[policy] - ) - for param_group in self.critic_optimizers[policy].param_groups: - param_group["lr"] = critic_lr - - # Loss backpropagation and optimization step - self.actor_optimizers[policy].zero_grad() - self.critic_optimizers[policy].zero_grad() - actor_loss.backward() - critic_loss.backward() - if self.clip_grad_norm[policy]: - nn.utils.clip_grad_norm_( - self.actor_models[policy].parameters(), self.max_grad_norm[policy] + if self.ring_buffer.get(f"{_DONE_FLAGS}_batch").isfull(): + # buffer is full, we can train + # this should skip the first roll-out train only + actions_batch = self.ring_buffer.get(f"{_ACTIONS}_batch_{policy}").unroll() + rewards_batch = self.ring_buffer.get(f"{_REWARDS}_batch_{policy}").unroll() + processed_obs_batch = self.ring_buffer.get(f"{_PROCESSED_OBSERVATIONS}_batch_{policy}").unroll() + # Policy evaluation for the entire batch + probabilities_batch = self.actor_models[policy]( + obs=processed_obs_batch + ) + target_probabilities_batch = self.target_actor_models[policy]( + obs=processed_obs_batch + ) + # Critic Q(s_t, a_t) is a function of both obs and action + # value_functions_batch includes sampled actions + value_functions_batch = self.critic_models[policy]( + obs=processed_obs_batch, action=actions_batch + ) + # Critic Q(s_t+1, a_t+1) is a function of both obs and action + # next_value_functions_batch not includes sampled action but + # the detached output from actor directly + next_value_functions_batch = self.target_critic_models[policy]( + obs=processed_obs_batch[1:], action=[pb[1:].detach() for pb in target_probabilities_batch] + ) + # j_functions_batch includes the graph of actor network for the back-propagation + j_functions_batch = self.critic_models[policy]( + obs=processed_obs_batch, action=probabilities_batch + ) + # Loss and metrics computation + actor_loss, critic_loss, metrics = self.trainers[policy].compute_loss_and_metrics( + self.current_timestep[policy], + actions_batch, + rewards_batch, + done_flags_batch, + value_functions_batch, + next_value_functions_batch, + j_functions_batch, + perform_logging=logging_flag, ) - nn.utils.clip_grad_norm_( - self.critic_models[policy].parameters(), self.max_grad_norm[policy] + # Compute the gradient norm + actor_grad_norm = 0.0 + for param in list( + filter(lambda p: p.grad is not None, self.actor_models[policy].parameters()) + ): + actor_grad_norm += param.grad.data.norm(2).item() + + critic_grad_norm = 0.0 + for param in list( + filter(lambda p: p.grad is not None, self.critic_models[policy].parameters()) + ): + critic_grad_norm += param.grad.data.norm(2).item() + + # Update the timestep and learning rate based on the schedule + self.current_timestep[policy] += self.training_batch_size + actor_lr = self.actor_lr_schedules[policy].get_param_value( + self.current_timestep[policy] ) + for param_group in self.actor_optimizers[policy].param_groups: + param_group["lr"] = actor_lr - self.actor_optimizers[policy].step() - self.critic_optimizers[policy].step() + critic_lr = self.critic_lr_schedules[policy].get_param_value( + self.current_timestep[policy] + ) + for param_group in self.critic_optimizers[policy].param_groups: + param_group["lr"] = critic_lr + + # Loss backpropagation and optimization step + self.actor_optimizers[policy].zero_grad() + self.critic_optimizers[policy].zero_grad() + actor_loss.backward() + critic_loss.backward() + if self.clip_grad_norm[policy]: + nn.utils.clip_grad_norm_( + self.actor_models[policy].parameters(), self.max_grad_norm[policy] + ) + nn.utils.clip_grad_norm_( + self.critic_models[policy].parameters(), self.max_grad_norm[policy] + ) - soft_update(self.target_actor_models[policy], self.actor_models[policy], self.tau[policy]) - soft_update(self.target_critic_models[policy], self.critic_models[policy], self.tau[policy]) + self.actor_optimizers[policy].step() + self.critic_optimizers[policy].step() + + soft_update(self.target_actor_models[policy], self.actor_models[policy], self.tau[policy]) + soft_update(self.target_critic_models[policy], self.critic_models[policy], self.tau[policy]) + + else: + metrics = {} + actor_grad_norm = 0.0 + critic_grad_norm = 0.0 + actor_lr = 0.0 + critic_lr = 0.0 # Logging if logging_flag: metrics_dict[policy] = metrics From 2df2bd2fcebae0f4f4baac3b5108cc98154efc1f Mon Sep 17 00:00:00 2001 From: Tian Lan <31748898+Emerald01@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:00:52 -0800 Subject: [PATCH 14/19] Update README.md --- README.md | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 48583fb..1d3353b 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,20 @@ Using the extreme parallelization capability of GPUs, WarpDrive enables orders-o faster RL compared to CPU simulation + GPU model implementations. It is extremely efficient as it avoids back-and-forth data copying between the CPU and the GPU, and runs simulations across multiple agents and multiple environment replicas in parallel. -We have some main updates since its initial open source, -- version 1.3: provides the auto scaling tools to achieve the optimal throughput per device. -- version 1.4: supports the distributed asynchronous training among multiple GPU devices. -- version 1.6: supports the aggregation of multiple GPU blocks for one environment replica. -- version 2.0: supports the dual backends of both CUDA C and JIT compiled Numba. [(Our Blog article)](https://blog.salesforceairesearch.com/warpdrive-v2-numba-nvidia-gpu-simulations/) -- version 2.6: supports single agent environments, including Cartpole, MountainCar, Acrobot +| | Support | Concurrent Number | Version +:--- | :---: | :---: | :---: +| Environments | Single ✅ Multi ✅ | >= 1000 per GPU | 1.0 +| Agents | Single ✅ Multi ✅ | 1024 | 1.0 +| Agents | Multi across blocks ✅| 1024 per block | 1.6 +| Discrete Actions | Single ✅ Multi ✅| - | 1.0 +| Continuous Action | Single ✅ Multi ✅| - | 2.7 +| On-Policy Policy Gradient | A2C ✅, PPO ✅ | - | 1.0 +| Off-Policy Policy Gradient| DDPG ✅ | - | 2.7 +| Auto-Scaling | ✅ | - | 1.3 +| Distributed Simulation | ✅ | 2 to 16 GPUs node| 1.4 +| Environment Backend | CUDA C ✅ | - | 1.0 +| Environment Backend | CUDA C ✅ Numba ✅ | - | 2.0 +| Training Backend | Pytorch ✅ | - | 1.0 Together, these allow the user to run thousands or even millions of concurrent simulations and train on extremely large batches of experience, achieving at least 100x throughput over CPU-based counterparts. From 3485b6ce5b2a4bc51ebb61e144e91b235409255b Mon Sep 17 00:00:00 2001 From: Tian Lan <31748898+Emerald01@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:01:35 -0800 Subject: [PATCH 15/19] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1d3353b..dafc1b3 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Using the extreme parallelization capability of GPUs, WarpDrive enables orders-o faster RL compared to CPU simulation + GPU model implementations. It is extremely efficient as it avoids back-and-forth data copying between the CPU and the GPU, and runs simulations across multiple agents and multiple environment replicas in parallel. -| | Support | Concurrent Number | Version +| | Support | Concurrency | Version :--- | :---: | :---: | :---: | Environments | Single ✅ Multi ✅ | >= 1000 per GPU | 1.0 | Agents | Single ✅ Multi ✅ | 1024 | 1.0 From 7c2bb3c9fae7faf777b5ad89a5a871996b8a6cf7 Mon Sep 17 00:00:00 2001 From: Tian Lan <31748898+Emerald01@users.noreply.github.com> Date: Sun, 18 Feb 2024 11:03:44 -0800 Subject: [PATCH 16/19] Update README.md --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index dafc1b3..411a1d5 100644 --- a/README.md +++ b/README.md @@ -7,14 +7,17 @@ Using the extreme parallelization capability of GPUs, WarpDrive enables orders-o faster RL compared to CPU simulation + GPU model implementations. It is extremely efficient as it avoids back-and-forth data copying between the CPU and the GPU, and runs simulations across multiple agents and multiple environment replicas in parallel. +Together, these allow the user to run thousands or even millions of concurrent simulations and train +on extremely large batches of experience, achieving at least 100x throughput over CPU-based counterparts. + | | Support | Concurrency | Version :--- | :---: | :---: | :---: -| Environments | Single ✅ Multi ✅ | >= 1000 per GPU | 1.0 -| Agents | Single ✅ Multi ✅ | 1024 | 1.0 +| Environments | Single ✅ Multi ✅ | 1 to 1000 per GPU | 1.0 +| Agents | Single ✅ Multi ✅ | 1 to 1024 per environment | 1.0 | Agents | Multi across blocks ✅| 1024 per block | 1.6 | Discrete Actions | Single ✅ Multi ✅| - | 1.0 | Continuous Action | Single ✅ Multi ✅| - | 2.7 -| On-Policy Policy Gradient | A2C ✅, PPO ✅ | - | 1.0 +| On-Policy Policy Gradient | A2C ✅ PPO ✅ | - | 1.0 | Off-Policy Policy Gradient| DDPG ✅ | - | 2.7 | Auto-Scaling | ✅ | - | 1.3 | Distributed Simulation | ✅ | 2 to 16 GPUs node| 1.4 @@ -22,8 +25,6 @@ and runs simulations across multiple agents and multiple environment replicas in | Environment Backend | CUDA C ✅ Numba ✅ | - | 2.0 | Training Backend | Pytorch ✅ | - | 1.0 -Together, these allow the user to run thousands or even millions of concurrent simulations and train -on extremely large batches of experience, achieving at least 100x throughput over CPU-based counterparts. ## Environments 1. We include several default multi-agent environments From 162cbc18845dd92919dfade3d9bebed2854bef38 Mon Sep 17 00:00:00 2001 From: Tian Lan <31748898+Emerald01@users.noreply.github.com> Date: Sun, 18 Feb 2024 13:15:20 -0800 Subject: [PATCH 17/19] Update README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 411a1d5..5817d43 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,11 @@ WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement framework that implements end-to-end multi-agent RL on a single or multiple GPUs (Graphics Processing Unit). Using the extreme parallelization capability of GPUs, WarpDrive enables orders-of-magnitude -faster RL compared to CPU simulation + GPU model implementations. It is extremely efficient as it avoids back-and-forth data copying between the CPU and the GPU, -and runs simulations across multiple agents and multiple environment replicas in parallel. - +faster RL compared to CPU simulation + GPU model implementations. It is extremely efficient as it avoids back-and-forth data copying between the CPU and the GPU, and runs simulations across multiple agents and multiple environment replicas in parallel. Together, these allow the user to run thousands or even millions of concurrent simulations and train -on extremely large batches of experience, achieving at least 100x throughput over CPU-based counterparts. +on extremely large batches of experience, achieving at least 100x throughput over CPU-based counterparts. + +The table below provides a visual overview of Warpdrive's key features and scalability over various dimensions. | | Support | Concurrency | Version :--- | :---: | :---: | :---: @@ -20,7 +20,7 @@ on extremely large batches of experience, achieving at least 100x throughput ove | On-Policy Policy Gradient | A2C ✅ PPO ✅ | - | 1.0 | Off-Policy Policy Gradient| DDPG ✅ | - | 2.7 | Auto-Scaling | ✅ | - | 1.3 -| Distributed Simulation | ✅ | 2 to 16 GPUs node| 1.4 +| Distributed Simulation | 1 GPU ✅ 2-16 GPU node ✅ | - | 1.4 | Environment Backend | CUDA C ✅ | - | 1.0 | Environment Backend | CUDA C ✅ Numba ✅ | - | 2.0 | Training Backend | Pytorch ✅ | - | 1.0 From fc65445f1bf8990fd834b63ce217977659cb3ff0 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Sun, 18 Feb 2024 17:01:10 -0800 Subject: [PATCH 18/19] version 2.7 --- CHANGELOG.md | 5 +++++ setup.py | 2 +- warp_drive/managers/numba_managers/numba_function_manager.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81dbe41..aa10123 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,9 @@ # Changelog +# Release 2.7 (2024-02-17) +- Support continuous actions +- Add Pendulum environment that can run up to 100K concurrent replicates +- Add DDPG algorithms for training continuous action policies + # Release 2.6.2 (2023-12-12) - Add Acrobot environment that can run up to 100K concurrent replicates. - Add Mountain Car environment that can run up to 100K concurrent replicates. diff --git a/setup.py b/setup.py index 1a27673..0d29b8c 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( name="rl-warp-drive", - version="2.6.2", + version="2.7", author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng", author_email="tian.lan@salesforce.com", description="Framework for fast end-to-end " diff --git a/warp_drive/managers/numba_managers/numba_function_manager.py b/warp_drive/managers/numba_managers/numba_function_manager.py index 43894b5..e122d71 100644 --- a/warp_drive/managers/numba_managers/numba_function_manager.py +++ b/warp_drive/managers/numba_managers/numba_function_manager.py @@ -574,7 +574,7 @@ def reset_when_done_from_pool( ) for name, pool_name in data_manager.reset_target_to_pool.items(): - f_shape = data_manager.get_shape(name) + f_shape = data_manager.get_shape(pool_name) assert f_shape[0] > 1, "reset function assumes the 0th dimension is n_pool" if len(f_shape) >= 3: if len(f_shape) > 3: From 3c96d8ec2d32ecda6e114064e9aeb204dbbaf4ae Mon Sep 17 00:00:00 2001 From: Tian Lan <31748898+Emerald01@users.noreply.github.com> Date: Sun, 18 Feb 2024 17:03:13 -0800 Subject: [PATCH 19/19] Update README.md --- README.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 5817d43..985bf84 100644 --- a/README.md +++ b/README.md @@ -12,18 +12,18 @@ The table below provides a visual overview of Warpdrive's key features and scala | | Support | Concurrency | Version :--- | :---: | :---: | :---: -| Environments | Single ✅ Multi ✅ | 1 to 1000 per GPU | 1.0 -| Agents | Single ✅ Multi ✅ | 1 to 1024 per environment | 1.0 -| Agents | Multi across blocks ✅| 1024 per block | 1.6 -| Discrete Actions | Single ✅ Multi ✅| - | 1.0 -| Continuous Action | Single ✅ Multi ✅| - | 2.7 -| On-Policy Policy Gradient | A2C ✅ PPO ✅ | - | 1.0 -| Off-Policy Policy Gradient| DDPG ✅ | - | 2.7 -| Auto-Scaling | ✅ | - | 1.3 -| Distributed Simulation | 1 GPU ✅ 2-16 GPU node ✅ | - | 1.4 -| Environment Backend | CUDA C ✅ | - | 1.0 -| Environment Backend | CUDA C ✅ Numba ✅ | - | 2.0 -| Training Backend | Pytorch ✅ | - | 1.0 +| Environments | Single ✅ Multi ✅ | 1 to 1000 per GPU | 1.0+ +| Agents | Single ✅ Multi ✅ | 1 to 1024 per environment | 1.0+ +| Agents | Multi across blocks ✅| 1024 per block | 1.6+ +| Discrete Actions | Single ✅ Multi ✅| - | 1.0+ +| Continuous Action | Single ✅ Multi ✅| - | 2.7+ +| On-Policy Policy Gradient | A2C ✅ PPO ✅ | - | 1.0+ +| Off-Policy Policy Gradient| DDPG ✅ | - | 2.7+ +| Auto-Scaling | ✅ | - | 1.3+ +| Distributed Simulation | 1 GPU ✅ 2-16 GPU node ✅ | - | 1.4+ +| Environment Backend | CUDA C ✅ | - | 1.0+ +| Environment Backend | CUDA C ✅ Numba ✅ | - | 2.0+ +| Training Backend | Pytorch ✅ | - | 1.0+ ## Environments