Skip to content

Commit

Permalink
perf: Call nvmlInit() and nvmlShutdown() only once (#54)
Browse files Browse the repository at this point in the history
Lots of time is spent on nvmlInit() and nvmlShutdown() for each
new_query call. When running in a loop mode (-i), we do not need to
initialize and shutdown the nvml library because nvml APIs will be used
throughout the lifespan of the gpustat process.

Upon importing `gpustat.pynvml`, nvmlInit() will always be called.
  • Loading branch information
wookayin committed Nov 24, 2023
1 parent 64d22a9 commit 313b58d
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
11 changes: 3 additions & 8 deletions gpustat/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from blessed import Terminal

import gpustat.util as util
import gpustat.nvml as nvml
from gpustat.nvml import pynvml as N
from gpustat.nvml import check_driver_nvml_version

Expand Down Expand Up @@ -443,7 +444,7 @@ def clean_processes():
def new_query(debug=False, id=None) -> 'GPUStatCollection':
"""Query the information of all the GPUs on local machine"""

N.nvmlInit()
nvml.ensure_initialized()
log = util.DebugHelper()

def _decode(b: Union[str, bytes]) -> str:
Expand Down Expand Up @@ -625,7 +626,6 @@ def _wrapped(*args, **kwargs):
if debug:
log.report_summary()

N.nvmlShutdown()
return GPUStatCollection(gpu_list, driver_version=driver_version)

def __len__(self):
Expand Down Expand Up @@ -752,15 +752,10 @@ def new_query() -> GPUStatCollection:
def gpu_count() -> int:
'''Return the number of available GPUs in the system.'''
try:
N.nvmlInit()
nvml.ensure_initialized()
return N.nvmlDeviceGetCount()
except N.NVMLError:
return 0 # fallback
finally:
try:
N.nvmlShutdown()
except N.NVMLError:
pass


def is_available() -> bool:
Expand Down
30 changes: 28 additions & 2 deletions gpustat/nvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pylint: disable=protected-access

from typing import Tuple
import atexit
import functools
import os
import sys
Expand Down Expand Up @@ -204,4 +204,30 @@ def nvmlDeviceGetMemoryInfo(handle):
setattr(pynvml, 'nvmlDeviceGetMemoryInfo', pynvml_monkeypatch.nvmlDeviceGetMemoryInfo)


__all__ = ['pynvml']
# Upon importing this module, let pynvml be initialized and remain active
# throughout the lifespan of the python process (until gpustat exists).
_initialized: bool
_init_error = None
try:
pynvml.nvmlInit()
_initialized = True

def _shutdown():
pynvml.nvmlShutdown()
atexit.register(_shutdown)

except pynvml.NVMLError as exc:
_initialized = False
_init_error = exc


def ensure_initialized():
if not _initialized:
raise _init_error # type: ignore


__all__ = [
'pynvml',
'check_driver_nvml_version',
'ensure_initialized',
]
1 change: 1 addition & 0 deletions gpustat/test_gpustat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _configure_mock(N=pynvml,
unstub(N) # reset all the stubs

when(N).nvmlInit().thenReturn()
gpustat.nvml._initialized = True # nvmlInit() is called upon module import
when(N).nvmlShutdown().thenReturn()
when(N).nvmlSystemGetDriverVersion().thenReturn('415.27.mock')

Expand Down

0 comments on commit 313b58d

Please sign in to comment.