Skip to content

Commit

Permalink
Made the Density Matrix Plotter (#4493)
Browse files Browse the repository at this point in the history
Implements a Quirk like plotting for Density Matrices.
The plotting is static, no dynamic sweeps are a part of this.

Closes #4485.
  • Loading branch information
AnimeshSinha1309 committed Oct 14, 2021
1 parent aaf969d commit 50271af
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 0 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/vis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from cirq.vis.histogram import integrated_histogram

from cirq.vis.state_histogram import get_state_histogram, plot_state_histogram
from cirq.vis.density_matrix import plot_density_matrix

from cirq.vis.vis_utils import relative_luminance
137 changes: 137 additions & 0 deletions cirq-core/cirq/vis/density_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tool to visualize the magnitudes and phases in the density matrix"""

from typing import Optional

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import lines, patches

from cirq.qis.states import validate_density_matrix


def _plot_element_of_density_matrix(ax, x, y, r, phase, show_rect=False, show_text=False):
"""Plots a single element of a density matrix
Args:
x: x coordinate of the cell we are plotting
y: y coordinate of the cell we are plotting
r: the amplitude of the qubit in that cell
phase: phase of the qubit in that cell, in radians
show_rect: Boolean on if to show the amplitude rectangle, used for diagonal elements
show_text: Boolean on if to show text labels or not
ax: The axes to plot on
"""
# Setting up a few magic numbers for graphics
_half_cell_size_after_padding = (1 / 1.1) * 0.5
_rectangle_margin = 0.01
_image_opacity = 0.8 if not show_text else 0.4

circle_out = plt.Circle(
(x + 0.5, y + 0.5), radius=1 / _half_cell_size_after_padding, fill=False, color='#333333'
)
circle_in = plt.Circle(
(x + 0.5, y + 0.5),
radius=r / _half_cell_size_after_padding,
fill=True,
color='IndianRed',
alpha=_image_opacity,
)
line = lines.Line2D(
(x + 0.5, x + 0.5 + np.cos(phase) / _half_cell_size_after_padding),
(y + 0.5, y + 0.5 + np.sin(phase) / _half_cell_size_after_padding),
color='#333333',
alpha=_image_opacity,
)
ax.add_artist(circle_in)
ax.add_artist(circle_out)
ax.add_artist(line)
if show_rect:
rect = patches.Rectangle(
(x + _rectangle_margin, y + _rectangle_margin),
1.0 - 2 * _rectangle_margin,
r * (1 - 2 * _rectangle_margin),
alpha=0.25,
)
ax.add_artist(rect)
if show_text:
plt.text(
x + 0.5,
y + 0.5,
f"{np.round(r, decimals=2)}\n{np.round(phase * 180 / np.pi, decimals=2)} deg",
horizontalalignment='center',
verticalalignment='center',
)


def plot_density_matrix(
matrix: np.ndarray,
ax: Optional[plt.Axes] = None,
*,
show_text: bool = False,
title: Optional[str] = None,
) -> plt.Axes:
"""Generates a plot for a given density matrix.
1. Each entry of the density matrix, a complex number, is plotted as an
Argand Diagram where the partially filled red circle represents the magnitude
and the line represents the phase angle, going anti-clockwise from positive x - axis.
2. The blue rectangles on the diagonal elements represent the probability
of measuring the system in state $|i\rangle$.
Rendering scheme is inspired from https://algassert.com/quirk
Args:
matrix: The density matrix to visualize
show_text: If true, the density matrix values are also shown as text labels
ax: The axes to plot on
title: Title of the plot
"""
plt.style.use('ggplot')

_padding_around_plot = 0.001

matrix = matrix.astype(np.complex128)
num_qubits = int(np.log2(matrix.shape[0]))
validate_density_matrix(matrix, qid_shape=(2 ** num_qubits,))

if ax is None:
_, ax = plt.subplots(figsize=(10, 10))
ax.set_xlim(0 - _padding_around_plot, 2 ** num_qubits + _padding_around_plot)
ax.set_ylim(0 - _padding_around_plot, 2 ** num_qubits + _padding_around_plot)

for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
_plot_element_of_density_matrix(
ax,
i,
j,
np.abs(matrix[i][-j - 1]),
np.angle(matrix[i][-j - 1]),
show_rect=(i == matrix.shape[1] - j - 1),
show_text=show_text,
)

ticks, labels = np.arange(0.5, matrix.shape[0]), [
f"{'0'*(num_qubits - len(f'{i:b}'))}{i:b}" for i in range(matrix.shape[0])
]
ax.set_xticks(ticks)
ax.set_xticklabels(labels)
ax.set_yticks(ticks)
ax.set_yticklabels(reversed(labels))
ax.set_facecolor('#eeeeee')
if title is not None:
ax.set_title(title)
return ax
137 changes: 137 additions & 0 deletions cirq-core/cirq/vis/density_matrix_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Density Matrix Plotter."""

import numpy as np
import pytest

from matplotlib import lines, patches, text, spines, axis
from matplotlib import pyplot as plt

import cirq.testing
from cirq.vis.density_matrix import plot_density_matrix
from cirq.vis.density_matrix import _plot_element_of_density_matrix


@pytest.mark.parametrize('show_text', [True, False])
@pytest.mark.parametrize('size', [2, 4, 8, 16])
def test_density_matrix_plotter(size, show_text):
matrix = cirq.testing.random_density_matrix(size)
# Check that the title shows back up
ax = plot_density_matrix(matrix, show_text=show_text, title='Test Density Matrix Plot')
assert ax.get_title() == 'Test Density Matrix Plot'
# Check that the objects in the plot are only those we expect and nothing new was added
for obj in ax.get_children():
assert isinstance(
obj,
(
patches.Circle,
spines.Spine,
axis.XAxis,
axis.YAxis,
lines.Line2D,
patches.Rectangle,
text.Text,
),
)


@pytest.mark.parametrize('show_text', [True, False])
@pytest.mark.parametrize('size', [2, 4, 8, 16])
def test_density_matrix_circle_sizes(size, show_text):
matrix = cirq.testing.random_density_matrix(size)
# Check that the correct title is being shown
ax = plot_density_matrix(matrix, show_text=show_text, title='Test Density Matrix Plot')
# Check that the radius of all the circles in the matrix is correct
circles = list(filter(lambda x: isinstance(x, patches.Circle), ax.get_children()))
mean_radius = np.mean([c.radius for c in circles if c.fill])
mean_value = np.mean(np.abs(matrix))
circles = np.array(sorted(circles, key=lambda x: (x.fill, x.center[0], -x.center[1]))).reshape(
(2, size, size)
)
for i in range(size):
for j in range(size):
assert np.isclose(
np.abs(matrix[i, j]) * mean_radius / mean_value, circles[1, i, j].radius
)

# Check that all the rectangles are of the right height, and only on the diagonal elements
rects = list(
filter(
lambda x: isinstance(x, patches.Rectangle) and x.get_alpha() is not None,
ax.get_children(),
)
)
assert len(rects) == size
mean_size = np.mean([r.get_height() for r in rects])
mean_value = np.trace(np.abs(matrix)) / size
rects = np.array(sorted(rects, key=lambda x: x.get_x()))
for i in range(size):
# Ensuring that the rectangle is the right height
assert np.isclose(np.abs(matrix[i, i]) * mean_size / mean_value, rects[i].get_height())
rect_points = rects[i].get_bbox().get_points()
# Checking for the x position of the rectangle corresponding x of the center of the circle
assert np.isclose((rect_points[0, 0] + rect_points[1, 0]) / 2, circles[1, i, i].center[0])
# Asserting that only the diagonal elements are on
assert (
np.abs((rect_points[0, 1] + rect_points[1, 1]) / 2 - circles[1, i, i].center[1])
<= circles[0, i, i].radius * 1.5
)


@pytest.mark.parametrize('show_rect', [True, False])
@pytest.mark.parametrize('value', [0.0, 1.0, 0.5 + 0.3j, 0.2 + 0.1j, 0.5 + 0.5j])
def test_density_element_plot(value, show_rect):
_, ax = plt.subplots(figsize=(10, 10))
_plot_element_of_density_matrix(
ax, 0, 0, np.abs(value), np.angle(value), show_rect=False, show_text=False
)
# Check that the right phase is being plotted
plotted_lines = list(filter(lambda x: isinstance(x, lines.Line2D), ax.get_children()))
assert len(plotted_lines) == 1
line_position = plotted_lines[0].get_xydata()
angle = np.arctan(
(line_position[1, 1] - line_position[0, 1]) / (line_position[1, 0] - line_position[0, 0])
)
assert np.isclose(np.angle(value), angle)
# Check if the circles are the right size ratio, given the value of the element
circles_in = list(filter(lambda x: isinstance(x, patches.Circle) and x.fill, ax.get_children()))
assert len(circles_in) == 1
circles_out = list(
filter(lambda x: isinstance(x, patches.Circle) and not x.fill, ax.get_children())
)
assert len(circles_out) == 1
assert np.isclose(circles_in[0].radius, circles_out[0].radius * np.abs(value))
# Check the rectangle is show if show_rect is on and it's filled if we are showing
# the rectangle. If show_rect = False, the lack of a rectangle is not tested because
# there are other rectangles on the plot that turn up with the axes, that get
# checked when counting and matching the rectangles to the diagonal circles in
# `test_density_matrix_circle_sizes`
if show_rect:
rectangles = list(filter(lambda x: isinstance(x, patches.Rectangle), ax.get_children()))
assert len(rectangles) == 1
assert rectangles[0].fill


@pytest.mark.parametrize(
'matrix',
[
np.random.random(size=(4, 4, 4)),
np.random.random((3, 3)) * np.exp(np.random.random((3, 3)) * 2 * np.pi * 1j),
np.random.random((4, 8)) * np.exp(np.random.random((4, 8)) * 2 * np.pi * 1j),
],
)
def test_density_matrix_type_error(matrix):
with pytest.raises(ValueError, match="Incorrect shape for density matrix:*"):
plot_density_matrix(matrix)

0 comments on commit 50271af

Please sign in to comment.