Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for np.random.Generator #6566

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@
MeasurementKey,
MeasurementType,
PeriodicValue,
PRNG_OR_SEED_LIKE,
RANDOM_STATE_OR_SEED_LIKE,
state_vector_to_probabilities,
SympyCondition,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
'QUANTUM_STATE_LIKE',
'QubitOrderOrList',
'RANDOM_STATE_OR_SEED_LIKE',
'PRNG_OR_SEED_LIKE',
'STATE_VECTOR_LIKE',
'Sweepable',
'TParamKey',
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/value/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@
from cirq.value.type_alias import TParamKey, TParamVal, TParamValComplex

from cirq.value.value_equality_attr import value_equality


from cirq.value.prng import parse_prng, PRNG_OR_SEED_LIKE
66 changes: 66 additions & 0 deletions cirq-core/cirq/value/prng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import numbers
import numpy as np

from cirq._doc import document
from cirq.value.random_state import RANDOM_STATE_OR_SEED_LIKE

PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator]

document(
PRNG_OR_SEED_LIKE,
"""A pseudorandom number generator or object that can be converted to one.

If is an integer or None, turns into a `np.random.Generator` seeded with that value.
If is an instance of `np.random.Generator` or a subclass of it, return as is.
If is an instance of `np.random.RandomState` or has a `randint` method, returns
Comment on lines +29 to +31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If is an integer or None, turns into a `np.random.Generator` seeded with that value.
If is an instance of `np.random.Generator` or a subclass of it, return as is.
If is an instance of `np.random.RandomState` or has a `randint` method, returns
If an integer or None, turns into a `np.random.Generator` seeded with that value.
If an instance of `np.random.Generator` or a subclass of it, return as is.
If an instance of `np.random.RandomState` or has a `randint` method, returns

`np.random.default_rng(rs.randint(2**31))`
""",
)


def parse_prng(
prng_or_seed: Union[PRNG_OR_SEED_LIKE, RANDOM_STATE_OR_SEED_LIKE]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just have only the PRNG_OR_SEED_LIKE type?

RANDOM_STATE_OR_SEED_LIKE is Any so it turns off type checking of the argument.

) -> np.random.Generator:
"""Interpret an object as a pseudorandom number generator.

If `prng_or_seed` is an `np.random.Generator`, return it unmodified.
If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`.
If `prng_or_seed` is an instance of `np.random.RandomState` or has a `randint` method,
returns `np.random.default_rng(prng_or_seed.randint(2**31))`.

Args:
prng_or_seed: The object to be used as or converted to a pseudorandom
number generator.

Returns:
The pseudorandom number generator object.

Raises:
TypeError: If `prng_or_seed` is can't be converted to an np.random.Generator.
"""
if isinstance(prng_or_seed, np.random.Generator):
return prng_or_seed
if prng_or_seed is None or isinstance(prng_or_seed, numbers.Integral):
return np.random.default_rng(prng_or_seed if prng_or_seed is None else int(prng_or_seed))
Comment on lines +59 to +60
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we return a singleton Generator object for None?
The None arg is going to be frequently used as a default for optional arguments.
Singleton would prevent creation of potentially large number of Generator objects.

if isinstance(prng_or_seed, np.random.RandomState):
return np.random.default_rng(prng_or_seed.randint(2**31))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can reuse the bit generator for a more genuine conversion.

Suggested change
return np.random.default_rng(prng_or_seed.randint(2**31))
return np.random.default_rng(prng_or_seed._bit_generator)

randint = getattr(prng_or_seed, "randint", None)
if randint is not None:
return np.random.default_rng(randint(2**31))
raise TypeError(f"{prng_or_seed} can't be converted to a pseudorandom number generator")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - maybe state the actual class here ?

Suggested change
raise TypeError(f"{prng_or_seed} can't be converted to a pseudorandom number generator")
raise TypeError(f"{prng_or_seed} cannot be converted to the numpy.random.Generator")

48 changes: 48 additions & 0 deletions cirq-core/cirq/value/prng_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

import pytest
import numpy as np

import cirq


def _sample(prng):
return tuple(prng.random(10))
Comment on lines +23 to +24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this. One output from random() is enough to check if 2 generators are at the same seed.



def test_parse_rng() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_parse_rng() -> None:
def test_parse_prng() -> None:

eq = cirq.testing.EqualsTester()

# An `np.random.Generator` or a seed.
group_inputs: List[Union[int, np.random.Generator]] = [42, np.random.default_rng(42)]
group: List[np.random.Generator] = [cirq.value.parse_prng(s) for s in group_inputs]
eq.add_equality_group(*[_sample(g) for g in group])
Comment on lines +30 to +33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us not check cross-group inequality. Following the test_parse_random_state style is a bit more readable

