Skip to content

Commit

Permalink
Fix evoformer test
Browse files Browse the repository at this point in the history
  • Loading branch information
gahdritz committed Jun 24, 2023
1 parent d5da89c commit 7fdb503
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/test_evoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ def test_shape(self):
ckpt=False,
inf=inf,
eps=eps,
).eval()
).eval().cuda()

m = torch.rand((batch_size, s_t, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
m = torch.rand((batch_size, s_t, n_res, c_m), device="cuda")
z = torch.rand((batch_size, n_res, n_res, c_z), device="cuda")
msa_mask = torch.randint(
0,
2,
Expand All @@ -205,6 +205,7 @@ def test_shape(self):
s_t,
n_res,
),
device="cuda",
)
pair_mask = torch.randint(
0,
Expand All @@ -214,6 +215,7 @@ def test_shape(self):
n_res,
n_res,
),
device="cuda",
)

shape_z_before = z.shape
Expand Down

0 comments on commit 7fdb503

Please sign in to comment.