-
Notifications
You must be signed in to change notification settings - Fork 983
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Made the Density Matrix Plotter (#4493)
Implements a Quirk like plotting for Density Matrices. The plotting is static, no dynamic sweeps are a part of this. Closes #4485.
- Loading branch information
1 parent
aaf969d
commit 50271af
Showing
3 changed files
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tool to visualize the magnitudes and phases in the density matrix""" | ||
|
||
from typing import Optional | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from matplotlib import lines, patches | ||
|
||
from cirq.qis.states import validate_density_matrix | ||
|
||
|
||
def _plot_element_of_density_matrix(ax, x, y, r, phase, show_rect=False, show_text=False): | ||
"""Plots a single element of a density matrix | ||
Args: | ||
x: x coordinate of the cell we are plotting | ||
y: y coordinate of the cell we are plotting | ||
r: the amplitude of the qubit in that cell | ||
phase: phase of the qubit in that cell, in radians | ||
show_rect: Boolean on if to show the amplitude rectangle, used for diagonal elements | ||
show_text: Boolean on if to show text labels or not | ||
ax: The axes to plot on | ||
""" | ||
# Setting up a few magic numbers for graphics | ||
_half_cell_size_after_padding = (1 / 1.1) * 0.5 | ||
_rectangle_margin = 0.01 | ||
_image_opacity = 0.8 if not show_text else 0.4 | ||
|
||
circle_out = plt.Circle( | ||
(x + 0.5, y + 0.5), radius=1 / _half_cell_size_after_padding, fill=False, color='#333333' | ||
) | ||
circle_in = plt.Circle( | ||
(x + 0.5, y + 0.5), | ||
radius=r / _half_cell_size_after_padding, | ||
fill=True, | ||
color='IndianRed', | ||
alpha=_image_opacity, | ||
) | ||
line = lines.Line2D( | ||
(x + 0.5, x + 0.5 + np.cos(phase) / _half_cell_size_after_padding), | ||
(y + 0.5, y + 0.5 + np.sin(phase) / _half_cell_size_after_padding), | ||
color='#333333', | ||
alpha=_image_opacity, | ||
) | ||
ax.add_artist(circle_in) | ||
ax.add_artist(circle_out) | ||
ax.add_artist(line) | ||
if show_rect: | ||
rect = patches.Rectangle( | ||
(x + _rectangle_margin, y + _rectangle_margin), | ||
1.0 - 2 * _rectangle_margin, | ||
r * (1 - 2 * _rectangle_margin), | ||
alpha=0.25, | ||
) | ||
ax.add_artist(rect) | ||
if show_text: | ||
plt.text( | ||
x + 0.5, | ||
y + 0.5, | ||
f"{np.round(r, decimals=2)}\n{np.round(phase * 180 / np.pi, decimals=2)} deg", | ||
horizontalalignment='center', | ||
verticalalignment='center', | ||
) | ||
|
||
|
||
def plot_density_matrix( | ||
matrix: np.ndarray, | ||
ax: Optional[plt.Axes] = None, | ||
*, | ||
show_text: bool = False, | ||
title: Optional[str] = None, | ||
) -> plt.Axes: | ||
"""Generates a plot for a given density matrix. | ||
1. Each entry of the density matrix, a complex number, is plotted as an | ||
Argand Diagram where the partially filled red circle represents the magnitude | ||
and the line represents the phase angle, going anti-clockwise from positive x - axis. | ||
2. The blue rectangles on the diagonal elements represent the probability | ||
of measuring the system in state $|i\rangle$. | ||
Rendering scheme is inspired from https://algassert.com/quirk | ||
Args: | ||
matrix: The density matrix to visualize | ||
show_text: If true, the density matrix values are also shown as text labels | ||
ax: The axes to plot on | ||
title: Title of the plot | ||
""" | ||
plt.style.use('ggplot') | ||
|
||
_padding_around_plot = 0.001 | ||
|
||
matrix = matrix.astype(np.complex128) | ||
num_qubits = int(np.log2(matrix.shape[0])) | ||
validate_density_matrix(matrix, qid_shape=(2 ** num_qubits,)) | ||
|
||
if ax is None: | ||
_, ax = plt.subplots(figsize=(10, 10)) | ||
ax.set_xlim(0 - _padding_around_plot, 2 ** num_qubits + _padding_around_plot) | ||
ax.set_ylim(0 - _padding_around_plot, 2 ** num_qubits + _padding_around_plot) | ||
|
||
for i in range(matrix.shape[0]): | ||
for j in range(matrix.shape[1]): | ||
_plot_element_of_density_matrix( | ||
ax, | ||
i, | ||
j, | ||
np.abs(matrix[i][-j - 1]), | ||
np.angle(matrix[i][-j - 1]), | ||
show_rect=(i == matrix.shape[1] - j - 1), | ||
show_text=show_text, | ||
) | ||
|
||
ticks, labels = np.arange(0.5, matrix.shape[0]), [ | ||
f"{'0'*(num_qubits - len(f'{i:b}'))}{i:b}" for i in range(matrix.shape[0]) | ||
] | ||
ax.set_xticks(ticks) | ||
ax.set_xticklabels(labels) | ||
ax.set_yticks(ticks) | ||
ax.set_yticklabels(reversed(labels)) | ||
ax.set_facecolor('#eeeeee') | ||
if title is not None: | ||
ax.set_title(title) | ||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright 2021 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for Density Matrix Plotter.""" | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from matplotlib import lines, patches, text, spines, axis | ||
from matplotlib import pyplot as plt | ||
|
||
import cirq.testing | ||
from cirq.vis.density_matrix import plot_density_matrix | ||
from cirq.vis.density_matrix import _plot_element_of_density_matrix | ||
|
||
|
||
@pytest.mark.parametrize('show_text', [True, False]) | ||
@pytest.mark.parametrize('size', [2, 4, 8, 16]) | ||
def test_density_matrix_plotter(size, show_text): | ||
matrix = cirq.testing.random_density_matrix(size) | ||
# Check that the title shows back up | ||
ax = plot_density_matrix(matrix, show_text=show_text, title='Test Density Matrix Plot') | ||
assert ax.get_title() == 'Test Density Matrix Plot' | ||
# Check that the objects in the plot are only those we expect and nothing new was added | ||
for obj in ax.get_children(): | ||
assert isinstance( | ||
obj, | ||
( | ||
patches.Circle, | ||
spines.Spine, | ||
axis.XAxis, | ||
axis.YAxis, | ||
lines.Line2D, | ||
patches.Rectangle, | ||
text.Text, | ||
), | ||
) | ||
|
||
|
||
@pytest.mark.parametrize('show_text', [True, False]) | ||
@pytest.mark.parametrize('size', [2, 4, 8, 16]) | ||
def test_density_matrix_circle_sizes(size, show_text): | ||
matrix = cirq.testing.random_density_matrix(size) | ||
# Check that the correct title is being shown | ||
ax = plot_density_matrix(matrix, show_text=show_text, title='Test Density Matrix Plot') | ||
# Check that the radius of all the circles in the matrix is correct | ||
circles = list(filter(lambda x: isinstance(x, patches.Circle), ax.get_children())) | ||
mean_radius = np.mean([c.radius for c in circles if c.fill]) | ||
mean_value = np.mean(np.abs(matrix)) | ||
circles = np.array(sorted(circles, key=lambda x: (x.fill, x.center[0], -x.center[1]))).reshape( | ||
(2, size, size) | ||
) | ||
for i in range(size): | ||
for j in range(size): | ||
assert np.isclose( | ||
np.abs(matrix[i, j]) * mean_radius / mean_value, circles[1, i, j].radius | ||
) | ||
|
||
# Check that all the rectangles are of the right height, and only on the diagonal elements | ||
rects = list( | ||
filter( | ||
lambda x: isinstance(x, patches.Rectangle) and x.get_alpha() is not None, | ||
ax.get_children(), | ||
) | ||
) | ||
assert len(rects) == size | ||
mean_size = np.mean([r.get_height() for r in rects]) | ||
mean_value = np.trace(np.abs(matrix)) / size | ||
rects = np.array(sorted(rects, key=lambda x: x.get_x())) | ||
for i in range(size): | ||
# Ensuring that the rectangle is the right height | ||
assert np.isclose(np.abs(matrix[i, i]) * mean_size / mean_value, rects[i].get_height()) | ||
rect_points = rects[i].get_bbox().get_points() | ||
# Checking for the x position of the rectangle corresponding x of the center of the circle | ||
assert np.isclose((rect_points[0, 0] + rect_points[1, 0]) / 2, circles[1, i, i].center[0]) | ||
# Asserting that only the diagonal elements are on | ||
assert ( | ||
np.abs((rect_points[0, 1] + rect_points[1, 1]) / 2 - circles[1, i, i].center[1]) | ||
<= circles[0, i, i].radius * 1.5 | ||
) | ||
|
||
|
||
@pytest.mark.parametrize('show_rect', [True, False]) | ||
@pytest.mark.parametrize('value', [0.0, 1.0, 0.5 + 0.3j, 0.2 + 0.1j, 0.5 + 0.5j]) | ||
def test_density_element_plot(value, show_rect): | ||
_, ax = plt.subplots(figsize=(10, 10)) | ||
_plot_element_of_density_matrix( | ||
ax, 0, 0, np.abs(value), np.angle(value), show_rect=False, show_text=False | ||
) | ||
# Check that the right phase is being plotted | ||
plotted_lines = list(filter(lambda x: isinstance(x, lines.Line2D), ax.get_children())) | ||
assert len(plotted_lines) == 1 | ||
line_position = plotted_lines[0].get_xydata() | ||
angle = np.arctan( | ||
(line_position[1, 1] - line_position[0, 1]) / (line_position[1, 0] - line_position[0, 0]) | ||
) | ||
assert np.isclose(np.angle(value), angle) | ||
# Check if the circles are the right size ratio, given the value of the element | ||
circles_in = list(filter(lambda x: isinstance(x, patches.Circle) and x.fill, ax.get_children())) | ||
assert len(circles_in) == 1 | ||
circles_out = list( | ||
filter(lambda x: isinstance(x, patches.Circle) and not x.fill, ax.get_children()) | ||
) | ||
assert len(circles_out) == 1 | ||
assert np.isclose(circles_in[0].radius, circles_out[0].radius * np.abs(value)) | ||
# Check the rectangle is show if show_rect is on and it's filled if we are showing | ||
# the rectangle. If show_rect = False, the lack of a rectangle is not tested because | ||
# there are other rectangles on the plot that turn up with the axes, that get | ||
# checked when counting and matching the rectangles to the diagonal circles in | ||
# `test_density_matrix_circle_sizes` | ||
if show_rect: | ||
rectangles = list(filter(lambda x: isinstance(x, patches.Rectangle), ax.get_children())) | ||
assert len(rectangles) == 1 | ||
assert rectangles[0].fill | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'matrix', | ||
[ | ||
np.random.random(size=(4, 4, 4)), | ||
np.random.random((3, 3)) * np.exp(np.random.random((3, 3)) * 2 * np.pi * 1j), | ||
np.random.random((4, 8)) * np.exp(np.random.random((4, 8)) * 2 * np.pi * 1j), | ||
], | ||
) | ||
def test_density_matrix_type_error(matrix): | ||
with pytest.raises(ValueError, match="Incorrect shape for density matrix:*"): | ||
plot_density_matrix(matrix) |