Skip to content

Commit 2f5be8f

Browse files
authored
Map weights to CUDA or CPU depending on which is enabled (#216)
1 parent d8fab14 commit 2f5be8f

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

ada_feeding_action_select/ada_feeding_action_select/adapters/hapticnet_adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,13 @@ def __init__(
4747

4848
# Init CUDA
4949
self.use_cuda = torch.cuda.is_available()
50+
self.device = torch.device("cuda") if self.use_cuda else torch.device("cpu")
5051
if self.use_cuda:
5152
logger.info("Init HapticNet with CUDA")
5253
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
5354
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
55+
else:
56+
logger.info("Init HapticNet with CPU")
5457

5558
# Init HapticNet
5659
self.config = HapticNetConfig(n_output=n_features)
@@ -60,7 +63,7 @@ def __init__(
6063
ckpt_file = os.path.join(
6164
get_package_share_directory("ada_feeding_action_select"), "data", checkpoint
6265
)
63-
ckpt = torch.load(ckpt_file)
66+
ckpt = torch.load(ckpt_file, map_location=self.device)
6467
self.hapticnet.load_state_dict(ckpt["state_dict"])
6568
self.hapticnet.eval()
6669
if self.use_cuda:

ada_feeding_action_select/ada_feeding_action_select/adapters/spanet_adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,13 @@ def __init__(
4949

5050
# Init CUDA
5151
self.use_cuda = torch.cuda.is_available()
52+
self.device = torch.device("cuda") if self.use_cuda else torch.device("cpu")
5253
if self.use_cuda:
5354
logger.info("Init SPANet with CUDA")
5455
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
5556
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
57+
else:
58+
logger.info("Init SPANet with CPU")
5659

5760
# Init SPANet
5861
self.config = SPANetConfig(n_features=n_features)
@@ -62,7 +65,7 @@ def __init__(
6265
ckpt_file = os.path.join(
6366
get_package_share_directory("ada_feeding_action_select"), "data", checkpoint
6467
)
65-
ckpt = torch.load(ckpt_file)
68+
ckpt = torch.load(ckpt_file, map_location=self.device)
6669
self.spanet.load_state_dict(ckpt["net"])
6770
self.spanet.eval()
6871
if self.use_cuda:

ada_feeding_action_select/ada_feeding_action_select/policy_service.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _init_checkpoints_record(self, context_cls: type, posthoc_cls: type) -> None
226226
)
227227
if len(pt_files) > 0:
228228
with open(pt_files[-1], "rb") as ckpt_file:
229-
ckpt = torch.load(ckpt_file)
229+
ckpt = torch.load(ckpt_file, map_location=self.device)
230230
try:
231231
if ckpt["context_cls"] != context_cls:
232232
self.get_logger().warning(
@@ -253,6 +253,8 @@ def __init__(self):
253253
register_logger(self.get_logger())
254254
self._declare_parameters()
255255

256+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
257+
256258
# Name of the Policy
257259
policy_name = self.get_parameter("policy").value
258260
if policy_name is None:

0 commit comments

Comments
 (0)