Suggested change
# An `np.random.Generator` or a seed.
group_inputs: List[Union[int, np.random.Generator]] = [42, np.random.default_rng(42)]
group: List[np.random.Generator] = [cirq.value.parse_prng(s) for s in group_inputs]
eq.add_equality_group(*[_sample(g) for g in group])
# An `np.random.Generator` or a seed.
prngs = [
cirq.value.parse_prng(42),
cirq.value.parse_prng(np.int32(42)),
cirq.value.parse_prng(np.random.default_rng(42)),
]
vals = [prng.random() for prng in prngs]
eq = cirq.testing.EqualsTester()
eq.add_equality_group(*vals)


# A None seed.
prng = cirq.value.parse_prng(None)
eq.add_equality_group(_sample(prng))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a noop check for a single value. Perhaps replace with

assert prng is cirq.value.parse_prng(None)

if you are OK with the previous suggestion to have a singleton generator for None.


# RandomState PRNG.
prng = cirq.value.parse_prng(np.random.RandomState(42))
eq.add_equality_group(_sample(prng))
Comment on lines +39 to +41
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can check reproducibility here -

Suggested change
# RandomState PRNG.
prng = cirq.value.parse_prng(np.random.RandomState(42))
eq.add_equality_group(_sample(prng))
# RandomState PRNG.
prngs = [
cirq.value.parse_prng(np.random.RandomState(42)),
cirq.value.parse_prng(np.random.RandomState(42)),
]
vals = [prng.random() for prng in prngs]
eq = cirq.testing.EqualsTester()
eq.add_equality_group(*vals)


# np.random module
prng = cirq.value.parse_prng(np.random)
eq.add_equality_group(_sample(prng))
Comment on lines +43 to +45
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not support creation of generator from a module, not a good practice.
The use of np.random module was causing pickle havoc in #3717.

I don't quite see a need for it, users can pass None for a default generator.
I'd be open to throwing TypeError for a module argument.

Suggested change
# np.random module
prng = cirq.value.parse_prng(np.random)
eq.add_equality_group(_sample(prng))


with pytest.raises(TypeError):
_ = cirq.value.parse_prng(1.0)
12 changes: 10 additions & 2 deletions cirq-core/cirq/value/random_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@
If an integer, turns into a `np.random.RandomState` seeded with that
integer.

If `random_state` is an instance of `np.random.Generator`, returns a
`np.random.RandomState` seeded with `random_state.bit_generator`.

If none of the above, it is used unmodified. In this case, it is assumed
that the object implements whatever methods are required for the use case
at hand. For example, it might be an existing instance of
`np.random.RandomState` or a custom pseudorandom number generator
implementation.

Note: prefer to use cirq.PRNG_OR_SEED_LIKE.
""",
)

Expand All @@ -41,8 +46,9 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran
"""Interpret an object as a pseudorandom number generator.

If `random_state` is None, returns the module `np.random`.
If `random_state` is an integer, returns
`np.random.RandomState(random_state)`.
If `random_state` is an integer, returns `np.random.RandomState(random_state)`.
If `random_state` is an instance of `np.random.Generator`, returns a
`np.random.RandomState` seeded with `random_state.bit_generator`.
Otherwise, returns `random_state` unmodified.

Args:
Expand All @@ -56,5 +62,7 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran
return cast(np.random.RandomState, np.random)
elif isinstance(random_state, int):
return np.random.RandomState(random_state)
elif isinstance(random_state, np.random.Generator):
return np.random.RandomState(random_state.bit_generator)
else:
return cast(np.random.RandomState, random_state)
2 changes: 2 additions & 0 deletions cirq-core/cirq/value/random_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ def rand(prng):
vals = [prng.rand() for prng in prngs]
eq = cirq.testing.EqualsTester()
eq.add_equality_group(*vals)

eq.add_equality_group(cirq.value.parse_random_state(np.random.default_rng(0)).rand())
Comment on lines +45 to +46
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us follow the above style here. Creating a new EqualsTester will only check equality within the one group, we don't need to check inequality from other groups.

Suggested change
eq.add_equality_group(cirq.value.parse_random_state(np.random.default_rng(0)).rand())
prngs = [
cirq.value.parse_random_state(np.random.default_rng(0)),
cirq.value.parse_random_state(np.random.default_rng(0)),
]
vals = [prng.rand() for prng in prngs]
eq = cirq.testing.EqualsTester()
eq.add_equality_group(*vals)```