Skip to content

Commit

Permalink
Merge pull request #87 from salesforce/random-reset
Browse files Browse the repository at this point in the history
Random reset from the pool
  • Loading branch information
Emerald01 committed Jul 28, 2023
2 parents d48766e + 97fa410 commit a98b712
Show file tree
Hide file tree
Showing 20 changed files with 644 additions and 66 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,4 +1,7 @@
# Changelog
# Release 2.5 (2022-07-27)
- Introduce environment reset pool, so concurrent enviornment replicas can randomly reset themselves from the pool.

# Release 2.4 (2022-06-16)
- Introduce new device context management and autoinit_pycuda
- Therefore, Torch (any version) will not conflict with PyCUDA in the GPU context
Expand Down
7 changes: 1 addition & 6 deletions example_envs/tag_continuous/tag_continuous.py
Expand Up @@ -9,7 +9,6 @@

import numpy as np
from gym import spaces
from gym.utils import seeding

from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed
Expand Down Expand Up @@ -313,7 +312,7 @@ def seed(self, seed=None):
Note: this uses the code in
https://github.com/openai/gym/blob/master/gym/utils/seeding.py
"""
self.np_random, seed = seeding.np_random(seed)
self.np_random.seed(seed)
return [seed]

def set_global_state(self, key=None, value=None, t=None, dtype=None):
Expand Down Expand Up @@ -756,10 +755,6 @@ def get_data_dictionary(self):
)
return data_dict

def get_tensor_dictionary(self):
tensor_dict = DataFeed()
return tensor_dict

def reset(self):
"""
Env reset().
Expand Down
100 changes: 95 additions & 5 deletions example_envs/tag_gridworld/tag_gridworld.py
Expand Up @@ -6,7 +6,6 @@

import numpy as np
from gym import spaces
from gym.utils import seeding

# seeding code from https://github.com/openai/gym/blob/master/gym/utils/seeding.py
from warp_drive.utils.constants import Constants
Expand Down Expand Up @@ -130,7 +129,7 @@ def __init__(
name = "TagGridWorld"

def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
self.np_random.seed(seed)
return [seed]

def set_global_state(self, key=None, value=None, t=None, dtype=None):
Expand Down Expand Up @@ -349,9 +348,100 @@ def get_data_dictionary(self):
)
return data_dict

def get_tensor_dictionary(self):
tensor_dict = DataFeed()
return tensor_dict
def step(self, actions=None):
self.timestep += 1
args = [
_LOC_X,
_LOC_Y,
_ACTIONS,
"_done_",
_REWARDS,
_OBSERVATIONS,
"wall_hit_penalty",
"tag_reward_for_tagger",
"tag_penalty_for_runner",
"step_cost_for_tagger",
"use_full_observation",
"world_boundary",
"_timestep_",
("episode_length", "meta"),
]
if self.env_backend == "pycuda":
self.cuda_step(
*self.cuda_step_function_feed(args),
block=self.cuda_function_manager.block,
grid=self.cuda_function_manager.grid,
)
elif self.env_backend == "numba":
self.cuda_step[
self.cuda_function_manager.grid, self.cuda_function_manager.block
](*self.cuda_step_function_feed(args))
else:
raise Exception("CUDATagGridWorld expects env_backend = 'pycuda' or 'numba' ")


class CUDATagGridWorldWithResetPool(TagGridWorld, CUDAEnvironmentContext):
"""
CUDA version of the TagGridWorld environment and with reset pool for the starting point.
Note: this class subclasses the Python environment class TagGridWorld,
and also the CUDAEnvironmentContext
"""

def get_data_dictionary(self):
data_dict = DataFeed()
for feature in [
_LOC_X,
_LOC_Y,
]:
data_dict.add_data(
name=feature,
data=self.global_state[feature][0],
save_copy_and_apply_at_reset=False,
log_data_across_episode=False,
)
data_dict.add_data_list(
[
("wall_hit_penalty", self.wall_hit_penalty),
("tag_reward_for_tagger", self.tag_reward_for_tagger),
("tag_penalty_for_runner", self.tag_penalty_for_runner),
("step_cost_for_tagger", self.step_cost_for_tagger),
("use_full_observation", self.use_full_observation),
("world_boundary", self.grid_length),
]
)
return data_dict

