Skip to content

Commit

Permalink
Add lint changes
Browse files Browse the repository at this point in the history
  • Loading branch information
atharva-satpute committed May 1, 2024
1 parent cd03721 commit 511b5c7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
10 changes: 8 additions & 2 deletions qiskit/primitives/containers/bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def num_shots(self) -> int:
@staticmethod
def _bytes_to_bitstring(data: bytes, num_bits: int, mask: int) -> str:
val = int.from_bytes(data, "big") & mask
return bin(val)[2:].zfill(num_bits)
return bin(val)[2:].zfill(num_bits) if mask else ""

@staticmethod
def _bytes_to_int(data: bytes, mask: int) -> int:
Expand All @@ -143,7 +143,13 @@ def _bytes_to_int(data: bytes, mask: int) -> int:
def _get_counts(
self, *, loc: int | tuple[int, ...] | None, converter: Callable[[bytes], str | int]
) -> dict[str, int] | dict[int, int]:
arr = self._array.reshape(-1, self._array.shape[-1]) if loc is None else self._array[loc]
if loc is None:
_order = self._array.shape[-1]
arr = self._array.reshape(-1, _order) if _order else self._array
else:
arr = self._array[loc]
if isinstance(arr, np.uint8):
arr = np.array([self._array[loc]])

counts = defaultdict(int)
for shot_row in arr:
Expand Down
26 changes: 26 additions & 0 deletions test/python/primitives/containers/test_bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ def test_get_counts(self):
# test that providing no location takes the union over all shots
self.assertEqual(bit_array.get_counts(), {bs1: 2, bs2: 1, bs3: 1, bs4: 2})

bit_array = BitArray(np.zeros([1024, 2], dtype=np.uint8), num_bits=16)
bs5 = "00000000" + "00000000"
self.assertEqual(bit_array.get_counts(), {bs5: 1024})
self.assertEqual(bit_array.get_counts(1), {bs5: 2})
self.assertEqual(bit_array.get_counts((0, 1)), {bs5: 1})

# test with no classical register
bit_array = BitArray(np.zeros([1024, 0], dtype=np.uint8), num_bits=0)
self.assertEqual(bit_array.get_counts(), {"": 1024})
self.assertEqual(bit_array.get_counts(1), {})
with self.assertRaises(IndexError):
bit_array.get_counts((0, 1))

def test_get_int_counts(self):
"""Test conversion to int counts."""
# note that [234, 100] requires 16 bits, not 15; we are testing that get_counts ignores the
Expand All @@ -108,6 +121,19 @@ def test_get_int_counts(self):
# test that providing no location takes the union over all shots
self.assertEqual(bit_array.get_int_counts(), {val1: 2, val2: 1, val3: 1, val4: 2})

bit_array = BitArray(np.zeros([1024, 2], dtype=np.uint8), num_bits=16)
val5 = 0
self.assertEqual(bit_array.get_int_counts(), {val5: 1024})
self.assertEqual(bit_array.get_int_counts(1), {val5: 2})
self.assertEqual(bit_array.get_int_counts((0, 1)), {val5: 1})

# test with no classical register
bit_array = BitArray(np.zeros([1024, 0], dtype=np.uint8), num_bits=0)
self.assertEqual(bit_array.get_int_counts(), {0: 1024})
self.assertEqual(bit_array.get_int_counts(1), {})
with self.assertRaises(IndexError):
bit_array.get_int_counts((0, 1))

def test_get_bitstrings(self):
"""Test conversion to bitstrings."""
# note that [234, 100] requires 16 bits, not 15; we are testing that get_counts ignores the
Expand Down

0 comments on commit 511b5c7

Please sign in to comment.