Skip to content

Commit

Permalink
fix torch backend no_sum bug
Browse files Browse the repository at this point in the history
  • Loading branch information
danlkv committed Mar 21, 2024
1 parent 7e0286f commit 53770f2
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions qtensor/contraction_backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def slice_torch_tensor(data:np.ndarray, indices_in, indices_out, slice_dict):
indices_sliced = [
i for sl, i in zip(slice_bounds, indices_in) if not isinstance(sl, int)
]
#print(f'indicies_in {indices_in}, slice_dict {slice_dict}, bounds {slice_bounds}, slicedix {indices_sliced}, sshape {s_data.shape}')
#print(f'{indices_in=}, {indices_sliced=} {slice_dict=}, {slice_bounds=}, slicedix {indices_sliced}, sshape {s_data.shape}')
indices_sized = [v.copy(size=size) for v, size in zip(indices_sliced, s_data.shape)]
indices_out = [v for v in indices_out if not isinstance(slice_dict.get(v, None), int)]
assert len(indices_sized) == len(s_data.shape)
Expand Down Expand Up @@ -130,7 +130,7 @@ def process_bucket(self, bucket, no_sum=False):
tensor = bucket[-1]
expr = get_einsum_expr(
list(map(int, result_indices)), list(map(int, tensor.indices))
, contract = 1
, contract = 0 if no_sum else 1
)
logger.trace('Before contract. Expr: {}, inputs: {}, {}', expr, result_data, tensor)
result_data = torch.einsum(expr, result_data, tensor.data)
Expand All @@ -146,10 +146,10 @@ def process_bucket(self, bucket, no_sum=False):
result_data = result_data



if len(result_indices) > 0:
first_index = result_indices[-1]
result_indices = result_indices[:-1]
if not no_sum:
result_indices = result_indices[:-1]
tag = first_index.identity
else:
tag = 'f'
Expand Down

0 comments on commit 53770f2

Please sign in to comment.