Skip to content

Commit

Permalink
torch matm backend to support sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
danlkv committed Apr 5, 2024
1 parent f515c1d commit 334148c
Showing 1 changed file with 52 additions and 24 deletions.
76 changes: 52 additions & 24 deletions qtensor/contraction_backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,20 @@ def get_result_data(self, result):

class TorchBackendMatm(TorchBackend):

def _get_index_sizes(self, *ixs):
def _get_index_sizes(self, *ixs, size_dict = None):
if size_dict is not None:
return [size_dict[i] for i in ixs]
try:
sizes = [ i.size for i in ixs ]
except AttributeError:
sizes = [2] * len(ixs)
return sizes

def _get_index_space_size(self, *ixs):
sizes = self._get_index_sizes(*ixs)
def _get_index_space_size(self, *ixs, size_dict = None):
sizes = self._get_index_sizes(*ixs, size_dict = size_dict)
return reduce(np.multiply, sizes, 1)

def pairwise_sum_contract(self, ixa, a, ixb, b, ixout):
def pairwise_sum_contract(self, ixa, a, ixb, b, ixout, size_dict = None):
out = ixout
common = set(ixa).intersection(set(ixb))
# -- sum indices that are in one tensor only
Expand Down Expand Up @@ -267,29 +269,34 @@ def pairwise_sum_contract(self, ixa, a, ixb, b, ixout):
kix = common - set(out)
fix = common - kix
common = list(kix) + list(fix)
#print(f'{ixa=} {ixb=} {ixout=}; {common=} {mix=} {nix=}')
a = tensors[0].permute(*[
list(ixs[0]).index(x) for x in common + list(mix)
])

b = tensors[1].permute(*[
list(ixs[1]).index(x) for x in common + list(nix)
])

k, f, m, n = [self._get_index_space_size(*ix)
#print(f'{ixa=} {ixb=} {ixout=}; {common=} {mix=} {nix=}, {size_dict=}')
if tensors[0].numel() > 1:
a = tensors[0].permute(*[
list(ixs[0]).index(x) for x in common + list(mix)
])

if tensors[1].numel() > 1:
b = tensors[1].permute(*[
list(ixs[1]).index(x) for x in common + list(nix)
])

k, f, m, n = [self._get_index_space_size(*ix, size_dict=size_dict)
for ix in (kix, fix, mix, nix)
]
a = a.reshape(k, f, m)
b = b.reshape(k, f, n)
c = torch.einsum('kfm, kfn -> fmn', a, b)
if len(out):
#print('out ix', out, 'kfmnix', kix, fix, mix, nix)
c = c.reshape(*self._get_index_sizes(*out))
c = c.reshape(*self._get_index_sizes(*out, size_dict=size_dict))
#print('outix', out, 'res', c.shape, 'kfmn',kix, fix, mix, nix)

current_ord_ = list(fix) + list(mix) + list(nix)
if len(out):
c = c.permute(*[current_ord_.index(i) for i in out])
else:
c = c.flatten()
#print(f'c shape {c.shape}')
return c

def process_bucket(self, bucket, no_sum=False):
Expand All @@ -303,17 +310,24 @@ def process_bucket(self, bucket, no_sum=False):

ixr = list(map(int, result_indices))
ixt = list(map(int, tensor.indices))
result_indices = tuple(sorted(
out_indices = tuple(sorted(
set(result_indices + tensor.indices),
key=int, reverse=True
)
)
ixout = list(map(int, result_indices))

logger.trace('Before contract. expr: {}, {} ->', ixr, ixt, ixout)
result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout)
ixout = list(map(int, out_indices))

logger.trace('Before contract. expr: {}, {} -> {}', ixr, ixt, ixout)
size_dict = {}
for i in result_indices:
size_dict[int(i)] = i.size
for i in tensor.indices:
size_dict[int(i)] = i.size
logger.debug("result_indices: {}", result_indices)
result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout, size_dict = size_dict)
result_indices = out_indices
#result_data = torch.einsum(expr, result_data, tensor.data)
logger.trace("Data: {}, -> {}", result_data, tensor.data, result_data_new)
logger.trace("Data: {}, {} -> {}", result_data.shape, tensor.data.shape, result_data_new.shape)
result_data = result_data_new

# Merge and sort indices and shapes
Expand All @@ -334,10 +348,24 @@ def process_bucket(self, bucket, no_sum=False):
))[:-1]
ixout = list(map(int, result_indices))

logger.trace('Before contract. expr: {}, {} ->', ixr, ixt, ixout)
result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout)
logger.trace('Before contract. expr: {}, {} -> {}', ixr, ixt, ixout)
size_dict = {}
for i in result_indices:
size_dict[int(i)] = i.size
for i in tensor.indices:
size_dict[int(i)] = i.size
#logger.debug("result_indices: {}", result_indices)
result_data_new = self.pairwise_sum_contract(ixr, result_data, ixt, tensor.data, ixout, size_dict = size_dict)
#result_data = torch.einsum(expr, result_data, tensor.data)
logger.trace("Data: {}, -> {}", result_data, tensor.data, result_data_new)
logger.trace("Data: {}, {} -> {}", result_data.mean(), tensor.data.mean(), result_data_new.mean())
#if result_data_new.mean() == 0:
# logger.warning("Result is zero")
# logger.debug("result_indices: {}", result_indices)
# logger.debug("result_data: {}", result_data)
# logger.debug("tensor: {}", tensor)
# logger.debug("tensor_data: {}", tensor.data)
# logger.debug("result_data_new: {}", result_data_new)
# raise ValueError("Result is zero")
result_data = result_data_new
else:
result_data = result_data.sum(axis=-1)
Expand Down

0 comments on commit 334148c

Please sign in to comment.