Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #83 from salesforce/pycuda-context
new pycuda device context
- Loading branch information
Showing
21 changed files
with
82 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
gym>=0.18, <0.26 | ||
matplotlib>=3.2.1 | ||
numpy>=1.18.1 | ||
pycuda==2022.1 | ||
pycuda>=2022.1 | ||
pytest>=6.1.0 | ||
pyyaml>=5.4 | ||
torch>=1.9, <1.11 | ||
torch>=1.9 | ||
numba>=0.54.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
warp_drive/training/utils/distributed_train/distributed_trainer_pycuda.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import atexit | ||
from warp_drive.utils.device_context import make_current_context | ||
|
||
# Initialize torch and CUDA context | ||
|
||
context = make_current_context() | ||
device = context.get_device() | ||
atexit.register(context.pop) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch | ||
import pycuda.driver as cuda_driver | ||
|
||
|
||
def make_current_context(device_id=None): | ||
torch.cuda.init() | ||
cuda_driver.init() | ||
if device_id is None: | ||
context = _get_primary_context_for_current_device() | ||
else: | ||
context = cuda_driver.Device(device_id).retain_primary_context() | ||
context.push() | ||
return context | ||
|
||
|
||
def _get_primary_context_for_current_device(): | ||
ndevices = cuda_driver.Device.count() | ||
if ndevices == 0: | ||
raise RuntimeError("No CUDA enabled device found. " | ||
"Please check your installation.") | ||
|
||
# Is CUDA_DEVICE set? | ||
import os | ||
devn = os.environ.get("CUDA_DEVICE") | ||
|
||
# Is $HOME/.cuda_device set ? | ||
if devn is None: | ||
try: | ||
homedir = os.environ.get("HOME") | ||
assert homedir is not None | ||
devn = (open(os.path.join(homedir, ".cuda_device")) | ||
.read().strip()) | ||
except: | ||
pass | ||
|
||
# If either CUDA_DEVICE or $HOME/.cuda_device is set, try to use it ;-) | ||
if devn is not None: | ||
try: | ||
devn = int(devn) | ||
except TypeError: | ||
raise TypeError("CUDA device number (CUDA_DEVICE or ~/.cuda_device)" | ||
" must be an integer") | ||
|
||
dev = cuda_driver.Device(devn) | ||
return dev.retain_primary_context() | ||
|
||
# Otherwise, try to use any available device | ||
else: | ||
for devn in range(ndevices): | ||
dev = cuda_driver.Device(devn) | ||
try: | ||
return dev.retain_primary_context() | ||
except cuda_driver.Error: | ||
pass | ||
|
||
raise RuntimeError("_get_primary_context_for_current_device() wasn't able to create a context " | ||
"on any of the %d detected devices" % ndevices) |