Skip to content

Commit

Permalink
perf: fix potential for multi-thread collision when creating pools
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed Sep 15, 2023
1 parent bee46e6 commit e4979ee
Showing 1 changed file with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions cloudfiles/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,27 @@ def __missing__(self, key):
GCloudBucketPoolParams = namedtuple('GCloudBucketPoolParams', 'bucket_name request_payer')
MemoryPoolParams = namedtuple('MemoryPoolParams', 'bucket_name')

GCS_BUCKET_POOL_LOCK = threading.Lock()
S3_BUCKET_POOL_LOCK = threading.Lock()
MEM_BUCKET_POOL_LOCK = threading.Lock()

def reset_connection_pools():
global S3_POOL
global GC_POOL
global MEM_POOL
S3_POOL = keydefaultdict(lambda params: S3ConnectionPool(params.service, params.bucket_name))
GC_POOL = keydefaultdict(lambda params: GCloudBucketPool(params.bucket_name, params.request_payer))
MEM_POOL = keydefaultdict(lambda params: MemoryPool(params.bucket_name))
MEMORY_DATA.clear()
global GCS_BUCKET_POOL_LOCK
global S3_BUCKET_POOL_LOCK
global MEM_BUCKET_POOL_LOCK

with S3_BUCKET_POOL_LOCK:
S3_POOL = keydefaultdict(lambda params: S3ConnectionPool(params.service, params.bucket_name))

with GCS_BUCKET_POOL_LOCK:
GC_POOL = keydefaultdict(lambda params: GCloudBucketPool(params.bucket_name, params.request_payer))

with MEM_BUCKET_POOL_LOCK:
MEM_POOL = keydefaultdict(lambda params: MemoryPool(params.bucket_name))
MEMORY_DATA.clear()
import gc
gc.collect()

Expand Down Expand Up @@ -340,14 +353,19 @@ def stripext(fname):
filenames = list(map(stripext, filenames))
filenames.sort()
return iter(filenames)

class MemoryInterface(StorageInterface):
def __init__(self, path, secrets=None, request_payer=None, **kwargs):
global MEM_BUCKET_POOL_LOCK

super(StorageInterface, self).__init__()
if request_payer is not None:
raise ValueError("Specifying a request payer for the MemoryInterface is not supported. request_payer must be None, got '{}'.", request_payer)
self._path = path
self._data = MEM_POOL[MemoryPoolParams(path.bucket)].get_connection(secrets, None)

with MEM_BUCKET_POOL_LOCK:
pool = MEM_POOL[MemoryPoolParams(path.bucket)]
self._data = pool.get_connection(secrets, None)

def get_path_to_file(self, file_path):
return posixpath.join(
Expand Down Expand Up @@ -498,9 +516,13 @@ class GoogleCloudStorageInterface(StorageInterface):
def __init__(self, path, secrets=None, request_payer=None, **kwargs):
super(StorageInterface, self).__init__()
global GC_POOL
global GCS_BUCKET_POOL_LOCK
self._path = path
self._request_payer = request_payer
self._bucket = GC_POOL[GCloudBucketPoolParams(self._path.bucket, self._request_payer)].get_connection(secrets, None)

with GCS_BUCKET_POOL_LOCK:
pool = GC_POOL[GCloudBucketPoolParams(self._path.bucket, self._request_payer)]
self._bucket = pool.get_connection(secrets, None)
self._secrets = secrets

def get_path_to_file(self, file_path):
Expand Down Expand Up @@ -799,10 +821,13 @@ def __init__(self, path, secrets=None, request_payer=None, composite_upload_thre
self.composite_upload_threshold = composite_upload_threshold

def _get_bucket(self, bucket_name):
global S3_BUCKET_POOL_LOCK
service = self._path.alias or 's3'
return S3_POOL[S3ConnectionPoolParams(service, bucket_name, self._request_payer)].get_connection(
self._secrets, self._path.host
)

with S3_BUCKET_POOL_LOCK:
pool = S3_POOL[S3ConnectionPoolParams(service, bucket_name, self._request_payer)]

return pool.get_connection(self._secrets, self._path.host)

def get_path_to_file(self, file_path):
return posixpath.join(self._path.path, file_path)
Expand Down

0 comments on commit e4979ee

Please sign in to comment.