Skip to content

Commit

Permalink
torch matm backend with flattened reshaping
Browse files Browse the repository at this point in the history
  • Loading branch information
danlkv committed Apr 8, 2024
1 parent 334148c commit 9fb06b9
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 114 deletions.
19 changes: 18 additions & 1 deletion qtensor/contraction_backends/tests/test_torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import qtensor
import pytest
import numpy as np
from qtensor.contraction_backends import TorchBackend, NumpyBackend
from qtensor.contraction_backends import NumpyBackend
from qtensor.contraction_backends.torch import TorchBackend, TorchBackendMatm, permute_flattened
from qtensor import QtreeSimulator
from qtensor.tests import get_test_qaoa_ansatz_circ
torch = pytest.importorskip('torch')
Expand Down Expand Up @@ -61,6 +62,22 @@ def contract_tn(backend, search_len=1, test_problem_kwargs={}):

assert restr.shape == resnp.shape
assert np.allclose(restr, resnp)
# -- Testing low-level functions for torch matm backend

def test_torch_matm_permute():
K = 5
d = 2
shape = [5] + [d]*(K-1)
x = torch.randn(shape)
for i in range(20):
perm = list(np.random.permutation(K))
y = permute_flattened(x.flatten(), perm, shape)
assert y.ndim == 1
assert y.numel() == x.numel()
print('perm', perm)
assert torch.allclose(y, x.permute(perm).flatten())

# -- Testing get_sliced_buckets

def test_torch_get_sliced__slice():
backend = TorchBackend()
Expand Down

0 comments on commit 9fb06b9

Please sign in to comment.