Skip to content

Commit

Permalink
Merge pull request #83 from salesforce/pycuda-context
Browse files Browse the repository at this point in the history
new pycuda device context
  • Loading branch information
Emerald01 committed Jun 16, 2023
2 parents 2632a07 + 6848024 commit 4c6acb2
Show file tree
Hide file tree
Showing 21 changed files with 82 additions and 28 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
@@ -1,4 +1,8 @@
# Changelog
# Release 2.4 (2022-06-16)
- Introduce new device context management and autoinit_pycuda
- Therefore, Torch (any version) will not conflict with PyCUDA in the GPU context

# Release 2.3 (2022-03-22)
- Add ModelFactory class to manage custom models
- Add Xavier initialization for the model
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
@@ -1,8 +1,8 @@
gym>=0.18, <0.26
matplotlib>=3.2.1
numpy>=1.18.1
pycuda==2022.1
pycuda>=2022.1
pytest>=6.1.0
pyyaml>=5.4
torch>=1.9, <1.11
torch>=1.9
numba>=0.54.0
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -14,7 +14,7 @@

setup(
name="rl-warp-drive",
version="2.3",
version="2.4",
author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng",
author_email="tian.lan@salesforce.com",
description="Framework for fast end-to-end "
Expand Down
Expand Up @@ -19,7 +19,6 @@
from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)

_CUBIN_FILEPATH = f"{get_project_root()}/warp_drive/cuda_bin"
_OBSERVATIONS = Constants.OBSERVATIONS
Expand Down
Expand Up @@ -18,7 +18,6 @@
from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)

_CUBIN_FILEPATH = f"{get_project_root()}/warp_drive/cuda_bin"
_ACTIONS = Constants.ACTIONS
Expand Down
Expand Up @@ -17,7 +17,6 @@
from warp_drive.utils.common import get_project_root
from warp_drive.utils.data_feed import DataFeed

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)

_CUBIN_FILEPATH = f"{get_project_root()}/warp_drive/cuda_bin"

Expand Down
Expand Up @@ -19,7 +19,6 @@
from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)

_CUBIN_FILEPATH = f"{get_project_root()}/warp_drive/cuda_bin"
_ACTIONS = Constants.ACTIONS
Expand Down
1 change: 0 additions & 1 deletion tests/warp_drive/pycuda_tests/test_action_sampler.py
Expand Up @@ -18,7 +18,6 @@
from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)

_CUBIN_FILEPATH = f"{get_project_root()}/warp_drive/cuda_bin"
_ACTIONS = Constants.ACTIONS
Expand Down
2 changes: 0 additions & 2 deletions tests/warp_drive/pycuda_tests/test_env_reset.py
Expand Up @@ -17,8 +17,6 @@
from warp_drive.utils.common import get_project_root
from warp_drive.utils.data_feed import DataFeed

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)

_CUBIN_FILEPATH = f"{get_project_root()}/warp_drive/cuda_bin"


Expand Down
1 change: 0 additions & 1 deletion tests/warp_drive/pycuda_tests/test_function_manager.py
Expand Up @@ -19,7 +19,6 @@
from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)

