import re
from collections.abc import Iterable
from copy import deepcopy
import numpy as np
from matplotlib.axes import Axes
from matplotlib.lines import Line2D
from matplotlib.patches import Circle, Polygon, Wedge
from matplotlib.text import Text
from .layout import Layout
Coordinates = tuple[float | int, float | int]
CoordRange = tuple[float | int, float | int]
# Regex to filter qubit labels for TeX text.
RE_FILTER = re.compile("([a-zA-Z]+)([0-9]+)")
# Order in which to draw the elements.
# We want the text to be the last drawn object so that is it on top and thus readable.
# Wedge is used to draw the logical support.
ZORDERS = dict(circle=3, wedge=2, patch=0, line=1, text=4)
# Define the default colors
COLORS = {
"red": "#e41a1cff",
"red_light": "#f07f80ff",
"green": "#4daf4aff",
"green_light": "#9dd49bff",
"blue": "#377eb8ff",
"blue_light": "#90bbdfff",
"orange": "#ff9933ff",
"orange_light": "#ffb770ff",
"purple": "#984ea3ff",
"purple_light": "#ca9ed1ff",
"yellow": "#f2c829ff",
"yellow_light": "#f7dc78ff",
}
def clockwise_sort(coordinates: Iterable[Coordinates]) -> list[Coordinates]:
"""Sorts a sequence of coordinates in clockwise order.
This function is used to correcly draw a ``matplotlib.patches.Polygon``.
Parameters
----------
coordinates
The coordinates to sort.
Returns
-------
sorted_coords
The sorted coordinates.
"""
coords = list(coordinates)
x_coords, y_coords = zip(*coords)
x_center = np.mean(x_coords)
y_center = np.mean(y_coords)
x_vectors: list[float] = [x - x_center for x in x_coords]
y_vectors: list[float] = [y - y_center for y in y_coords]
angles = np.arctan2(x_vectors, y_vectors)
inds = np.argsort(angles)
return [coords[ind] for ind in inds]
def get_label(qubit: str, coords: Coordinates, **kwargs) -> Text:
"""Draws the label of a qubit.
Parameters
----------
qubit
The qubit label.
coords
The coordinates of the qubit.
**kargs
Extra arguments for ``matplotlib.text.Text``.
"""
text = deepcopy(qubit)
# check if the qubit can be plotted using TeX
match = RE_FILTER.match(qubit)
if match is not None:
name, ind = match.groups()
text = f"${name}_\\mathrm{{{ind}}}$"
x, y = coords
zorder = ZORDERS["text"]
label = Text(x, y, text, zorder=zorder, **kwargs)
return label
def get_circle(center: Coordinates, radius: float, **kwargs) -> Circle:
"""Draws a ``matplotlib.patches.Circle`` with the given specifications.
Parameters
----------
coords
The coordinates of the centre of the circle.
radius
The radius of the circle.
**kargs
Extra arguments for ``matplotlib.patches.Circle``.
Returns
-------
Circle
The circle with the given specifications.
"""
zorder = ZORDERS["circle"]
circle = Circle(center, radius=radius, zorder=zorder, **kwargs)
return circle
def get_wedge(
center: Coordinates, r: float, theta1: float, theta2: float, **kwargs
) -> Wedge:
"""Draws a ``matplotlib.patches.Wedge`` with the given specifications.
Parameters
----------
coords
The coordinates of the centre of the circle.
r
The radius of the circle.
theta1
The angle to start drawing the wedge.
theta2
The angle to finish drawing the wedge.
**kargs
Extra arguments for ``matplotlib.patches.Wedge``.
Returns
-------
Wedge
The circle with the given specifications.
"""
zorder = ZORDERS["wedge"]
circle = Wedge(center, r=r, theta1=theta1, theta2=theta2, zorder=zorder, **kwargs)
return circle
def get_patch(patch_coords: Iterable[Coordinates], **kwargs) -> Polygon:
"""Draws a ``matplotlib.patches.Polygon`` with the given specifications.
Parameters
----------
patch_coords
The coordinates of the patch.
**kargs
Extra arguments for ``matplotlib.patches.Polygon``.
"""
zorder = ZORDERS["patch"]
patch = Polygon(patch_coords, closed=True, zorder=zorder, **kwargs)
return patch
def get_line(coordinates: Iterable[Coordinates], **kwargs) -> Line2D:
"""Draws a connection between two qubits.
Parameters
----------
qubit_coords
The coordinates of the qubits.
**kargs
Extra arguments for ``matplotlib.lines.Line2D``.
Returns
-------
line
Line between the two qubits.
"""
x_coords, y_coords = zip(*coordinates)
zorder = ZORDERS["line"]
line = Line2D(x_coords, y_coords, zorder=zorder, **kwargs)
return line
def qubit_labels(layout: Layout, label_fontsize: float | int = 11) -> Iterable[Text]:
"""Draws the qubit labels from a layout.
Parameters
----------
layout
The layout to draw the connections of.
label_fontsize
Default value of the font size for the labels.
"""
default_params = dict(
color="black",
verticalalignment="center",
horizontalalignment="center",
fontsize=label_fontsize,
)
for qubit in layout.qubits:
coords: Coordinates = layout.param("coords", qubit)
if len(coords) != 2:
raise ValueError(
"Coordinates must be 2D to be plotted, "
f"but {len(coords)}D were given for qubit {qubit}"
)
metaparams = layout.param("metaparams", qubit)
text_params = deepcopy(default_params)
if isinstance(metaparams, dict):
custom_params = metaparams.get("text", {})
text_params.update(custom_params)
yield get_label(qubit, coords, **text_params)
def qubit_connections(layout: Layout) -> Iterable[Line2D]:
"""Draws the connections between ancilla qubits and its neighbors.
Parameters
----------
layout
The layout to draw the connections of.
"""
default_params = dict(
linestyle="--",
)
for anc_qubit in layout.anc_qubits:
anc_coords: Coordinates = layout.param("coords", anc_qubit)
if len(anc_coords) != 2:
raise ValueError(
"Coordinates must be 2D to be plotted, "
f"but {len(anc_coords)}D were given for qubit {anc_qubit}."
)
metaparams = layout.param("metaparams", anc_qubit)
line_params = deepcopy(default_params)
stab_type = layout.param("stab_type", anc_qubit)
if stab_type == "z_type":
line_params["color"] = COLORS["blue"]
elif stab_type == "x_type":
line_params["color"] = COLORS["red"]
else:
line_params["color"] = COLORS["green"]
if isinstance(metaparams, dict):
custom_params = metaparams.get("line", {})
line_params.update(custom_params)
for nbr in layout.get_neighbors(anc_qubit):
nbr_coords = layout.param("coords", nbr)
line_coords = (anc_coords, nbr_coords)
yield get_line(line_coords, **line_params)
def qubit_artists(layout: Layout) -> Iterable[Circle]:
"""Draws the qubits of a layout.
Parameters
----------
layout
The layout to draw the qubits of.
"""
default_radius = 0.3
default_params = dict(edgecolor="black")
for qubit in layout.qubits:
coords: Coordinates = layout.param("coords", qubit)
if len(coords) != 2:
raise ValueError(
"Coordinates must be 2D to be plotted, "
f"but {len(coords)}D were given for qubit {qubit}."
)
metaparams = layout.param("metaparams", qubit)
radius = deepcopy(default_radius)
circle_params = deepcopy(default_params)
if layout.param("role", qubit) == "data":
circle_params["facecolor"] = "white"
else:
stab_type = layout.param("stab_type", qubit)
if stab_type == "z_type":
circle_params["facecolor"] = COLORS["blue"]
elif stab_type == "x_type":
circle_params["facecolor"] = COLORS["red"]
else:
circle_params["facecolor"] = COLORS["green"]
if isinstance(metaparams, dict):
custom_params = metaparams.get("circle", {})
circle_params.update(custom_params)
radius = circle_params.pop("radius", default_radius)
yield get_circle(coords, radius, **circle_params)
def logical_artists(layout: Layout) -> Iterable[Wedge]:
"""Draws the logical support of a layout.
Parameters
----------
layout
The layout to draw the qubits of.
"""
default_radius = 0.3
width = 0.1
default_params = dict(edgecolor="none")
logical_qubits = layout.logical_qubits
if len(logical_qubits) == 0:
return
angle = 360 / len(logical_qubits)
support: dict[str, dict[str, tuple[str]]] = {
"z": {l: layout.logical_param("log_z", l) for l in logical_qubits},
"x": {l: layout.logical_param("log_x", l) for l in logical_qubits},
}
support_x = [q for supp in support["x"].values() for q in supp]
for k, logical_qubit in enumerate(logical_qubits):
for pauli in ["z", "x"]:
for qubit in support[pauli][logical_qubit]:
coords: Coordinates = layout.param("coords", qubit)
if len(coords) != 2:
raise ValueError(
"Coordinates must be 2D to be plotted, "
f"but {len(coords)}D were given for qubit {qubit}."
)
wedge_params = deepcopy(default_params)
wedge_params["facecolor"] = COLORS["red"]
radius = default_radius + width
if pauli == "z":
wedge_params["facecolor"] = COLORS["blue"]
if qubit in support_x:
radius = default_radius + 2 * width
yield get_wedge(
coords,
radius,
theta1=k * angle,
theta2=(k + 1) * angle,
width=width,
**wedge_params,
)
def patch_artists(layout: Layout) -> Iterable[Polygon]:
"""Draws the stabilizer patches of a layout.
Parameters
----------
layout
The layout to draw the patches of.
"""
default_params = dict(edgecolor="black")
for anc_qubit in layout.anc_qubits:
anc_coords: Coordinates = layout.param("coords", anc_qubit)
if len(anc_coords) != 2:
raise ValueError(
"Coordinates must be 2D to be plotted, "
f"but {len(anc_coords)}D were given for qubit {anc_qubit}."
)
neigbors = layout.get_neighbors(anc_qubit)
coords: list[Coordinates] = [layout.param("coords", nbr) for nbr in neigbors]
# if the ancilla is only connected to two other data qubits,
# then the ancilla is one of the vertices of the stabilizer patch.
if len(neigbors) == 2:
coords.append(anc_coords)
# sort the coordinates so that a correct polygon is drawn.
patch_coords = clockwise_sort(coords)
metaparams = layout.param("metaparams", anc_qubit)
patch_params = deepcopy(default_params)
stab_type = layout.param("stab_type", anc_qubit)
if stab_type == "z_type":
patch_params["facecolor"] = COLORS["blue_light"]
elif stab_type == "x_type":
patch_params["facecolor"] = COLORS["red_light"]
else:
patch_params["facecolor"] = COLORS["green_light"]
if isinstance(metaparams, dict):
custom_params = metaparams.get("patch", {})
patch_params.update(custom_params)
yield get_patch(patch_coords, **patch_params)
def get_coord_range(layout: Layout) -> tuple[CoordRange, CoordRange]:
"""Returns the range for the X and Y coordinates in the Layout.
Parameters
----------
layout
Layout of which to compute the coordinate range.
Returns
-------
[(x_min, x_max), (y_min, y_max)].
"""
list_coords: list[Coordinates] = [
layout.param("coords", qubit) for qubit in layout.qubits
]
for coords in list_coords:
if len(coords) != 2:
raise ValueError(
f"Coordinates must be 2D to be plotted, but {len(coords)}D were given."
)
x_coords, y_coords = zip(*list_coords)
x_range: CoordRange = (min(x_coords), max(x_coords))
y_range: CoordRange = (min(y_coords), max(y_coords))
return x_range, y_range
[docs]
def plot(
ax: Axes,
*layouts: Layout,
add_labels: bool = True,
add_patches: bool = True,
add_connections: bool = True,
add_logicals: bool = True,
pad: float = 1,
stim_orientation: bool = True,
label_fontsize: float | int = 11,
) -> Axes:
"""Plots a layout.
Parameters
----------
ax
The axis to plot the layout on.
*layouts
List of layouts to plot.
add_labels
Flag to add qubit labels, by default ``True``.
add_patches
Flag to plot stabilizer patches, by default ``True``.
add_connections
Flag to plot lines indicating the connectivity, by default ``True``.
add_logicals
Flag to highlight the logical support on the data qubits, by default ``True``.
pad
The padding to the bottom axis, by default ``1``.
stim_orientation
Flag to orient the layout and axis as stim does for ``diagram``.
label_fontsize
Default font size of the qubit labels. If ``layout`` has information
about the font size, then this argument is ignored. The purpose
of this argument is to easily scale down the label size for
large codes.
Returns
-------
ax
The figure the layout was plotted on.
"""
x_min, x_max = np.inf, -np.inf
y_min, y_max = np.inf, -np.inf
for layout in layouts:
for artist in qubit_artists(layout):
ax.add_artist(artist)
if add_logicals:
for artist in logical_artists(layout):
ax.add_artist(artist)
if add_patches:
for artist in patch_artists(layout):
ax.add_artist(artist)
if add_connections:
for artist in qubit_connections(layout):
ax.add_artist(artist)
if add_labels:
for artist in qubit_labels(layout, label_fontsize):
ax.add_artist(artist)
x_range, y_range = get_coord_range(layout)
x_min, x_max = min(x_min, x_range[0]), max(x_max, x_range[1])
y_min, y_max = min(y_min, y_range[0]), max(y_max, y_range[1])
ax.set_xlim(x_min - pad, x_max + pad)
ax.set_ylim(y_min - pad, y_max + pad)
ax.set_xlabel("$x$ coordinate")
ax.set_ylabel("$y$ coordinate")
ax.set_aspect("equal")
if stim_orientation:
ax.invert_yaxis()
return ax