def get_reset_pool_dictionary(self):

def _random_location_generator():
starting_location_x = self.np_random.choice(
np.linspace(1, int(self.grid_length) - 1, int(self.grid_length) - 1),
self.num_agents
).astype(np.int32)
starting_location_x[-1] = 0
starting_location_y = self.np_random.choice(
np.linspace(1, int(self.grid_length) - 1, int(self.grid_length) - 1),
self.num_agents
).astype(np.int32)
starting_location_y[-1] = 0
return starting_location_x, starting_location_y

N = 5 # we hard code the number of env pool for this demo purpose
x_pool = []
y_pool = []
for _ in range(N):
x, y = _random_location_generator()
x_pool.append(x)
y_pool.append(y)

x_pool = np.stack(x_pool, axis=0)
y_pool = np.stack(y_pool, axis=0)

reset_pool_dict = DataFeed()
reset_pool_dict.add_pool_for_reset(name=f"{_LOC_X}_reset_pool", data=x_pool, reset_target=_LOC_X)
reset_pool_dict.add_pool_for_reset(name=f"{_LOC_Y}_reset_pool", data=y_pool, reset_target=_LOC_Y)

return reset_pool_dict

def step(self, actions=None):
self.timestep += 1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
@@ -1,4 +1,4 @@
gym>=0.18, <0.26
gym>=0.26
matplotlib>=3.2.1
numpy>=1.18.1
pycuda>=2022.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -14,7 +14,7 @@

