diff --git a/CHANGELOG.md b/CHANGELOG.md index e203c12..e4bfa33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,9 @@ # Changelog +# Release 2.3 (2022-03-22) +- Add ModelFactory class to manage custom models +- Add Xavier initialization for the model +- Improve trainer.fetch_episode_states() so it can fetch (s, a, r) and can replay with argmax. + # Release 2.2 (2022-12-20) - Factorize the data loading for placeholders and batches (obs, actions and rewards) for the trainer. diff --git a/setup.py b/setup.py index 2136f18..d50970b 100644 --- a/setup.py +++ b/setup.py @@ -14,9 +14,9 @@ setup( name="rl-warp-drive", - version="2.2.2", + version="2.3", author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng", - author_email="stephan.zheng@salesforce.com", + author_email="tian.lan@salesforce.com", description="Framework for fast end-to-end " "multi-agent reinforcement learning on GPUs.", long_description=open("README.md", "r", encoding="utf-8").read(), diff --git a/warp_drive/training/models/factory.py b/warp_drive/training/models/factory.py new file mode 100644 index 0000000..8bb039c --- /dev/null +++ b/warp_drive/training/models/factory.py @@ -0,0 +1,55 @@ +import importlib + +# warpdrive reserved models +default_models = { + "fully_connected": "warp_drive.training.models.fully_connected:FullyConnected", +} + + +def dynamic_import(model_name: str, model_pool: dict): + """ + Dynamically import a member from the specified module. + :param model_name: the name of the model, e.g., fully_connected + :param model_pool: the dictionary of all available models + :return: imported class + """ + + if model_name not in model_pool: + raise ValueError( + f"model_name {model_name} should be registered in the model factory in the form of," + f"e.g. {'fully_connected': 'warp_drive.training.models.fully_connected:FullyConnected' } " + ) + if ":" not in model_pool[model_name]: + raise ValueError( + f"Invalid model path format, expect ':' to separate the path and the object name" + f"e.g. 'warp_drive.training.models.fully_connected:FullyConnected' " + ) + + module_name, objname = model_pool[model_name].split(":") + m = importlib.import_module(module_name) + return getattr(m, objname) + + +class ModelFactory: + + model_pool = {} + + @classmethod + def add(cls, model_name: str, model_path: str, object_name: str): + """ + + :param model_name: e.g., "fully_connected" + :param model_path: e.g., "warp_drive.training.models.fully_connected" + :param object_name: e.g., "FullyConnected" + :return: + :rtype: + """ + assert model_name not in default_models and model_name not in cls.model_pool, \ + f"{model_name} has already been used by the model collection" + + cls.model_pool.update({model_name: f"{model_path}:{object_name}"}) + + @classmethod + def create(cls, model_name): + cls.model_pool.update(default_models) + return dynamic_import(model_name, model_pool=cls.model_pool) diff --git a/warp_drive/training/models/fully_connected.py b/warp_drive/training/models/fully_connected.py index b59893a..fa85cf5 100644 --- a/warp_drive/training/models/fully_connected.py +++ b/warp_drive/training/models/fully_connected.py @@ -50,7 +50,7 @@ class FullyConnected(nn.Module): def __init__( self, env, - fc_dims, + model_config, policy, policy_tag_to_agent_id_map, create_separate_placeholders_for_each_policy=False, @@ -59,6 +59,7 @@ def __init__( super().__init__() self.env = env + fc_dims = model_config["fc_dims"] assert isinstance(fc_dims, list) num_fc_layers = len(fc_dims) self.policy = policy diff --git a/warp_drive/training/pytorch_lightning.py b/warp_drive/training/pytorch_lightning.py index 654cbb3..4834f18 100644 --- a/warp_drive/training/pytorch_lightning.py +++ b/warp_drive/training/pytorch_lightning.py @@ -30,7 +30,7 @@ from warp_drive.training.algorithms.policygradient.a2c import A2C from warp_drive.training.algorithms.policygradient.ppo import PPO -from warp_drive.training.models.fully_connected import FullyConnected +from warp_drive.training.models.factory import ModelFactory from warp_drive.training.trainer import Metrics from warp_drive.training.utils.data_loader import create_and_push_data_placeholders from warp_drive.training.utils.param_scheduler import LRScheduler, ParamScheduler @@ -353,17 +353,15 @@ def _initialize_policy_algorithm(self, policy): def _initialize_policy_model(self, policy): policy_model_config = self._get_config(["policy", policy, "model"]) - if policy_model_config["type"] == "fully_connected": - model = FullyConnected( - self.cuda_envs, - policy_model_config["fc_dims"], - policy, - self.policy_tag_to_agent_id_map, - self.create_separate_placeholders_for_each_policy, - self.obs_dim_corresponding_to_num_agents, - ) - else: - raise NotImplementedError + model_obj = ModelFactory.create(policy_model_config["type"]) + model = model_obj( + env=self.cuda_envs, + model_config=policy_model_config, + policy=policy, + policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, + ) self.models[policy] = model def _get_config(self, args): diff --git a/warp_drive/training/trainer.py b/warp_drive/training/trainer.py index 7d66bab..d74600f 100644 --- a/warp_drive/training/trainer.py +++ b/warp_drive/training/trainer.py @@ -23,7 +23,7 @@ from warp_drive.training.algorithms.policygradient.a2c import A2C from warp_drive.training.algorithms.policygradient.ppo import PPO -from warp_drive.training.models.fully_connected import FullyConnected +from warp_drive.training.models.factory import ModelFactory from warp_drive.training.utils.data_loader import create_and_push_data_placeholders from warp_drive.training.utils.param_scheduler import ParamScheduler from warp_drive.utils.common import get_project_root @@ -368,24 +368,24 @@ def _initialize_policy_algorithm(self, policy): def _initialize_policy_model(self, policy): policy_model_config = self._get_config(["policy", policy, "model"]) - if policy_model_config["type"] == "fully_connected": - model = FullyConnected( - self.cuda_envs, - policy_model_config["fc_dims"], - policy, - self.policy_tag_to_agent_id_map, - self.create_separate_placeholders_for_each_policy, - self.obs_dim_corresponding_to_num_agents, - ) - if "init_method" in policy_model_config and \ - policy_model_config["init_method"] == "xavier": - def init_weights_by_xavier_uniform(m): - if isinstance(m, nn.Linear): - torch.nn.init.xavier_uniform(m.weight) + model_obj = ModelFactory.create(policy_model_config["type"]) + model = model_obj( + env=self.cuda_envs, + model_config=policy_model_config, + policy=policy, + policy_tag_to_agent_id_map=self.policy_tag_to_agent_id_map, + create_separate_placeholders_for_each_policy=self.create_separate_placeholders_for_each_policy, + obs_dim_corresponding_to_num_agents=self.obs_dim_corresponding_to_num_agents, + ) + + if "init_method" in policy_model_config and \ + policy_model_config["init_method"] == "xavier": + def init_weights_by_xavier_uniform(m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform(m.weight) + + model.apply(init_weights_by_xavier_uniform) - model.apply(init_weights_by_xavier_uniform) - else: - raise NotImplementedError self.models[policy] = model def _get_config(self, args):