Skip to content

Commit

Permalink
fix cuszp implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Danil committed May 8, 2024
1 parent 6137d2c commit 1506748
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions qtensor/compression/Compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,8 @@ def free_compressed(self, ptr):
#return
import ctypes, cupy
#cmp_bytes, num_elements_eff, shape, dtype, _ = ptr
cmp_t_real, cmp_t_imag, shape, dtype = ptr
del cmp_t_real
del cmp_t_imag
cmp_t, shape, dtype = ptr
del cmp_t
torch.cuda.empty_cache()
return
print(f"Freeing compressed data {num_elements_eff}")
Expand All @@ -194,31 +193,34 @@ def free_compressed(self, ptr):
def compress(self, data):
isCupy, num_elements_eff = _get_data_info(data)
dtype = data.dtype
shape = data.shape
# convert cupy to torch
data_imag = torch.as_tensor(data.imag, device='cuda').contiguous()
data_real = torch.as_tensor(data.real, device='cuda').contiguous()
print(f"cuszp Compressing {type(data)}")
# TODO: cast to one array of double the number of elements
torch_data = torch.tensor(data, device='cuda')
data_view = torch.view_as_real(torch_data)
#print(f"cuszp Compressing {type(data)}")
#cmp_bytes, outSize_ptr = cuszp_device_compress(data, self.r2r_error, num_elements_eff, self.r2r_threshold)
cmp_t_real = cuszp.compress(data_real, self.r2r_error, 'rel')
cmp_t_imag = cuszp.compress(data_imag, self.r2r_error, 'rel')
return (cmp_t_real, cmp_t_imag, data.shape, dtype)
cmp_t = cuszp.compress(data_view, self.r2r_error, 'rel')
return (cmp_t, shape, dtype)

# return (cmp_bytes, num_elements_eff, isCuPy, data.shape, dtype, outSize_ptr.contents.value)
def compress_size(self, ptr):
#return ptr[4]
return ptr[0].nbytes + ptr[1].nbytes
return ptr[0].nbytes

def decompress(self, obj):
import cupy
#cmp_bytes, num_elements_eff, shape, dtype, cmpsize = obj
#decompressed_ptr = cuszp_device_decompress(num_elements_eff, cmp_bytes, cmpsize, self, dtype)
cmp_t_real, cmp_t_imag, shape, dtype = obj
cmp_t, shape, dtype = obj
num_elements_decompressed = 1
for s in shape:
num_elements_decompressed *= s
decomp_t_real = cuszp.decompress(cmp_t_real, num_elements_decompressed, cmp_t_real.nbytes, self.r2r_error, 'rel')
decomp_t_imag = cuszp.decompress(cmp_t_imag, num_elements_decompressed, cmp_t_imag.nbytes, self.r2r_error, 'rel')
decomp_t = decomp_t_real + 1j * decomp_t_imag
# Number of elements is twice because the shape is for complex numbers
num_elements_decompressed *= 2
decomp_t_float = cuszp.decompress(cmp_t, num_elements_decompressed, cmp_t.nbytes, self.r2r_error, 'rel')
decomp_t_float = decomp_t_float.view(decomp_t_float.shape[0]//2, 2)
decomp_t = torch.view_as_complex(decomp_t_float)
arr_cp = cupy.asarray(decomp_t)
arr = cupy.reshape(arr_cp, shape)
return arr
Expand Down

0 comments on commit 1506748

Please sign in to comment.