Skip to content

Commit

Permalink
trying to make cuSZp work
Browse files Browse the repository at this point in the history
  • Loading branch information
Danil committed Mar 22, 2024
1 parent a61a658 commit c8447ad
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
22 changes: 16 additions & 6 deletions qtensor/compression/Compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
print(Path(__file__).parent/'szx/src/')
sys.path.append(str(Path(__file__).parent/'szx/src/'))
sys.path.append('./szx/src')
# sys.path.append(str(Path(__file__).parent/'szp/src/'))
# sys.path.append('./szp/src')
sys.path.append(str(Path(__file__).parent/'szp/src/'))
sys.path.append('./szp/src')

sys.path.append(str(Path(__file__).parent/'cusz/src'))
sys.path.append('./cusz/src')
Expand All @@ -19,7 +19,7 @@
import torch
try:
from cuszx_wrapper import cuszx_host_compress, cuszx_host_decompress, cuszx_device_compress, cuszx_device_decompress
# from cuSZp_wrapper import cuszp_device_compress, cuszp_device_decompress
from cuSZp_wrapper import cuszp_device_compress, cuszp_device_decompress
from cusz_wrapper import cusz_device_compress, cusz_device_decompress
from torch_quant_perchannel import quant_device_compress, quant_device_decompress
from newsz_wrapper import newsz_device_compress, newsz_device_decompress
Expand Down Expand Up @@ -166,14 +166,24 @@ def free_decompressed(self):
self.decompressed_own = []

def free_compressed(self, ptr):
import ctypes, cupy
cmp_bytes, num_elements_eff, shape, dtype, _ = ptr
p_decompressed_ptr = ctypes.addressof(cmp_bytes[0])
# cast to int64 pointer
# (effectively converting pointer to pointer to addr to pointer to int64)
p_decompressed_int= ctypes.cast(p_decompressed_ptr, ctypes.POINTER(ctypes.c_uint64))
decompressed_int = p_decompressed_int.contents
cupy.cuda.runtime.free(decompressed_int.value)
cupy.get_default_memory_pool().free_all_blocks()
del cmp_bytes

def compress(self, data):
isCupy, num_elements_eff = _get_data_info(data)
dtype = data.dtype
cmp_bytes, outSize_ptr = cuszp_device_compress(data, self.r2r_error,self.r2r_threshold)
return (cmp_bytes, num_elements_eff, data.shape, dtype, outSize_ptr)
print("Compressing")
print(type(data), type(num_elements_eff))
cmp_bytes, outSize_ptr = cuszp_device_compress(data, self.r2r_error,num_elements_eff, self.r2r_threshold)
return (cmp_bytes, num_elements_eff, data.shape, dtype, outSize_ptr.contents.value)

# return (cmp_bytes, num_elements_eff, isCuPy, data.shape, dtype, outSize_ptr.contents.value)
def compress_size(self, ptr):
Expand All @@ -182,7 +192,7 @@ def compress_size(self, ptr):
def decompress(self, obj):
import cupy
cmp_bytes, num_elements_eff, shape, dtype, cmpsize = obj
decompressed_ptr = cuszp_device_decompress(num_elements_eff, cmp_bytes)
decompressed_ptr = cuszp_device_decompress(num_elements_eff, cmp_bytes, cmpsize, self, dtype)
arr_cp = decompressed_ptr[0]

arr = cupy.reshape(arr_cp, shape)
Expand Down
4 changes: 2 additions & 2 deletions qtensor/compression/szp/src/cuSZp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch

from pathlib import Path
#LIB_PATH = str(Path(__file__).parent/'libcuszp_wrapper.so')
LIB_PATH = '/home/mkshah5/QTensor/qtensor/compression/szp/src/libcuszp_wrapper.so'
LIB_PATH = str(Path(__file__).parent/'libcuszp_wrapper.so')
#LIB_PATH = '/home/mkshah5/QTensor/qtensor/compression/szp/src/libcuszp_wrapper.so'
# unsigned char* cuSZp_device_compress(float *oriData, size_t *outSize, float absErrBound, size_t nbEle){

def get_device_compress():
Expand Down

0 comments on commit c8447ad

Please sign in to comment.