setup(
name="rl-warp-drive",
version="2.4",
version="2.5.0",
author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng",
author_email="tian.lan@salesforce.com",
description="Framework for fast end-to-end "
Expand Down
145 changes: 145 additions & 0 deletions tests/warp_drive/numba_tests/test_pool_reset.py
@@ -0,0 +1,145 @@
# Copyright (c) 2021, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root
# or https://opensource.org/licenses/BSD-3-Clause

import unittest

import numpy as np
import torch

from warp_drive.managers.numba_managers.numba_data_manager import NumbaDataManager
from warp_drive.managers.numba_managers.numba_function_manager import (
NumbaEnvironmentReset,
NumbaFunctionManager,
)
from warp_drive.utils.common import get_project_root
from warp_drive.utils.data_feed import DataFeed

_NUMBA_FILEPATH = f"warp_drive.numba_includes"


class TestEnvironmentReset(unittest.TestCase):
"""
Unit tests for the CUDA environment resetter
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dm = NumbaDataManager(num_agents=5, num_envs=2, episode_length=2)
self.fm = NumbaFunctionManager(
num_agents=int(self.dm.meta_info("n_agents")),
num_envs=int(self.dm.meta_info("n_envs")),
)
self.fm.import_numba_from_source_code(f"{_NUMBA_FILEPATH}.test_build")
self.resetter = NumbaEnvironmentReset(function_manager=self.fm)

def test_reset_for_different_dim(self):

self.dm.data_on_device_via_torch("_done_")[:] = torch.from_numpy(
np.array([1, 0])
).cuda()

done = self.dm.pull_data_from_device("_done_")
self.assertSequenceEqual(list(done), [1, 0])

# expected mean would be around 0.5 * (1+2+3+15) / 4 = 2.625
a_reset_pool = np.random.rand(4, 10, 10)
a_reset_pool[1] *= 2
a_reset_pool[2] *= 3
a_reset_pool[3] *= 15

b_reset_pool = np.random.rand(4, 100)
b_reset_pool[1] *= 2
b_reset_pool[2] *= 3
b_reset_pool[3] *= 15

c_reset_pool = np.random.rand(100)

data_feed = DataFeed()
data_feed.add_data(
name="a", data=np.random.randn(2, 10, 10), save_copy_and_apply_at_reset=False
)
data_feed.add_pool_for_reset(name="a_reset_pool", data=a_reset_pool, reset_target="a")
data_feed.add_data(
name="b", data=np.random.randn(2, 100), save_copy_and_apply_at_reset=False
)
data_feed.add_pool_for_reset(name="b_reset_pool", data=b_reset_pool, reset_target="b")
data_feed.add_data(
name="c", data=np.random.randn(2), save_copy_and_apply_at_reset=False
)
data_feed.add_pool_for_reset(name="c_reset_pool", data=c_reset_pool, reset_target="c")

self.dm.push_data_to_device(data_feed)

self.resetter.init_reset_pool(self.dm)

a = self.dm.pull_data_from_device("a")
b = self.dm.pull_data_from_device("b")
c = self.dm.pull_data_from_device("c")

# soft reset
a_after_reset_0_mean = []
a_after_reset_1_mean = []
b_after_reset_0_mean = []
b_after_reset_1_mean = []
c_after_reset_0_mean = []
c_after_reset_1_mean = []

for _ in range(2000):
self.resetter.reset_when_done(self.dm, mode="if_done", undo_done_after_reset=False)
a_after_reset = self.dm.pull_data_from_device("a")
a_after_reset_0_mean.append(a_after_reset[0].mean())
a_after_reset_1_mean.append(a_after_reset[1].mean())
b_after_reset = self.dm.pull_data_from_device("b")
b_after_reset_0_mean.append(b_after_reset[0].mean())
b_after_reset_1_mean.append(b_after_reset[1].mean())
c_after_reset = self.dm.pull_data_from_device("c")
c_after_reset_0_mean.append(c_after_reset[0].mean())
c_after_reset_1_mean.append(c_after_reset[1].mean())
# env 0 should have 1000 times random reset from the pool, so it should close to a_reset_pool.mean()
print(a_reset_pool.mean())
print(np.array(a_after_reset_0_mean).mean())
self.assertTrue(np.absolute(a_reset_pool.mean() - np.array(a_after_reset_0_mean).mean()) < 5e-1)
print(b_reset_pool.mean())
print(np.array(b_after_reset_0_mean).mean())
self.assertTrue(np.absolute(b_reset_pool.mean() - np.array(b_after_reset_0_mean).mean()) < 5e-1)
print(c_reset_pool.mean())
print(np.array(c_after_reset_0_mean).mean())
self.assertTrue(np.absolute(c_reset_pool.mean() - np.array(c_after_reset_0_mean).mean()) < 5e-1)
# env 1 has no reset at all, so it should be exactly the same as the original one
self.assertTrue(np.absolute(a[1].mean() - np.array(a_after_reset_1_mean).mean()) < 1e-5)
self.assertTrue(np.absolute(b[1].mean() - np.array(b_after_reset_1_mean).mean()) < 1e-5)
self.assertTrue(np.absolute(c[1].mean() - np.array(c_after_reset_1_mean).mean()) < 1e-5)

# hard reset
a_after_reset_0_mean = []
a_after_reset_1_mean = []
b_after_reset_0_mean = []
b_after_reset_1_mean = []
c_after_reset_0_mean = []
c_after_reset_1_mean = []
for _ in range(2000):
self.resetter.reset_when_done(self.dm, mode="force_reset", undo_done_after_reset=False)
a_after_reset = self.dm.pull_data_from_device("a")
a_after_reset_0_mean.append(a_after_reset[0].mean())
a_after_reset_1_mean.append(a_after_reset[1].mean())
b_after_reset = self.dm.pull_data_from_device("b")
b_after_reset_0_mean.append(b_after_reset[0].mean())
b_after_reset_1_mean.append(b_after_reset[1].mean())
c_after_reset = self.dm.pull_data_from_device("c")
c_after_reset_0_mean.append(c_after_reset[0].mean())
c_after_reset_1_mean.append(c_after_reset[1].mean())
# env 0 should have 1000 times random reset from the pool, so it should close to a_reset_pool.mean()
self.assertTrue(np.absolute(a_reset_pool.mean() - np.array(a_after_reset_0_mean).mean()) < 5e-1)
self.assertTrue(np.absolute(b_reset_pool.mean() - np.array(b_after_reset_0_mean).mean()) < 5e-1)
self.assertTrue(np.absolute(c_reset_pool.mean() - np.array(c_after_reset_0_mean).mean()) < 5e-1)
# env 1 should have 1000 times random reset from the pool, so it should close to a_reset_pool.mean()
self.assertTrue(np.absolute(a_reset_pool.mean() - np.array(a_after_reset_1_mean).mean()) < 5e-1)
self.assertTrue(np.absolute(b_reset_pool.mean() - np.array(b_after_reset_1_mean).mean()) < 5e-1)
self.assertTrue(np.absolute(c_reset_pool.mean() - np.array(c_after_reset_1_mean).mean()) < 5e-1)




6 changes: 2 additions & 4 deletions warp_drive/env_cpu_gpu_consistency_checker.py
Expand Up @@ -13,7 +13,6 @@
import numpy as np
import torch
from gym.spaces import Discrete, MultiDiscrete
from gym.utils import seeding

from warp_drive.env_wrapper import EnvWrapper
from warp_drive.training.utils.data_loader import (
Expand All @@ -38,10 +37,9 @@ def generate_random_actions(env, num_envs, seed=None):
Generate random actions for each agent and each env.
"""
agent_ids = list(env.action_space.keys())
np_random = np.random
if seed is not None:
np_random = seeding.np_random(seed)[0]
else:
np_random = np.random
np_random.seed(seed)

return [
{
Expand Down
28 changes: 27 additions & 1 deletion warp_drive/env_wrapper.py
Expand Up @@ -291,9 +291,10 @@ def repeat_across_env_dimension(array, num_envs):
# Copy host data and tensors to device
# Note: this happens only once after the first reset on the host

# Add env dimension to data if "save_copy_and_apply_at_reset" is True
data_dictionary = self.env.get_data_dictionary()
tensor_dictionary = self.env.get_tensor_dictionary()
reset_pool_dictionary = self.env.get_reset_pool_dictionary()
# Add env dimension to data if "save_copy_and_apply_at_reset" is True
for key in data_dictionary:
if data_dictionary[key]["attributes"][
"save_copy_and_apply_at_reset"
Expand All @@ -309,13 +310,35 @@ def repeat_across_env_dimension(array, num_envs):
tensor_dictionary[key]["data"] = repeat_across_env_dimension(
tensor_dictionary[key]["data"], self.n_envs
)
# Add env dimension to data if "is_reset_pool" exists for this data
# if so, also check this data has "save_copy_and_apply_at_reset" = False
for key in reset_pool_dictionary:
if "is_reset_pool" in reset_pool_dictionary[key]["attributes"] and \
reset_pool_dictionary[key]["attributes"]["is_reset_pool"]:
# find the corresponding target data
reset_target = reset_pool_dictionary[key]["attributes"]["reset_target"]
if reset_target in data_dictionary:
assert not data_dictionary[reset_target]["attributes"]["save_copy_and_apply_at_reset"]
data_dictionary[reset_target]["data"] = repeat_across_env_dimension(
data_dictionary[reset_target]["data"], self.n_envs
)
elif reset_target in tensor_dictionary:
assert not tensor_dictionary[reset_target]["attributes"]["save_copy_and_apply_at_reset"]
tensor_dictionary[reset_target]["data"] = repeat_across_env_dimension(
tensor_dictionary[reset_target]["data"], self.n_envs
)
else:
raise Exception(f"Fail to locate the target data {reset_target} for the reset pool "
f"in neither data_dictionary nor tensor_dictionary")

self.cuda_data_manager.push_data_to_device(data_dictionary)

self.cuda_data_manager.push_data_to_device(
tensor_dictionary, torch_accessible=True
)

self.cuda_data_manager.push_data_to_device(reset_pool_dictionary)

# All subsequent resets happen on the GPU
self.reset_on_host = False

Expand All @@ -329,6 +352,9 @@ def repeat_across_env_dimension(array, num_envs):
return {}
return obs # CPU version

def init_reset_pool(self, seed=None):
self.env_resetter.init_reset_pool(self.cuda_data_manager, seed)

def reset_only_done_envs(self):
"""
This function only works for GPU example_envs.
Expand Down

0 comments on commit a98b712

Please sign in to comment.