_CUBIN_FILEPATH = f"{get_project_root()}/warp_drive/cuda_bin"
_ACTIONS = Constants.ACTIONS
Expand Down
4 changes: 1 addition & 3 deletions tutorials/simple-end-to-end-example.ipynb
Expand Up @@ -82,9 +82,7 @@
"source": [
"from example_envs.tag_continuous.tag_continuous import TagContinuous\n",
"from warp_drive.env_wrapper import EnvWrapper\n",
"from warp_drive.training.trainer import Trainer\n",
"\n",
"pytorch_cuda_init_success = torch.cuda.FloatTensor(8)"
"from warp_drive.training.trainer import Trainer"
]
},
{
Expand Down
1 change: 0 additions & 1 deletion warp_drive/env_cpu_gpu_consistency_checker.py
Expand Up @@ -23,7 +23,6 @@
from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed

pytorch_cuda_init_success = torch.cuda.FloatTensor(8)
_OBSERVATIONS = Constants.OBSERVATIONS
_ACTIONS = Constants.ACTIONS
_REWARDS = Constants.REWARDS
Expand Down
2 changes: 1 addition & 1 deletion warp_drive/managers/pycuda_managers/pycuda_data_manager.py
Expand Up @@ -9,7 +9,7 @@
from typing import Optional

import numpy as np
import pycuda.autoinit
from warp_drive.utils import autoinit_pycuda
import pycuda.driver as pycuda_driver
import torch

Expand Down
Expand Up @@ -13,7 +13,7 @@
from typing import Optional

import numpy as np
import pycuda.autoinit
from warp_drive.utils import autoinit_pycuda
import pycuda.driver as cuda_driver
import torch
from pycuda.compiler import SourceModule
Expand Down
1 change: 0 additions & 1 deletion warp_drive/training/example_training_script_numba.py
Expand Up @@ -50,7 +50,6 @@ def setup_trainer_and_train(
and create the trainer object. Also, perform training.
"""
logging.getLogger().setLevel(logging.ERROR)
torch.cuda.FloatTensor(8) # add this line for successful cuda_init

num_envs = run_configuration["trainer"]["num_envs"]

Expand Down
2 changes: 1 addition & 1 deletion warp_drive/training/example_training_script_pycuda.py
Expand Up @@ -27,6 +27,7 @@
from warp_drive.training.utils.vertical_scaler import perform_auto_vertical_scaling
from warp_drive.utils.common import get_project_root


_ROOT_DIR = get_project_root()

_TAG_CONTINUOUS = "tag_continuous"
Expand All @@ -50,7 +51,6 @@ def setup_trainer_and_train(
and create the trainer object. Also, perform training.
"""
logging.getLogger().setLevel(logging.ERROR)
torch.cuda.FloatTensor(8) # add this line for successful cuda_init

num_envs = run_configuration["trainer"]["num_envs"]

Expand Down
@@ -1,6 +1,6 @@
import time

import pycuda.autoinit
from warp_drive.utils import autoinit_pycuda
import pycuda.driver as pycuda_driver

from warp_drive.training.utils.device_child_process.child_process_base import event_messenger
Expand Down
Expand Up @@ -2,7 +2,8 @@

import pycuda.driver as cuda_driver
import torch.distributed as dist
from pycuda.tools import clear_context_caches, make_default_context
from pycuda.tools import clear_context_caches
from warp_drive.utils.device_context import make_current_context


class PyCUDASingleDeviceContext:
Expand All @@ -13,12 +14,8 @@ class PyCUDASingleDeviceContext:
_context = None

def init_context(self, device_id=None):
if device_id is None:
context = make_default_context()
self._context = context
else:
context = cuda_driver.Device(device_id).make_context()
self._context = context
context = make_current_context(device_id)
self._context = context

@property
def context(self):
Expand Down
2 changes: 1 addition & 1 deletion warp_drive/utils/architecture_validate.py
Expand Up @@ -6,7 +6,7 @@

import logging

import pycuda.autoinit
from warp_drive.utils import autoinit_pycuda
from pycuda.driver import Context


Expand Down
8 changes: 8 additions & 0 deletions warp_drive/utils/autoinit_pycuda.py
@@ -0,0 +1,8 @@
import atexit
from warp_drive.utils.device_context import make_current_context

# Initialize torch and CUDA context

context = make_current_context()
device = context.get_device()
atexit.register(context.pop)
57 changes: 57 additions & 0 deletions warp_drive/utils/device_context.py
@@ -0,0 +1,57 @@
import torch
import pycuda.driver as cuda_driver


def make_current_context(device_id=None):
torch.cuda.init()
cuda_driver.init()
if device_id is None:
context = _get_primary_context_for_current_device()
else:
context = cuda_driver.Device(device_id).retain_primary_context()
context.push()
return context


def _get_primary_context_for_current_device():
ndevices = cuda_driver.Device.count()
if ndevices == 0:
raise RuntimeError("No CUDA enabled device found. "
"Please check your installation.")

# Is CUDA_DEVICE set?
import os
devn = os.environ.get("CUDA_DEVICE")

# Is $HOME/.cuda_device set ?
if devn is None:
try:
homedir = os.environ.get("HOME")
assert homedir is not None
devn = (open(os.path.join(homedir, ".cuda_device"))
.read().strip())
except:
pass

# If either CUDA_DEVICE or $HOME/.cuda_device is set, try to use it ;-)
if devn is not None:
try:
devn = int(devn)
except TypeError:
raise TypeError("CUDA device number (CUDA_DEVICE or ~/.cuda_device)"
" must be an integer")

dev = cuda_driver.Device(devn)
return dev.retain_primary_context()

# Otherwise, try to use any available device
else:
for devn in range(ndevices):
dev = cuda_driver.Device(devn)
try:
return dev.retain_primary_context()
except cuda_driver.Error:
pass

raise RuntimeError("_get_primary_context_for_current_device() wasn't able to create a context "
"on any of the %d detected devices" % ndevices)

0 comments on commit 4c6acb2

Please sign in to comment.