Skip to content

Commit

Permalink
Merge pull request #73 from salesforce/qlearner
Browse files Browse the repository at this point in the history
Qlearner
  • Loading branch information
Emerald01 committed Mar 8, 2023
2 parents ca6640d + 3a0c62c commit b5d46d4
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 24 deletions.
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -14,7 +14,7 @@

setup(
name="rl-warp-drive",
version="2.2.1",
version="2.2.2",
author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng",
author_email="stephan.zheng@salesforce.com",
description="Framework for fast end-to-end "
Expand Down
15 changes: 14 additions & 1 deletion warp_drive/cuda_includes/core/random.cu
Expand Up @@ -49,12 +49,25 @@ __device__ int search_index(float* distr, float p, int l, int r) {
}

__global__ void sample_actions(float* distr, int* action_indices,
float* cum_distr, int num_agents, int num_actions) {
float* cum_distr, int num_agents, int num_actions, int use_argmax) {
int posidx = blockIdx.x*blockDim.x + threadIdx.x;
if (posidx >= wkNumberEnvs * num_agents)
return;
int dist_index = posidx * num_actions;

if (use_argmax > 0.5){
float max_p = distr[dist_index];
int max_ind = 0;
for (int i = 1; i < num_actions; i++){
if (max_p < distr[dist_index + i]){
max_p = distr[dist_index + i];
max_ind = i;
}
}
action_indices[posidx] = max_ind;
return;
}

curandState_t s = *states[posidx];
float p = curand_uniform(&s);
*states[posidx] = s;
Expand Down
2 changes: 1 addition & 1 deletion warp_drive/cuda_includes/core/random.h
Expand Up @@ -18,6 +18,6 @@ extern "C" __global__ void free_random();
// binary search to get the action index
__device__ int search_index(float* distr, float p, int l, int r);

extern "C" __global__ void sample_actions(float*, int*, float*, int, int);
extern "C" __global__ void sample_actions(float*, int*, float*, int, int, int);

#endif // CUDA_INCLUDES_RANDOM_STATES_H_
3 changes: 3 additions & 0 deletions warp_drive/managers/numba_managers/numba_function_manager.py
Expand Up @@ -291,6 +291,7 @@ def sample(
data_manager: NumbaDataManager,
distribution: torch.Tensor,
action_name: str,
use_argmax: bool = False,
):
"""
Sample based on the distribution
Expand All @@ -300,6 +301,7 @@ def sample(
(num_env, num_agents, num_actions)
:param action_name: the name of action array that will
record the sampled actions
:param use_argmax: if True, sample based on the argmax(distribution)
"""
assert self._random_initialized, (
"sample() requires the random seed initialized first, "
Expand All @@ -323,6 +325,7 @@ def sample(
data_manager.device_data(action_name),
data_manager.device_data(f"{action_name}_cum_distr"),
np.int32(n_actions),
np.int32(use_argmax),
)


Expand Down
Expand Up @@ -534,6 +534,7 @@ def sample(
data_manager: PyCUDADataManager,
distribution: torch.Tensor,
action_name: str,
use_argmax: bool = False,
):
"""
Sample based on the distribution
Expand All @@ -543,6 +544,7 @@ def sample(
(num_env, num_agents, num_actions)
:param action_name: the name of action array that will
record the sampled actions
:param use_argmax: if True, sample based on the argmax(distribution)
"""
assert self._random_initialized, (
"sample() requires the random seed initialized first, "
Expand All @@ -564,6 +566,7 @@ def sample(
data_manager.device_data(f"{action_name}_cum_distr"),
np.int32(n_agents),
np.int32(n_actions),
np.int32(use_argmax),
block=((n_agents - 1) // self._blocks_per_env + 1, 1, 1),
grid=self._grid,
)
Expand Down
17 changes: 14 additions & 3 deletions warp_drive/numba_includes/core/random.py
@@ -1,5 +1,5 @@
from numba import cuda as numba_driver
from numba import float32, int32, from_dtype
from numba import float32, int32, boolean, from_dtype
from numba.cuda.random import init_xoroshiro128p_states, xoroshiro128p_uniform_float32
import numpy as np

Expand Down Expand Up @@ -30,14 +30,25 @@ def init_random(rng_states, seed):
init_xoroshiro128p_states(states=rng_states, seed=seed)


@numba_driver.jit((xoroshiro128p_type[::1], float32[:, :, ::1], int32[:, :, ::1], float32[:, :, ::1], int32))
def sample_actions(rng_states, distr, action_indices, cum_distr, num_actions):
@numba_driver.jit((xoroshiro128p_type[::1], float32[:, :, ::1], int32[:, :, ::1], float32[:, :, ::1], int32, int32))
def sample_actions(rng_states, distr, action_indices, cum_distr, num_actions, use_argmax):
env_id = numba_driver.blockIdx.x
# Block id in a 1D grid
agent_id = numba_driver.threadIdx.x
posidx = numba_driver.grid(1)
if posidx >= rng_states.shape[0]:
return

if use_argmax > 0.5:
max_dist = distr[env_id, agent_id, 0]
max_ind = 0
for i in range(1, num_actions):
if max_dist < distr[env_id, agent_id, i]:
max_dist = distr[env_id, agent_id, i]
max_ind = i
action_indices[env_id, agent_id, 0] = max_ind
return

p = xoroshiro128p_uniform_float32(rng_states, posidx)

cum_distr[env_id, agent_id, 0] = distr[env_id, agent_id, 0]
Expand Down
5 changes: 5 additions & 0 deletions warp_drive/training/algorithms/policygradient/__init__.py
@@ -0,0 +1,5 @@
# Copyright (c) 2021, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root
# or https://opensource.org/licenses/BSD-3-Clause
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion warp_drive/training/models/fully_connected.py
Expand Up @@ -210,7 +210,8 @@ def forward(self, obs=None, batch_index=None):
# Push the processed (in this case, flattened) obs to the GPU (device).
# The writing happens to a specific batch index in the processed obs batch.
# The processed obs batch is required for training.
self.push_processed_obs_to_batch(batch_index, ip)
if batch_index >= 0:
self.push_processed_obs_to_batch(batch_index, ip)

else:
ip = obs
Expand Down
4 changes: 2 additions & 2 deletions warp_drive/training/pytorch_lightning.py
Expand Up @@ -28,8 +28,8 @@
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from warp_drive.training.algorithms.a2c import A2C
from warp_drive.training.algorithms.ppo import PPO
from warp_drive.training.algorithms.policygradient.a2c import A2C
from warp_drive.training.algorithms.policygradient.ppo import PPO
from warp_drive.training.models.fully_connected import FullyConnected
from warp_drive.training.trainer import Metrics
from warp_drive.training.utils.data_loader import create_and_push_data_placeholders
Expand Down
83 changes: 68 additions & 15 deletions warp_drive/training/trainer.py
Expand Up @@ -21,8 +21,8 @@
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP

from warp_drive.training.algorithms.a2c import A2C
from warp_drive.training.algorithms.ppo import PPO
from warp_drive.training.algorithms.policygradient.a2c import A2C
from warp_drive.training.algorithms.policygradient.ppo import PPO
from warp_drive.training.models.fully_connected import FullyConnected
from warp_drive.training.utils.data_loader import create_and_push_data_placeholders
from warp_drive.training.utils.param_scheduler import ParamScheduler
Expand Down Expand Up @@ -255,6 +255,8 @@ def __init__(
self.num_completed_episodes = {}
self.episodic_reward_sum = {}
self.reward_running_sum = {}
self.episodic_step_sum = {}
self.step_running_sum = {}

# Indicates the current timestep of the policy model
self.current_timestep = {}
Expand Down Expand Up @@ -304,6 +306,10 @@ def __init__(
self.reward_running_sum[policy] = torch.zeros(
(self.num_envs, num_agents_for_policy)
).cuda()
self.episodic_step_sum[policy] = (
torch.tensor(0).type(torch.int32).cuda()
)
self.step_running_sum[policy] = torch.zeros(self.num_envs, dtype=torch.int32).cuda()

# Initialize the trainers
self.trainers = {}
Expand Down Expand Up @@ -371,6 +377,13 @@ def _initialize_policy_model(self, policy):
self.create_separate_placeholders_for_each_policy,
self.obs_dim_corresponding_to_num_agents,
)
if "init_method" in policy_model_config and \
policy_model_config["init_method"] == "xavier":
def init_weights_by_xavier_uniform(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform(m.weight)

model.apply(init_weights_by_xavier_uniform)
else:
raise NotImplementedError
self.models[policy] = model
Expand Down Expand Up @@ -508,7 +521,7 @@ def _evaluate_policies(self, batch_index=0):

return probabilities

def _sample_actions(self, probabilities, batch_index=0):
def _sample_actions(self, probabilities, batch_index=0, use_argmax=False):
"""
Sample action probabilities (and push the sampled actions to the device).
"""
Expand All @@ -518,7 +531,7 @@ def _sample_actions(self, probabilities, batch_index=0):
# Sample each individual policy
policy_suffix = f"_{policy}"
self._sample_actions_helper(
probabilities[policy], policy_suffix=policy_suffix
probabilities[policy], policy_suffix=policy_suffix, use_argmax=use_argmax
)
# Push the actions to the batch
actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch(
Expand All @@ -531,7 +544,7 @@ def _sample_actions(self, probabilities, batch_index=0):
assert len(probabilities) == 1
policy = list(probabilities.keys())[0]
# sample a single or a combined policy
self._sample_actions_helper(probabilities[policy])
self._sample_actions_helper(probabilities[policy], use_argmax=use_argmax)
actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch(
_ACTIONS
)
Expand All @@ -549,20 +562,20 @@ def _sample_actions(self, probabilities, batch_index=0):
name=f"{_ACTIONS}_batch_{policy}"
)[batch_index] = actions

def _sample_actions_helper(self, probabilities, policy_suffix=""):
def _sample_actions_helper(self, probabilities, policy_suffix="", use_argmax=False):
# Sample actions with policy_suffix tag
num_action_types = len(probabilities)

if num_action_types == 1:
action_name = _ACTIONS + policy_suffix
self.cuda_sample_controller.sample(
self.cuda_envs.cuda_data_manager, probabilities[0], action_name
self.cuda_envs.cuda_data_manager, probabilities[0], action_name, use_argmax
)
else:
for action_type_id, probs in enumerate(probabilities):
action_name = f"{_ACTIONS}_{action_type_id}" + policy_suffix
self.cuda_sample_controller.sample(
self.cuda_envs.cuda_data_manager, probs, action_name
self.cuda_envs.cuda_data_manager, probs, action_name, use_argmax
)
# Push (indexed) actions to 'actions'
actions = self.cuda_envs.cuda_data_manager.data_on_device_via_torch(
Expand Down Expand Up @@ -628,16 +641,21 @@ def _bookkeep_rewards_and_done_flags(self, batch_index):

def _update_episodic_rewards(self, rewards, done_env_ids, policy):
self.reward_running_sum[policy] += rewards
self.step_running_sum[policy] += 1

num_completed_episodes = len(done_env_ids)
if num_completed_episodes > 0:
# Update the episodic rewards
self.episodic_reward_sum[policy] += torch.sum(
self.reward_running_sum[policy][done_env_ids]
)
self.episodic_step_sum[policy] += torch.sum(
self.step_running_sum[policy][done_env_ids]
)
self.num_completed_episodes[policy] += num_completed_episodes
# Reset the reward running sum
self.reward_running_sum[policy][done_env_ids] = 0
self.step_running_sum[policy][done_env_ids] = 0

def _update_model_params(self, iteration):
start_event = torch.cuda.Event(enable_timing=True)
Expand Down Expand Up @@ -726,15 +744,21 @@ def _update_model_params(self, iteration):
"Learning rate": lr,
"Mean episodic reward": self.episodic_reward_sum[policy].item()
/ (self.num_completed_episodes[policy] + _EPSILON),
"Mean episodic steps": self.episodic_step_sum[policy].item()
/ (self.num_completed_episodes[policy] + _EPSILON),
}
)

# Reset sum and counter
self.episodic_reward_sum[policy] = (
torch.tensor(0).type(torch.float32).cuda()
)
self.episodic_step_sum[policy] = (
torch.tensor(0).type(torch.int32).cuda()
)
self.num_completed_episodes[policy] = 0


end_event.record()
torch.cuda.synchronize()

Expand Down Expand Up @@ -874,6 +898,9 @@ def fetch_episode_states(
self,
list_of_states=None, # list of states (data array names) to fetch
env_id=0, # environment id to fetch the states from
include_rewards_actions=False, # flag to output reward and action
policy="", # if include_rewards_actions=True, the corresponding policy tag if any
use_argmax=False,
):
"""
Step through env and fetch the desired states (data arrays on the GPU)
Expand All @@ -890,7 +917,6 @@ def fetch_episode_states(
env = self.cuda_envs.env

episode_states = {}

for state in list_of_states:
assert self.cuda_envs.cuda_data_manager.is_data_on_device(
state
Expand All @@ -903,23 +929,48 @@ def fetch_episode_states(
[np.ones(array_shape) for _ in range(env.episode_length + 1)]
)

if include_rewards_actions:
policy_suffix = f"_{policy}" if len(policy) > 0 else ""
action_name = _ACTIONS + policy_suffix
reward_name = _REWARDS + policy_suffix
# Note the size is 1 step smaller than states because we do not have r_0 and a_T
episode_actions = np.zeros(
(
env.episode_length, *self.cuda_envs.cuda_data_manager.get_shape(action_name)[1:]
),
dtype=np.int32
)
episode_rewards= np.zeros(
(
env.episode_length, *self.cuda_envs.cuda_data_manager.get_shape(reward_name)[1:]
),
dtype=np.float32)

for timestep in range(env.episode_length):
# Update the episode states
# Update the episode states s_t
for state in list_of_states:
episode_states[state][
timestep
] = self.cuda_envs.cuda_data_manager.pull_data_from_device(state)[
env_id
]
# Evaluate policies to compute action probabilities
probabilities = self._evaluate_policies()
# Evaluate policies to compute action probabilities, we set batch_index=-1 to avoid batch writing
probabilities = self._evaluate_policies(batch_index=-1)

# Sample actions
self._sample_actions(probabilities)
self._sample_actions(probabilities, use_argmax=use_argmax)

# Step through all the environments
self.cuda_envs.step_all_envs()

if include_rewards_actions:
# Update the episode action a_t
episode_actions[timestep] = \
self.cuda_envs.cuda_data_manager.pull_data_from_device(action_name)[env_id]
# Update the episode reward r_(t+1)
episode_rewards[timestep] = \
self.cuda_envs.cuda_data_manager.pull_data_from_device(reward_name)[env_id]

# Fetch the states when episode is complete
if env.cuda_data_manager.pull_data_from_device("_done_")[env_id]:
for state in list_of_states:
Expand All @@ -929,8 +980,10 @@ def fetch_episode_states(
env_id
]
break

return episode_states
if include_rewards_actions:
return episode_states, episode_actions, episode_rewards
else:
return episode_states


class PerfStats:
Expand Down

0 comments on commit b5d46d4

Please sign in to comment.