# Source code for discopy.rigid

```# -*- coding: utf-8 -*-

"""
The free rigid category, i.e. diagrams with cups and caps.

Summary
-------

.. autosummary::
:template: class.rst
:nosignatures:
:toctree:

Ob
Ty
PRO
Diagram
Box
Cup
Cap
Sum
Category
Functor

Axioms
------

>>> unit, s, n = Ty(), Ty('s'), Ty('n')
>>> t = n.r @ s @ n.l
>>> assert t @ unit == t == unit @ t
>>> assert t.l.r == t == t.r.l
>>> left_snake, right_snake = Id(n.r).transpose(left=True), Id(n.l).transpose()
>>> assert left_snake.normal_form() == Id(n) == right_snake.normal_form()

>>> from discopy.drawing import Equation
>>> Equation(left_snake, Id(n), right_snake).draw(
...     figsize=(4, 2), path='docs/_static/rigid/typed-snake-equation.png')

.. image:: /_static/rigid/typed-snake-equation.png
:align: center
"""

from __future__ import annotations

from collections.abc import Callable

from typing import Iterator

from discopy import cat, monoidal, closed, messages
from discopy.cat import factory
from discopy.utils import (
assert_isinstance,
factory_name,
BinaryBoxConstructor,
AxiomError,
assert_isatomic
)

[docs]
class Ob(cat.Ob):
"""
A rigid object has adjoints :meth:`Ob.l` and :meth:`Ob.r`.

Parameters:
name : The name of the object.
z : The winding number.

Example
-------
>>> a = Ob('a')
>>> assert a.l.r == a.r.l == a and a != a.l.l != a.r.r
"""
__ambiguous_inheritance__ = True

def __setstate__(self, state):
if '_z' in state:  # Backward compatibility
self.z = state['_z']
del state['_z']
super().__setstate__(state)

def __init__(self, name: str, z: int = 0):
assert_isinstance(z, int)
self.z = z
super().__init__(name)

@property
def l(self) -> Ob:
""" The left adjoint of the object. """
return type(self)(self.name, self.z - 1)

@property
def r(self) -> Ob:
""" The right adjoint of the object. """
return type(self)(self.name, self.z + 1)

def __eq__(self, other):
return cat.Ob.__eq__(self, other) and self.z == other.z

def __hash__(self):
return hash(repr(self))

def __repr__(self):
return factory_name(type(self))\
+ f"({repr(self.name)}{', z=' + repr(self.z) if self.z else ''})"

def __str__(self):
return str(self.name) + (
- self.z * '.l' if self.z < 0 else self.z * '.r')

def to_tree(self):
tree = super().to_tree()
if self.z:
tree['z'] = self.z
return tree

@classmethod
def from_tree(cls, tree):
name, z = tree['name'], tree.get('z', 0)
return cls(name=name, z=z)

[docs]
@factory
class Ty(closed.Ty):
"""
A rigid type is a closed type with rigid objects inside.

Parameters:
inside (tuple[Ob, ...]) : The objects inside the type.

Example
-------
>>> s, n = Ty('s'), Ty('n')
>>> assert n.l.r == n == n.r.l
>>> assert (s @ n).l == n.l @ s.l and (s @ n).r == n.r @ s.r
"""
def __setstate__(self, state):
if '_z' in state:  # Backward compatibility
del state['_z']
super().__setstate__(state)

[docs]
"""
Raise ``AxiomError`` if two rigid types are not adjoints.

Parameters:
other : The alleged right adjoint.
"""
if self.r != other and self != other.r:
if self.r != other:

@property
def l(self) -> Ty:
""" The left adjoint of the type. """
return self.factory(*[x.l for x in self.inside[::-1]])

@property
def r(self) -> Ty:
""" The right adjoint of the type. """
return self.factory(*[x.r for x in self.inside[::-1]])

@property
def z(self) -> int:
""" The winding number is only defined for types of length 1. """
assert_isatomic(self)
return self.inside[0].z

def __lshift__(self, other):
return self @ other.l

def __rshift__(self, other):
return self.r @ other

ob_factory = Ob

[docs]
@factory
class PRO(monoidal.PRO, Ty):
"""
A rigid PRO is a natural number ``n`` seen as a rigid type of length ``n``.

Parameters
----------
n : int
The length of the PRO type.
"""
__ambiguous_inheritance__ = (monoidal.PRO, )
l = r = property(lambda self: self)

class Layer(monoidal.Layer):
"""
A rigid layer is a monoidal layer that can be rotated.

Parameters:
left : The type on the left of the layer.
box : The box in the middle of the layer.
right : The type on the right of the layer.
more : More boxes and types to the right,
used by :meth:`Diagram.foliation`.
"""
def rotate(self, left=False):
return type(self)(*(x.l if left else x.r for x in list(self)[::-1]))

l = property(lambda self: self.rotate(left=True))
r = property(lambda self: self.rotate(left=False))

[docs]
@factory
class Diagram(closed.Diagram):
"""
A rigid diagram is a closed diagram
with :class:`Cup` and :class:`Cap` boxes.

Parameters:
inside (tuple[Layer, ...]) : The layers of the diagram.
dom (Ty) : The domain of the diagram, i.e. its input.
cod (Ty) : The codomain of the diagram, i.e. its output.

Example
-------
>>> I, n, s = Ty(), Ty('n'), Ty('s')
>>> Alice, jokes = Box('Alice', I, n), Box('jokes', I, n.r @ s)
>>> d = Alice >> Id(n) @ jokes >> Cup(n, n.r) @ Id(s)
>>> d.draw(figsize=(3, 2),
...        path='docs/_static/rigid/diagram-example.png')

.. image:: /_static/rigid/diagram-example.png
:align: center
"""
__ambiguous_inheritance__ = True

ty_factory = Ty
layer_factory = Layer

over = staticmethod(lambda base, exponent: base << exponent)
under = staticmethod(lambda base, exponent: exponent >> base)

@classmethod
def ev(cls, base: Ty, exponent: Ty, left=True) -> Diagram:
return base @ cls.cups(exponent.l, exponent) if left\
else cls.cups(exponent, exponent.r) @ base

[docs]
@classmethod
def cups(cls, left: Ty, right: Ty) -> Diagram:
"""
Construct a diagram of nested cups for types ``left`` and ``right``.

Parameters:
left : The type left of the cup.
right : Its right adjoint.

Example
-------
>>> a, b = Ty('a'), Ty('b')
>>> Diagram.cups(a.l @ b, b.r @ a).draw(figsize=(3, 1),\\
... margins=(0.3, 0.05), path='docs/_static/rigid/cups.png')

.. image:: /_static/rigid/cups.png
:align: center
"""
return nesting(cls, cls.cup_factory)(left, right)

[docs]
@classmethod
def caps(cls, left: Ty, right: Ty) -> Diagram:
"""
Construct a diagram of nested caps for types ``left`` and ``right``.

Parameters:
left : The type left of the cap.
right : Its left adjoint.

Example
-------
>>> a, b = Ty('a'), Ty('b')
>>> Diagram.caps(a.r @ b, b.l @ a).draw(figsize=(3, 1),\\
... margins=(0.3, 0.05), path='docs/_static/rigid/caps.png')

.. image:: /_static/rigid/caps.png
:align: center
"""
return nesting(cls, cls.cap_factory)(left, right)

[docs]
def curry(self, n=1, left=True) -> Diagram:
"""
The curry of a rigid diagram is obtained using cups and caps.

>>> from discopy.drawing import Equation as Eq
>>> x = Ty('x')
>>> g = Box('g', x @ x, x)
>>> Eq(Eq(g.curry(left=False), g, symbol="\$\\\\mapsfrom\$"),
...     g.curry(), symbol="\$\\\\mapsto\$").draw(
...         path="docs/_static/rigid/curry.png")

.. image:: /_static/rigid/curry.png
:align: center
"""
if left:
base, exponent = self.dom[:-n], self.dom[-n:]
return base @ self.caps(exponent, exponent.l) >> self @ exponent.l
base, exponent = self.dom[n:], self.dom[:n]
return self.caps(exponent.r, exponent) @ base >> exponent.r @ self

[docs]
def rotate(self, left=False):
"""
The half-turn rotation of a diagram, called with ``.l`` and ``.r``.

Example
-------
>>> from discopy import drawing
>>> x, y = map(Ty, "xy")
>>> f = Box('f', Ty(), x)
>>> g = Box('g', Ty(), x.r @ y)
>>> diagram = f @ g >> Cup(x, x.r) @ y
>>> LHS = drawing.Equation(diagram.l, diagram, symbol="\$\\\\mapsfrom\$")
>>> RHS = drawing.Equation(LHS, diagram.r, symbol="\$\\\\mapsto\$")
>>> RHS.draw(figsize=(8, 3), path='docs/_static/rigid/rotate.png')

.. image:: /_static/rigid/rotate.png
:align: center
"""
dom, cod = (x.l if left else x.r for x in (self.cod, self.dom))
inside = tuple(
layer.l if left else layer.r for layer in self.inside[::-1])
return self.factory(inside, dom, cod, _scan=False)

l = property(lambda self: self.rotate(left=True))
r = property(lambda self: self.rotate(left=False))

[docs]
def transpose(self, left=False):
"""
The transpose of a diagram, i.e. its composition with cups and caps.

Parameters:
left : Whether to transpose left or right.

Example
-------
>>> from discopy.drawing import Equation
>>> x, y = map(Ty, "xy")
>>> f = Box('f', x, y)
>>> LHS = Equation(f.transpose(left=True), f, symbol="\$\\\\mapsfrom\$")
>>> RHS = Equation(LHS, f.transpose(), symbol="\$\\\\mapsto\$")
>>> RHS.draw(figsize=(8, 3), path="docs/_static/rigid/transpose.png")

.. image:: /_static/rigid/transpose.png
"""
if left:
return self.cod.l @ self.caps(self.dom, self.dom.l)\
>> self.cod.l @ self @ self.dom.l\
>> self.cups(self.cod.l, self.cod) @ self.dom.l
return self.caps(self.dom.r, self.dom) @ self.cod.r\
>> self.dom.r @ self @ self.cod.r\
>> self.dom.r @ self.cups(self.cod, self.cod.r)

[docs]
def transpose_box(self, i, j=0, left=False):
"""
Transpose the box at index ``i``.

Parameters:
i : The vertical index of the box to transpose.
j : The horizontal index of the box to transpose, only needed if
the layer ``i`` has more than one box.
left : Whether to transpose left or right.

Example
-------
>>> from discopy.drawing import Equation
>>> x, y, z = Ty(*"xyz")
>>> f, g = Box('f', x, y), Box('g', y, z)
>>> d = (f @ g).foliation()
>>> transpose_l = d.transpose_box(0, 0, left=True)
>>> transpose_r = d.transpose_box(0, 1, left=False)
>>> LHS = Equation(transpose_l, d, symbol="\$\\\\mapsfrom\$")
>>> RHS = Equation(LHS, transpose_r, symbol="\$\\\\mapsto\$")
>>> RHS.draw(
...     figsize=(8, 3), path="docs/_static/rigid/transpose_box.png")

.. image:: /_static/rigid/transpose_box.png
"""
box = list(self.inside[i])[2 * j + 1]
transposed_box = (box.r if left else box.l).transpose(left)
top, bottom = self[:i], self[i + 1:]
left_boxes_and_types = list(self.inside[i])[:2 * j + 1]
right_boxes_and_types = list(self.inside[i])[2 * j + 2:]
left_layer, right_layer = [
self.id().tensor(
*(x if k % 2 else self.id(x) for k, x in enumerate(xs)))
for xs in [left_boxes_and_types, right_boxes_and_types]]
return top >> left_layer @ transposed_box @ right_layer >> bottom

[docs]
def snake_removal(self, left=False) -> Iterator[Diagram]:
"""
Return a generator which yields normalization steps.

Parameters:
left : Passed to :meth:`discopy.monoidal.Diagram.normalize`.

Example
-------
>>> from discopy.rigid import *
>>> n, s = Ty('n'), Ty('s')
>>> cup, cap = Cup(n, n.r), Cap(n.r, n)
>>> f, g, h = Box('f', n, n), Box('g', s @ n, n), Box('h', n, n @ s)
>>> diagram = g @ cap >> f[::-1] @ Id(n.r) @ f >> cup @ h
>>> for d in diagram.normalize(): print(d)
g >> f[::-1] >> n @ Cap(n.r, n) >> n @ n.r @ f >> Cup(n, n.r) @ n >> h
g >> f[::-1] >> n @ Cap(n.r, n) >> Cup(n, n.r) @ n >> f >> h
g >> f[::-1] >> f >> h
"""
from discopy import monoidal
from discopy.rigid import Cup, Cap

def follow_wire(diagram, i, j):
"""
Given a diagram, the index of a box i and the offset j of an output
wire, returns (i, j, obstructions) where:
- i is the index of the box which takes this wire as input, or
len(diagram) if it is connected to the bottom boundary.
- j is the offset of the wire at its bottom end.
- obstructions is a pair of lists of indices for the boxes on
the left and right of the wire we followed.
"""
left_obstruction, right_obstruction = [], []
while i < len(diagram) - 1:
i += 1
box, off = diagram.boxes[i], diagram.offsets[i]
if off <= j < off + len(box.dom):
return i, j, (left_obstruction, right_obstruction)
if off <= j:
j += len(box.cod) - len(box.dom)
left_obstruction.append(i)
else:
right_obstruction.append(i)
return len(diagram), j, (left_obstruction, right_obstruction)

def find_snake(diagram):
"""
Given a diagram, returns (cup, cap, obstructions, left_snake)
if there is a yankable pair, otherwise returns None.
"""
for cap in range(len(diagram)):
if not isinstance(diagram.boxes[cap], Cap):
continue
for left_snake, wire in [(True, diagram.offsets[cap]),
(False, diagram.offsets[cap] + 1)]:
cup, wire, obstructions =\
follow_wire(diagram, cap, wire)
not_yankable =\
cup == len(diagram)\
or not isinstance(diagram.boxes[cup], Cup)\
or left_snake and diagram.offsets[cup] + 1 != wire\
or not left_snake and diagram.offsets[cup] != wire
if not_yankable:
continue
return cup, cap, obstructions, left_snake
return None

def unsnake(diagram, cup, cap, obstructions, left_snake=False):
"""
Given a diagram and the indices for a cup and cap pair
and a pair of lists of obstructions on the left and right,
returns a new diagram with the snake removed.

A left snake is one of the form Id @ Cap >> Cup @ Id.
A right snake is one of the form Cap @ Id >> Id @ Cup.
"""
left_obstruction, right_obstruction = obstructions
if left_snake:
for box in left_obstruction:
diagram = diagram.interchange(box, cap)
yield diagram
for i, right_box in enumerate(right_obstruction):
if right_box < box:
right_obstruction[i] += 1
cap += 1
for box in right_obstruction[::-1]:
diagram = diagram.interchange(box, cup)
yield diagram
cup -= 1
else:
for box in left_obstruction[::-1]:
diagram = diagram.interchange(box, cup)
yield diagram
for i, right_box in enumerate(right_obstruction):
if right_box > box:
right_obstruction[i] -= 1
cup -= 1
for box in right_obstruction:
diagram = diagram.interchange(box, cap)
yield diagram
cap += 1
inside = diagram.inside[:cap] + diagram.inside[cup + 1:]
yield diagram.factory(
inside, diagram.dom, diagram.cod, _scan=False)

diagram = self
while True:
yankable = find_snake(diagram)
if yankable is None:
break
for _diagram in unsnake(diagram, *yankable):
yield _diagram
diagram = _diagram
for _diagram in monoidal.Diagram.normalize(diagram, left=left):
yield _diagram

normalize = snake_removal

[docs]
def normal_form(self, **params):
"""
Implements the normalisation of rigid categories,
see Dunn and Vicary :cite:`DunnVicary19`, definition 2.12.

Examples
--------
>>> a, b = Ty('a'), Ty('b')
>>> double_snake = Id(a @ b).transpose()
>>> two_snakes = Id(b).transpose() @ Id(a).transpose()
>>> double_snake == two_snakes
False
>>> *_, two_snakes_nf = monoidal.Diagram.normalize(two_snakes)
>>> assert double_snake == two_snakes_nf
>>> f = Box('f', a, b)

>>> a, b = Ty('a'), Ty('b')
>>> double_snake = Id(a @ b).transpose(left=True)
>>> snakes = Id(b).transpose(left=True) @ Id(a).transpose(left=True)
>>> double_snake == two_snakes
False
>>> *_, two_snakes_nf = monoidal.Diagram.normalize(
...     snakes, left=True)
>>> assert double_snake == two_snakes_nf
"""
return super().normal_form(**params)

[docs]
class Box(closed.Box, Diagram):
"""
A rigid box is a closed box in a rigid diagram.

Parameters:
name : The name of the box.
dom : The domain of the box, i.e. its input.
cod : The codomain of the box, i.e. its output.
z : The winding number of the box,
i.e. the number of half-turn rotations.

Example
-------
>>> a, b = Ty('a'), Ty('b')
>>> f = Box('f', a, b.l @ b)
>>> assert f.l.z == -1 and f.z == 0 and f.r.z == 1
>>> assert f.r.l == f == f.l.r
>>> assert f.l.l != f != f.r.r
"""
__ambiguous_inheritance__ = (closed.Box, )

def __setstate__(self, state):
if '_z' in state:  # Backward compatibility
self.z = state['_z']
del state['_z']
super().__setstate__(state)

def __init__(self, name: str, dom: Ty, cod: Ty, data=None, z=0, **params):
self.z = z
closed.Box.__init__(self, name, dom, cod, data=data, **params)

def __str__(self):
return cat.Box.__str__(self) if not self.z\
else str(self.r) + '.l' if self.z < 0 else str(self.l) + '.r'

def __repr__(self):
if self.is_dagger:
return closed.Box.__repr__(self)
return closed.Box.__repr__(self)[:-1] + (
f', z={self.z})' if self.z else ')')

def __eq__(self, other):
if isinstance(other, Box):
return cat.Box.__eq__(self, other) and self.z == other.z
return monoidal.Box.__eq__(self, other)

def __hash__(self):
return hash(repr(self))

def rotate(self, left=False):
dom, cod = (
getattr(x, 'l' if left else 'r') for x in (self.cod, self.dom))
z = self.z + (-1 if left else 1)
return type(self)(self.name, dom=dom, cod=cod,
data=self.data, is_dagger=self.is_dagger, z=z)

@property
def is_transpose(self):
""" Whether the box is an odd rotation of a generator. """
return not self.is_dagger and self.z and bool(self.z % 2)

def to_drawing(self):
result = super().to_drawing()
result.is_transpose = self.is_transpose
return result

[docs]
class Sum(closed.Sum, Box):
"""
A rigid sum is a closed sum that can be transposed.

Parameters:
terms (tuple[Diagram, ...]) : The terms of the formal sum.
dom (Ty) : The domain of the formal sum.
cod (Ty) : The codomain of the formal sum.
"""
__ambiguous_inheritance__ = (closed.Sum, )

def rotate(self, left=False) -> Sum:
if left:
return self.sum_factory(
tuple(term.l for term in self.terms), self.cod.l, self.dom.l)
return self.sum_factory(
tuple(term.r for term in self.terms), self.cod.r, self.dom.r)

[docs]
class Cup(BinaryBoxConstructor, Box):
"""
The counit of the adjunction for an atomic type.

Parameters:
left : The atomic type.
right : Its right adjoint.

Example
-------
>>> n = Ty('n')
>>> Cup(n, n.r).draw(figsize=(2,1), margins=(0.5, 0.05),\\
... path='docs/_static/rigid/cup.png')

.. image:: /_static/rigid/cup.png
:align: center
"""
def __init__(self, left: Ty, right: Ty):
assert_isatomic(left, Ty)
assert_isatomic(right, Ty)
name = f"Cup({left}, {right})"
dom, cod = left @ right, self.ty_factory()
BinaryBoxConstructor.__init__(self, left, right)
Box.__init__(self, name, dom, cod, draw_as_wires=True)

def rotate(self, left=False):
return self.cap_factory(self.right.l, self.left.l) if left\
else self.cap_factory(self.right.r, self.left.r)

[docs]
def dagger(self):
"""
The dagger of a rigid cup is ill-defined,
use a :class:`pivotal.Cup` instead.
"""
raise AxiomError("Rigid cups have no dagger, use pivotal instead.")

[docs]
class Cap(BinaryBoxConstructor, Box):
"""
The unit of the adjunction for an atomic type.

Parameters:
left : The atomic type.
right : Its left adjoint.

Example
-------
>>> n = Ty('n')
>>> Cap(n, n.l).draw(figsize=(2,1), margins=(0.5, 0.05),\\
... path='docs/_static/rigid/cap.png')

.. image:: /_static/rigid/cap.png
:align: center
"""
def __init__(self, left: Ty, right: Ty):
assert_isatomic(left, Ty)
assert_isatomic(right, Ty)
name = f"Cap({left}, {right})"
dom, cod = self.ty_factory(), left @ right
BinaryBoxConstructor.__init__(self, left, right)
Box.__init__(self, name, dom, cod, draw_as_wires=True)

def rotate(self, left=False):
return self.cup_factory(self.right.l, self.left.l) if left\
else self.cup_factory(self.right.r, self.left.r)

[docs]
def dagger(self):
"""
The dagger of a rigid cap is ill-defined,
use a :class:`pivotal.Cap` instead.
"""
raise AxiomError("Rigid caps have no dagger, use pivotal instead.")

[docs]
class Category(closed.Category):
"""
A rigid category is a monoidal category
with methods :code:`l`, :code:`r`, :code:`cups` and :code:`caps`.

Parameters:
ob : The type of objects.
ar : The type of arrows.
"""
ob, ar = Ty, Diagram

[docs]
class Functor(closed.Functor):
"""
A rigid functor is a closed functor that preserves cups and caps.

Parameters:
ob (Mapping[Ty, Ty]) : Map from atomic :class:`Ty` to :code:`cod.ob`.
ar (Mapping[Box, Diagram]) : Map from :class:`Box` to :code:`cod.ar`.
cod (Category) : The codomain of the functor.

Example
-------
>>> s, n = Ty('s'), Ty('n')
>>> Alice, Bob = Box("Alice", Ty(), n), Box("Bob", Ty(), n)
>>> loves = Box('loves', Ty(), n.r @ s @ n.l)
>>> love_box = Box('loves', n @ n, s)
>>> ob = {s: s, n: n}
>>> ar = {Alice: Alice, Bob: Bob}
>>> ar.update({loves: Cap(n.r, n) @ Cap(n, n.l) >> n.r @ love_box @ n.l})
>>> F = Functor(ob, ar)
>>> sentence = Alice @ loves @ Bob >> Cup(n, n.r) @ s @ Cup(n.l, n)
>>> assert F(sentence).normal_form() == Alice >> Id(n) @ Bob >> love_box

>>> from discopy.drawing import Equation
>>> Equation(sentence, F(sentence), symbol='\$\\\\mapsto\$').draw(
...     figsize=(5, 2), path='docs/_static/rigid/functor-example.png')

.. image:: /_static/rigid/functor-example.png
:align: center
"""
dom = cod = Category(Ty, Diagram)

def __call__(self, other):
if isinstance(other, Ty) or isinstance(other, Ob) and other.z == 0:
return super().__call__(other)
if isinstance(other, Ob):
return self(other.r).l if other.z < 0 else self(other.l).r
if isinstance(other, Cup):
return self.cod.ar.cups(self(other.dom[:1]), self(other.dom[1:]))
if isinstance(other, Cap):
return self.cod.ar.caps(self(other.cod[:1]), self(other.cod[1:]))
if isinstance(other, Box):
if not hasattr(other, "z") or not other.z:
return super().__call__(other)
z = other.z
for _ in range(abs(z)):
other = other.l if z > 0 else other.r
result = super().__call__(other)
for _ in range(abs(z)):
result = result.l if z < 0 else result.r
return result
return super().__call__(other)

def nesting(cls: type, factory: Callable) -> Callable[[Ty, Ty], Diagram]:
"""
Take a :code:`factory` for cups or caps of atomic types
and extends it recursively.

Parameters:
cls : A diagram factory, e.g. :class:`Diagram`.
factory :
A factory for cups (or caps) of atomic types, e.g. :class:`Cup`.
"""
def method(left: Ty, right: Ty) -> Diagram:
if len(left) == 0:
return cls.id(left[:0])
head, tail = factory(left[0], right[-1]), method(left[1:], right[:-1])
if head.dom:  # We are nesting cups.
return left[0] @ tail @ right[-1] >> head
return head >> left[0] @ tail @ right[-1]

return method

Diagram.cup_factory, Diagram.cap_factory, Diagram.sum_factory = Cup, Cap, Sum

Id = Diagram.id
```