# -*- coding: utf-8 -*-
"""
DisCopy's legacy drawing: turn a diagram into a directed graph then plot it.
Summary
-------
.. autosummary::
:template: class.rst
:nosignatures:
:toctree:
Node
Backend
TikzBackend
MatBackend
Equation
.. admonition:: Functions
.. autosummary::
:template: function.rst
:nosignatures:
:toctree:
diagram2nx
"""
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from math import sqrt
from tempfile import NamedTemporaryFile, TemporaryDirectory
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.patches import PathPatch
from matplotlib.path import Path
from discopy.utils import assert_isinstance
from discopy.config import (
DRAWING_ATTRIBUTES as ATTRIBUTES,
DRAWING_DEFAULT as DEFAULT, COLORS, SHAPES)
[docs]class Node:
""" Node in a :class:`networkx.Graph`, can hold arbitrary data. """
def __init__(self, kind, **data):
self.kind, self.data = kind, data
for key, value in data.items():
setattr(self, key, value)
def __eq__(self, other):
return isinstance(other, Node)\
and (self.kind, self.data) == (other.kind, other.data)
def __repr__(self):
str_data = ", ".join(
f"{key}={repr(value)}"
for key, value in sorted(self.data.items())
)
return f"Node({repr(self.kind)}, {str_data})"
def __hash__(self):
return hash(repr(self))
__str__ = __repr__
[docs]def diagram2nx(diagram):
"""
Builds a networkx graph, called by :meth:`Diagram.draw`.
Parameters
----------
diagram : discopy.monoidal.Diagram
any diagram.
Returns
-------
graph : networkx.DiGraph
with nodes for inputs, outputs, boxes and their co/domain.
positions : Mapping[Node, Tuple[float, float]]
from nodes to pairs of floats.
"""
import networkx as nx
graph, pos = nx.DiGraph(), dict()
def add_node(node, position):
graph.add_node(node)
pos.update({node: position})
def add_box(scan, box, off, depth, x_pos):
bubble_opening = getattr(box, "bubble_opening", False)
bubble_closing = getattr(box, "bubble_closing", False)
bubble = bubble_opening or bubble_closing
node = Node("box", box=box, depth=depth)
add_node(node, (x_pos, len(diagram) - depth - .5))
for i, obj in enumerate(box.dom.inside):
wire, position = Node("dom", obj=obj, i=i, depth=depth), (
pos[scan[off + i]][0], len(diagram) - depth - .25)
add_node(wire, position)
graph.add_edge(scan[off + i], wire)
if not bubble or bubble_closing and i in [0, len(box.dom) - 1]:
graph.add_edge(wire, node)
for i, obj in enumerate(box.cod.inside):
position = (
pos[scan[off + i]][0] if len(box.dom) == len(box.cod)
else x_pos - len(box.cod[1:]) / 2 + i,
len(diagram) - depth - .75)
wire = Node("cod", obj=obj, i=i, depth=depth)
add_node(wire, position)
if not bubble or bubble_opening and i in [0, len(box.cod) - 1]:
graph.add_edge(node, wire)
if bubble_opening or bubble_closing:
source_ty, target_ty = (box.dom, box.cod[1:-1]) if bubble_opening\
else (box.dom[1:-1], box.cod)
for i, (source_obj, target_obj) in enumerate(zip(
source_ty.inside, target_ty.inside)):
source_i, target_i = (i, i + 1) if bubble_opening\
else (i + 1, i)
source = Node("dom", obj=source_obj, i=source_i, depth=depth)
target = Node("cod", obj=target_obj, i=target_i, depth=depth)
graph.add_edge(source, target)
return scan[:off]\
+ [Node("cod", obj=obj, i=i, depth=depth)
for i, obj in enumerate(box.cod.inside)]\
+ scan[off + len(box.dom):]
def make_space(scan, box, off):
if not scan:
return 0
half_width = len(box.cod[:-1]) / 2 + 1
if not box.dom:
if not off:
x_pos = pos[scan[0]][0] - half_width
elif off == len(scan):
x_pos = pos[scan[-1]][0] + half_width
else:
right = pos[scan[off + len(box.dom)]][0]
x_pos = (pos[scan[off - 1]][0] + right) / 2
else:
right = pos[scan[off + len(box.dom) - 1]][0]
x_pos = (pos[scan[off]][0] + right) / 2
if off and pos[scan[off - 1]][0] > x_pos - half_width:
limit = pos[scan[off - 1]][0]
pad = limit - x_pos + half_width
for node, position in pos.items():
if position[0] <= limit:
pos[node] = (pos[node][0] - pad, pos[node][1])
if off + len(box.dom) < len(scan)\
and pos[scan[off + len(box.dom)]][0] < x_pos + half_width:
limit = pos[scan[off + len(box.dom)]][0]
pad = x_pos + half_width - limit
for node, position in pos.items():
if position[0] >= limit:
pos[node] = (pos[node][0] + pad, pos[node][1])
return x_pos
scan = []
for i, obj in enumerate(diagram.dom.inside):
node = Node("input", obj=obj, i=i)
add_node(node, (i, len(diagram) or 1))
scan.append(node)
for depth, (box, off) in enumerate(zip(diagram.boxes, diagram.offsets)):
x_pos = make_space(scan, box, off)
scan = add_box(scan, box, off, depth, x_pos)
for i, obj in enumerate(diagram.cod.inside):
node = Node("output", obj=obj, i=i)
add_node(node, (pos[scan[i]][0], 0))
graph.add_edge(scan[i], node)
return graph, pos
[docs]class Backend(ABC):
""" Abstract drawing backend. """
def __init__(self, linewidth=1):
self.max_width = 0
[docs] def draw_text(self, text, i, j, **params):
""" Draws a piece of text at a given position. """
self.max_width = max(self.max_width, i)
[docs] def draw_node(self, i, j, **params):
""" Draws a node for a given position, color and shape. """
self.max_width = max(self.max_width, i)
[docs] def draw_polygon(self, *points, color=DEFAULT["color"]):
""" Draws a polygon given a list of points. """
self.max_width = max(self.max_width, max(i for i, _ in points))
[docs] def draw_wire(self, source, target,
bend_out=False, bend_in=False, style=None):
""" Draws a wire from source to target, possibly with a Bezier. """
self.max_width = max(self.max_width, source[0], target[0])
[docs] def draw_spiders(self, graph, positions, draw_box_labels=True, **params):
""" Draws a list of boxes depicted as spiders. """
spider_widths = [
positions[n][0] for n in graph.nodes
if n.kind == 'box' and n.box.draw_as_spider]
if spider_widths:
self.max_width = max(self.max_width, max(spider_widths))
[docs] @abstractmethod
def output(self, path=None, show=True, **params):
""" Output the drawing. """
[docs]class TikzBackend(Backend):
""" Tikz drawing backend. """
def __init__(self, use_tikzstyles=None):
self.use_tikzstyles = DEFAULT["use_tikzstyles"]\
if use_tikzstyles is None else use_tikzstyles
self.node_styles, self.edge_styles = [], []
self.nodes, self.nodelayer, self.edgelayer = {}, [], []
super().__init__()
[docs] def add_node(self, i, j, text=None, options=None):
""" Add a node to the tikz picture, return its unique id. """
node = len(self.nodes) + 1
text = "" if text is None else text
self.nodelayer.append(
f"\\node [{options or ''}] ({node}) at ({i}, {j}) {{{text}}};\n")
self.nodes.update({(i, j): node})
return node
def draw_node(self, i, j, text=None, **params):
options = []
if 'shape' in params:
options.append(params['shape'])
if 'color' in params:
options.append(params['color'])
self.add_node(i, j, text, options=", ".join(options))
super().draw_node(i, j, **params)
def draw_text(self, text, i, j, **params):
options = "style=none, fill=white"
if params.get('horizontalalignment', 'center') == 'left':
options += ", anchor=west"
if params.get("verticalalignment", "center") == "top": # wire labels
options += ", right"
if 'fontsize' in params and params['fontsize'] is not None:
options += f", scale={params['fontsize']}"
self.add_node(i, j, text, options)
super().draw_text(text, i, j, **params)
def draw_polygon(self, *points, color=DEFAULT["color"]):
nodes = []
for point in points:
nodes.append(self.add_node(*point))
nodes.append(nodes[0])
if self.use_tikzstyles:
style_name = "box" if color == DEFAULT["color"]\
else f"{color}_box"
style = f"\\tikzstyle{{{style_name}}}=" \
f"[-, fill={self.format_color(color)}]\n"
if style not in self.edge_styles:
self.edge_styles.append(style)
options = f"style={style_name}"
else:
options = f"-, fill={{{color}}}"
str_connections = " to ".join(f"({node}.center)" for node in nodes)
self.edgelayer.append(f"\\draw [{options}] {str_connections};\n")
super().draw_polygon(*points, color=color)
def draw_wire(self, source, target,
bend_out=False, bend_in=False, style=None):
out = -90 if not bend_out or source[0] == target[0]\
else (180 if source[0] > target[0] else 0)
inp = 90 if not bend_in or source[0] == target[0]\
else (180 if source[0] < target[0] else 0)
looseness = 1
if not (source[0] == target[0] or source[1] == target[1]):
dx, dy = abs(source[0] - target[0]), abs(source[1] - target[1])
length = sqrt(dx * dx + dy * dy)
distance = min(dx, dy)
looseness = round(distance / length * 2.1, 4)
if looseness != 1:
if style is None:
style = ''
style += f'looseness={looseness}'
cmd = (
"\\draw [in={}, out={}{}] "
"({}.center) to ({}.center);\n")
if source not in self.nodes:
self.add_node(*source)
if target not in self.nodes:
self.add_node(*target)
self.edgelayer.append(cmd.format(
inp, out,
f", {style}" if style is not None else "",
self.nodes[source], self.nodes[target]))
super().draw_wire(source, target, bend_out=bend_out, bend_in=bend_in)
def draw_spiders(self, graph, positions, draw_box_labels=True, **params):
spiders = [(node, node.box.color, node.box.shape)
for node in graph.nodes
if node.kind == "box" and node.box.draw_as_spider]
for node, color, shape in spiders:
i, j = positions[node]
text = node.box.drawing_name if draw_box_labels else ""
if self.use_tikzstyles:
style = f"\\tikzstyle{{{node.box.tikzstyle_name}}}=" \
f"[fill={self.format_color(color)}]\n"
if style not in self.node_styles:
self.node_styles.append(style)
options = f"style={node.box.tikzstyle_name}"
else:
options = f"{shape}, fill={color}"
if params.get("nodesize", 1) != 1:
options +=\
f", scale={params.get('nodesize')}" # pragma: no cover
self.add_node(i, j, text, options)
super().draw_spiders(graph, positions, draw_box_labels)
def output(self, path=None, show=True, **params):
baseline = params.get("baseline", 0)
tikz_options = params.get("tikz_options", None)
output_tikzstyle = self.use_tikzstyles\
and params.get("output_tikzstyle", True)
options = "baseline=(0.base)" if tikz_options is None\
else "baseline=(0.base), " + tikz_options
begin = [f"\\begin{{tikzpicture}}[{options}]\n"]
nodes = ["\\begin{pgfonlayer}{nodelayer}\n",
f"\\node (0) at (0, {baseline}) {{}};\n"]\
+ self.nodelayer + ["\\end{pgfonlayer}\n"]
edges = ["\\begin{pgfonlayer}{edgelayer}\n"] + self.edgelayer\
+ ["\\end{pgfonlayer}\n"]
end = ["\\end{tikzpicture}\n"]
if path is not None:
if output_tikzstyle:
style_path = '.'.join(path.split('.')[:-1]) + '.tikzstyles'
with open(style_path, 'w+') as file:
file.writelines(["% Node styles\n"] + self.node_styles)
file.writelines(["% Edge styles\n"] + self.edge_styles)
with open(path, 'w+') as file:
file.writelines(begin + nodes + edges + end)
elif show: # pragma: no cover
if output_tikzstyle:
print(''.join(self.node_styles + self.edge_styles))
print(''.join(begin + nodes + edges + end))
[docs]class MatBackend(Backend):
""" Matplotlib drawing backend. """
def __init__(self, axis=None, figsize=None, linewidth=1):
self.axis = axis or plt.subplots(figsize=figsize, facecolor='white')[1]
self.linewidth = linewidth
super().__init__()
def draw_text(self, text, i, j, **params):
params['fontsize'] = params.get('fontsize', DEFAULT['fontsize'])
self.axis.text(i, j, text, **params)
super().draw_text(text, i, j, **params)
def draw_node(self, i, j, **params):
self.axis.scatter(
[i], [j],
c=COLORS[params.get("color", "black")],
marker=SHAPES[params.get("shape", "circle")],
s=300 * params.get("nodesize", 1),
edgecolors=params.get("edgecolor", None))
super().draw_node(i, j, **params)
def draw_polygon(self, *points, color=DEFAULT["color"]):
codes = [Path.MOVETO]
codes += len(points[1:]) * [Path.LINETO] + [Path.CLOSEPOLY]
path = Path(points + points[:1], codes)
self.axis.add_patch(PathPatch(
path, facecolor=COLORS[color], linewidth=self.linewidth))
super().draw_polygon(*points, color=color)
def draw_wire(self, source, target,
bend_out=False, bend_in=False, style=None):
if style == '->': # pragma: no cover
self.axis.arrow(
*(source + (target[0] - source[0], target[1] - source[1])),
head_width=.02, color="black")
else:
mid = (target[0], source[1])\
if bend_out else (source[0], target[1])
path = Path([source, mid, target],
[Path.MOVETO, Path.CURVE3, Path.CURVE3])
self.axis.add_patch(PathPatch(
path, facecolor='none', linewidth=self.linewidth))
super().draw_wire(source, target, bend_out=bend_out, bend_in=bend_in)
def draw_spiders(self, graph, positions, draw_box_labels=True, **params):
import networkx as nx
nodes = {node for node in graph.nodes
if node.kind == "box" and node.box.draw_as_spider}
shapes = {node: node.box.shape for node in nodes}
for shape in set(shapes.values()):
colors = {n: n.box.color for n, s in shapes.items() if s == shape}
nodes, colors = zip(*colors.items())
nx.draw_networkx_nodes(
graph, positions, nodelist=nodes,
node_color=[COLORS[color] for color in colors],
node_shape=SHAPES[shape], ax=self.axis,
node_size=300 * params.get("nodesize", 1))
if draw_box_labels:
labels = {node: node.box.drawing_name for node in nodes}
nx.draw_networkx_labels(graph, positions, labels)
super().draw_spiders(graph, positions, draw_box_labels)
def output(self, path=None, show=True, **params):
xlim, ylim = params.get("xlim", None), params.get("ylim", None)
margins = params.get("margins", DEFAULT['margins'])
aspect = params.get("aspect", DEFAULT['aspect'])
plt.margins(*margins)
plt.subplots_adjust(
top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
self.axis.set_aspect(aspect)
plt.axis('off')
if xlim is not None:
self.axis.set_xlim(*xlim)
if ylim is not None:
self.axis.set_ylim(*ylim)
if path is not None:
plt.savefig(path)
plt.close()
if show:
plt.show()
def needs_asymmetry(diagram):
if hasattr(diagram, "terms"):
return any(needs_asymmetry(d) for d in diagram.terms)
return any(
box.is_dagger
or getattr(box, "is_conjugate", False)
or getattr(box, "is_transpose", False)
for box in diagram.boxes)
def draw(diagram, **params):
"""
Draws a diagram using networkx and matplotlib.
Parameters
----------
draw_as_nodes : bool, optional
Whether to draw boxes as nodes, default is :code:`False`.
color : string, optional
Color of the box or node, default is white (:code:`'#ffffff'`) for
boxes and red (:code:`'#ff0000'`) for nodes.
textpad : pair of floats, optional
Padding between text and wires, default is :code:`(0.1, 0.1)`.
draw_type_labels : bool, optional
Whether to draw type labels, default is :code:`False`.
draw_box_labels : bool, optional
Whether to draw box labels, default is :code:`True`.
aspect : string, optional
Aspect ratio, one of :code:`['auto', 'equal']`.
margins : tuple, optional
Margins, default is :code:`(0.05, 0.05)`.
nodesize : float, optional
Node size for spiders and controlled gates.
fontsize : int, optional
Font size for the boxes, default is :code:`12`.
fontsize_types : int, optional
Font size for the types, default is :code:`12`.
figsize : tuple, optional
Figure size.
path : str, optional
Where to save the image, if `None` we call :code:`plt.show()`.
to_tikz : bool, optional
Whether to output tikz code instead of matplotlib.
asymmetry : float, optional
Make a box and its dagger mirror images, default is
:code:`.25 * any(box.is_dagger for box in diagram.boxes)`.
"""
diagram = diagram.to_drawing()
drawing_methods = [
("draw_as_brakets", draw_brakets),
("draw_as_controlled", draw_controlled_gate),
("draw_as_discards", draw_discard),
("draw_as_measures", draw_measure),
(None, draw_box)]
params['asymmetry'] = params.get(
'asymmetry', .25 * needs_asymmetry(diagram))
def draw_wires(backend, graph, positions):
for source, target in graph.edges():
def inside_a_box(node):
return node.kind == "box"\
and not node.box.draw_as_wires\
and not node.box.draw_as_spider
if inside_a_box(source) or inside_a_box(target):
continue # no need to draw wires inside a box
braid_shadow, source_position, target_position =\
DEFAULT["braid_shadow"], positions[source], positions[target]
bend_out, bend_in = source.kind == "box", target.kind == "box"
if source.kind == "box" and source.box.draw_as_braid:
if source.box.is_dagger and target.i == 0:
source_position = tuple(
x + b * shadow
for x, b, shadow in zip(
source_position, [-1, -1], braid_shadow))
if not source.box.is_dagger and target.i == 1:
source_position = tuple(
x + b * shadow
for x, b, shadow in zip(
source_position, [1, -1], braid_shadow))
if target.kind == "box" and target.box.draw_as_braid:
if target.box.is_dagger and source.i == 1:
target_position = tuple(
x + b * shadow
for x, b, shadow in zip(
target_position, [1, 1], braid_shadow))
if not target.box.is_dagger and source.i == 0:
target_position = tuple(
x + b * shadow
for x, b, shadow in zip(
target_position, [-1, 1], braid_shadow))
backend.draw_wire(
source_position, target_position, bend_out, bend_in)
if source.kind in ["input", "cod"]\
and (params.get('draw_type_labels', True)
or getattr(source.obj, "always_draw_label", False)
and params.get('draw_box_labels', True)):
i, j = positions[source]
pad_i, pad_j = params.get('textpad', DEFAULT['textpad'])
pad_j = 0 if source.kind == "input" else pad_j
backend.draw_text(
str(source.obj), i + pad_i, j - pad_j,
fontsize=params.get('fontsize_types',
params.get('fontsize', None)),
verticalalignment='top')
return backend
def scale_and_pad(graph, pos, scale, pad):
if len(pos) == 0:
return pos
widths, heights = zip(*pos.values())
min_width, min_height = min(widths), min(heights)
pos = {n: ((x - min_width) * scale[0] + pad[0],
(y - min_height) * scale[1] + pad[1])
for n, (x, y) in pos.items()}
for box_node in graph.nodes:
if box_node.kind == "box":
for i, obj in enumerate(box_node.box.dom.inside):
node = Node("dom", obj=obj, i=i, depth=box_node.depth)
pos[node] = (
pos[node][0], pos[node][1] - .25 * (scale[1] - 1))
for i, obj in enumerate(box_node.box.cod.inside):
node = Node("cod", obj=obj, i=i, depth=box_node.depth)
pos[node] = (
pos[node][0], pos[node][1] + .25 * (scale[1] - 1))
return pos
scale, pad = params.get('scale', (1, 1)), params.get('pad', (0, 0))
graph, positions = diagram2nx(diagram)
positions = scale_and_pad(graph, positions, scale, pad)
backend = params.pop('backend') if 'backend' in params else\
TikzBackend(use_tikzstyles=params.get('use_tikzstyles', None))\
if params.get('to_tikz', False)\
else MatBackend(figsize=params.get('figsize', None),
linewidth=params.get('linewidth', 1))
min_size = 0.01
max_v = max([v for p in positions.values() for v in p] + [min_size])
params['nodesize'] = round(params.get('nodesize', 1.) / sqrt(max_v), 3)
backend = draw_wires(backend, graph, positions)
backend.draw_spiders(graph, positions, **params)
box_nodes = [node for node in graph.nodes if node.kind == "box"]
for node in box_nodes:
if node.box.draw_as_spider or node.box.draw_as_wires:
continue
for attr, drawing_method in drawing_methods:
if attr is None or getattr(node.box, attr, False):
backend = drawing_method(backend, positions, node, **params)
break
return backend.output(
path=params.get('path', None),
baseline=len(box_nodes) / 2 or .5,
tikz_options=params.get('tikz_options', None),
show=params.get('show', True),
margins=params.get('margins', DEFAULT['margins']),
aspect=params.get('aspect', DEFAULT['aspect']))
def draw_box(backend, positions, node, **params):
""" Draws a box node on a given backend. """
box, depth = node.box, node.depth
asymmetry = params.get('asymmetry')
if not box.dom and not box.cod:
left, right = positions[node][0], positions[node][0]
elif not box.dom:
left, right = (
positions[Node("cod", obj=box.cod.inside[i], i=i, depth=depth)][0]
for i in [0, len(box.cod) - 1])
elif not box.cod:
left, right = (
positions[Node("dom", obj=box.dom.inside[i], i=i, depth=depth)][0]
for i in [0, len(box.dom) - 1])
else:
top_left, top_right = (
positions[Node("dom", obj=box.dom.inside[i], i=i, depth=depth)][0]
for i in [0, len(box.dom) - 1])
bottom_left, bottom_right = (
positions[Node("cod", obj=box.cod.inside[i], i=i, depth=depth)][0]
for i in [0, len(box.cod) - 1])
left = min(top_left, bottom_left)
right = max(top_right, bottom_right)
height = positions[node][1] - .25
left, right = left - .25, right + .25
# dictionary key is (is_dagger, is_conjugate)
points = [[left, height], [right, height],
[right, height + .5], [left, height + .5]]
box.is_conjugate = getattr(box, "is_conjugate", False)
box.is_transpose = getattr(box, "is_transpose", False)
if box.is_transpose:
points[0][0] -= asymmetry
elif box.is_conjugate:
points[3][0] -= asymmetry
elif box.is_dagger:
points[1][0] += asymmetry
else:
points[2][0] += asymmetry
backend.draw_polygon(*points, color=box.color)
if params.get('draw_box_labels', True):
backend.draw_text(box.drawing_name, *positions[node],
ha='center', va='center',
fontsize=params.get('fontsize', None))
return backend
def to_gif(diagram, *diagrams, **params): # pragma: no cover
"""
Builds a gif with the normalisation steps.
Parameters
----------
diagrams : :class:`Diagram`, optional
Sequence of diagrams to draw.
path : str
Where to save the image, if :code:`None` a gif gets created.
timestep : int, optional
Time step in milliseconds, default is :code:`500`.
loop : bool, optional
Whether to loop, default is :code:`False`
params : any, optional
Passed to :meth:`Diagram.draw`.
"""
path = params.pop("path", None)
timestep = params.get("timestep", 500)
loop = params.get("loop", False)
steps, frames = (diagram, ) + diagrams, []
path = path or os.path.basename(NamedTemporaryFile(
suffix='.gif', prefix='tmp_', dir='.').name)
with TemporaryDirectory() as directory:
for i, _diagram in enumerate(steps):
tmp_path = os.path.join(directory, f'{i}.png')
_diagram.draw(path=tmp_path, **params)
frames.append(Image.open(tmp_path))
if loop:
frames = frames + frames[::-1]
frames[0].save(path, format='GIF', append_images=frames[1:],
save_all=True, duration=timestep,
**{'loop': 0} if loop else {})
try:
from IPython.display import HTML
return HTML(f'<img src="{path}">')
except ImportError:
return f'<img src="{path}">'
[docs]class Equation:
"""
An equation is a list of diagrams with a dedicated draw method.
Parameters:
terms : The terms of the equation.
symbol : The symbol between the terms.
space : The space between the terms.
Example
-------
>>> from discopy.tensor import Spider, Swap, Dim, Id
>>> dim = Dim(2)
>>> mu, eta = Spider(2, 1, dim), Spider(0, 1, dim)
>>> delta, upsilon = Spider(1, 2, dim), Spider(1, 0, dim)
>>> special = Equation(mu >> delta, Id(dim))
>>> frobenius = Equation(
... delta @ Id(dim) >> Id(dim) @ mu,
... mu >> delta,
... Id(dim) @ delta >> mu @ Id(dim))
>>> Equation(special, frobenius, symbol=', ').draw(
... aspect='equal', draw_type_labels=False, figsize=(8, 2),
... path='docs/_static/drawing/frobenius-axioms.png')
.. image:: /_static/drawing/frobenius-axioms.png
:align: center
"""
def __init__(self, *terms: discopy.monoidal.Diagram, symbol="=", space=1):
self.terms, self.symbol, self.space = terms, symbol, space
def __repr__(self):
return f"Equation({', '.join(map(repr, self.terms))})"
def __str__(self):
return f" {self.symbol} ".join(map(str, self.terms))
[docs] def draw(self, path=None, space=None, **params):
"""
Drawing an equation.
Parameters:
path : Where to save the drawing.
space : The amount of space between the terms.
params : Passed to :meth:`discopy.monoidal.Diagram.draw`.
"""
def height(term):
# i.e. if isinstance(diagram, (Sum, Equation))
if hasattr(term, "terms"):
return max(height(d) for d in term.terms)
return len(term.to_drawing()) or 1
params['asymmetry'] = params.get(
'asymmetry', .25 * needs_asymmetry(self))
space = space or self.space
max_height = max(map(height, self.terms))
pad = params.get('pad', (0, 0))
scale_x, scale_y = params.get('scale', (1, 1))
backend = params['backend'] if 'backend' in params\
else TikzBackend(
use_tikzstyles=params.get('use_tikzstyles', None))\
if params.get('to_tikz', False)\
else MatBackend(figsize=params.get('figsize', None))
for i, term in enumerate(self.terms):
scale = (scale_x, scale_y * max_height / height(term))
term.draw(**dict(
params, show=False, path=None,
backend=backend, scale=scale, pad=pad))
pad = (backend.max_width + space, 0)
if i < len(self.terms) - 1:
backend.draw_text(
self.symbol, pad[0], scale_y * max_height / 2)
pad = (pad[0] + space, pad[1])
return backend.output(
path=path,
baseline=max_height / 2,
tikz_options=params.get('tikz_options', None),
show=params.get("show", True),
margins=params.get('margins', DEFAULT['margins']),
aspect=params.get('aspect', DEFAULT['aspect']))
def __bool__(self):
return all(term == self.terms[0] for term in self.terms)
def draw_discard(backend, positions, node, **params):
""" Draws a :class:`discopy.quantum.circuit.Discard` box. """
box, depth = node.box, node.depth
for i in range(len(box.dom)):
obj = box.dom.inside[i]
wire = Node("dom", obj=obj, depth=depth, i=i)
middle = positions[wire]
left, right = middle[0] - .25, middle[0] + .25
height = positions[node][1] + .25
for j in range(3):
source = (left + .1 * j, height - .1 * j)
target = (right - .1 * j, height - .1 * j)
backend.draw_wire(source, target)
return backend
def draw_measure(backend, positions, node, **params):
""" Draws a :class:`discopy.quantum.circuit.Measure` box. """
backend = draw_box(backend, positions, node,
**dict(params, draw_box_labels=False))
i, j = positions[node]
backend.draw_wire((i - .15, j - .1), (i, j + .1), bend_in=True)
backend.draw_wire((i, j + .1), (i + .15, j - .1), bend_out=True)
backend.draw_wire((i, j - .1), (i + .05, j + .15), style='->')
return backend
def draw_brakets(backend, positions, node, **params):
""" Draws a :class:`discopy.quantum.gates.Ket` box. """
box, depth = node.box, node.depth
is_bra = len(box.dom) > 0
for i, bit in enumerate(box._digits):
kind = "dom" if is_bra else "cod"
obj = box.dom.inside[i] if is_bra else box.cod.inside[i]
wire = Node(kind, obj=obj, depth=depth, i=i)
middle = positions[wire]
left = middle[0] - .25, middle[1]
right = middle[0] + .25, middle[1]
top = middle[0], middle[1] + .5
bottom = middle[0], middle[1] - .5
backend.draw_polygon(
left, right, bottom if is_bra else top, color=box.color)
backend.draw_text(
bit, middle[0], middle[1] + (-.25 if is_bra else .2),
ha='center', va='center', fontsize=params.get('fontsize', None))
return backend
def draw_controlled_gate(backend, positions, node, **params):
""" Draws a :class:`discopy.quantum.gates.Controlled` gate. """
box, depth = node.box, node.depth
distance = box.distance
c_size = len(box.controlled.dom)
index = (0, distance) if distance > 0 else (c_size - distance - 1, 0)
dom = Node("dom", obj=box.dom.inside[0], i=index[0], depth=depth)
cod = Node("cod", obj=box.cod.inside[0], i=index[0], depth=depth)
middle = positions[dom][0], (positions[dom][1] + positions[cod][1]) / 2
controlled_box = box.controlled.to_drawing()
controlled = Node("box", box=controlled_box, depth=depth)
# TODO select obj properly for classical gates
c_dom = Node("dom", obj=box.dom.inside[0], i=index[1], depth=depth)
c_cod = Node("cod", obj=box.cod.inside[0], i=index[1], depth=depth)
c_middle =\
positions[c_dom][0], (positions[c_dom][1] + positions[c_cod][1]) / 2
target = (positions[c_dom][0] + (c_size - 1) / 2,
(positions[c_dom][1] + positions[c_cod][1]) / 2)
target_boundary = target
if controlled_box.name == "X": # CX gets drawn as a circled plus sign.
backend.draw_wire(positions[c_dom], positions[c_cod])
eps = 1e-10
perturbed_target = target[0], target[1] + eps
backend.draw_node(
*perturbed_target,
shape="circle", color="white", edgecolor="black",
nodesize=2 * params.get("nodesize", 1))
backend.draw_node(
*target, shape="plus",
nodesize=2 * params.get("nodesize", 1))
else:
fake_positions = {controlled: target}
for i in range(c_size):
dom_node = Node("dom", obj=box.dom.inside[i], i=i, depth=depth)
x, y = positions[c_dom][0] + i, positions[c_dom][1]
fake_positions[dom_node] = x, y
cod_node = Node("cod", obj=box.cod.inside[i], i=i, depth=depth)
x, y = positions[c_cod][0] + i, positions[c_cod][1]
fake_positions[cod_node] = x, y
shift_boundary = True
if hasattr(box.controlled, "draw_as_controlled"):
backend = draw_controlled_gate(
backend, fake_positions, controlled, **params)
next_box = box.controlled
while hasattr(next_box, "controlled"):
if controlled_box.distance * next_box.distance < 0:
shift_boundary = False
break
next_box = next_box.controlled
if next_box.name == "X":
shift_boundary = False
else:
backend = draw_box(backend, fake_positions, controlled, **params)
if shift_boundary:
if box.distance > 0:
target_boundary = c_middle[0] - .25, c_middle[1]
else:
target_boundary = c_middle[0] + c_size - 1 + .25, c_middle[1]
else:
if box.distance > 0:
target_boundary = c_middle[0], c_middle[1]
else:
target_boundary = c_middle[0] + c_size - 1, c_middle[1]
backend.draw_wire(positions[dom], positions[cod])
# draw all the other vertical wires
extra_offset = 1 if distance > 0 else len(box.controlled.dom)
for i in range(extra_offset, extra_offset + abs(distance) - 1):
node1 = Node("dom", obj=box.dom.inside[i], i=i, depth=depth)
node2 = Node("cod", obj=box.cod.inside[i], i=i, depth=depth)
backend.draw_wire(positions[node1], positions[node2])
# TODO change bend_in and bend_out for tikz backend
backend.draw_wire(middle, target_boundary, bend_in=True, bend_out=True)
backend.draw_node(
*middle, color="black", shape="circle",
nodesize=params.get("nodesize", 1))
return backend