From 5e103b3dee0d580a478860af65180e586aa76056 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Wed, 27 Jul 2022 13:55:53 -0700 Subject: [PATCH] add custom resetter --- warp_drive/env_wrapper.py | 9 +++++++++ warp_drive/managers/function_manager.py | 26 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/warp_drive/env_wrapper.py b/warp_drive/env_wrapper.py index b2ff9d5..9cfe2ce 100644 --- a/warp_drive/env_wrapper.py +++ b/warp_drive/env_wrapper.py @@ -185,6 +185,11 @@ def __init__( self.env_resetter = CUDAEnvironmentReset( function_manager=self.cuda_function_manager ) + # custom reset function, if not found, will ignore + reset_function = f"Cuda{self.name}Reset" + self.env_resetter.register_custom_reset_function( + self.cuda_data_manager, + reset_function_name=reset_function) def reset_all_envs(self): """ @@ -268,6 +273,10 @@ def reset_only_done_envs(self): self.env_resetter.reset_when_done(self.cuda_data_manager, mode="if_done") return {} + def custom_reset_all_envs(self, args=None, block=None, grid=None): + self.env_resetter.custom_reset(args=args, block=block, grid=grid) + return {} + def step_all_envs(self, actions=None): """ Step through all the environments diff --git a/warp_drive/managers/function_manager.py b/warp_drive/managers/function_manager.py index 25d47a2..9282520 100644 --- a/warp_drive/managers/function_manager.py +++ b/warp_drive/managers/function_manager.py @@ -827,6 +827,32 @@ def __init__(self, function_manager: CUDAFunctionManager): "undo_done_flag_and_reset_timestep" ) + self._cuda_custom_reset = None + self._cuda_reset_feed = None + + def register_custom_reset_function(self, data_manager: CUDADataManager, reset_function_name=None): + if reset_function_name is None or reset_function_name not in self._function_manager._cuda_function_names: + return + self._cuda_custom_reset = self._function_manager.get_function(reset_function_name) + self._cuda_reset_feed = CUDAFunctionFeed(data_manager) + + def custom_reset(self, + args: Optional[list] = None, + block=None, + grid=None): + + assert self._cuda_custom_reset is not None and self._cuda_reset_feed is not None, \ + "Custom Reset function is not defined, call register_custom_reset_function() first" + assert args is None or isinstance(args, list) + if block is None: + block = self._block + if grid is None: + grid = self._grid + if args is None or len(args) == 0: + self._cuda_custom_reset(block=block, grid=grid) + else: + self._cuda_custom_reset(*self._cuda_reset_feed(args), block=block, grid=grid) + def reset_when_done( self, data_manager: CUDADataManager,