# Source code for discopy.cat

```# -*- coding: utf-8 -*-

"""
The free (dagger) category
with formal sums, unary operators and symbolic variables.

Summary
-------

.. autosummary::
:template: class.rst
:nosignatures:
:toctree:

Ob
Arrow
Box
Id
Sum
Bubble
Category
Functor
Composable

.. autosummary::
:template: function.rst
:nosignatures:
:toctree:

factory
dumps

Axioms
------

We can create boxes with objects as domain and codomain:

>>> x, y, z = Ob('x'), Ob('y'), Ob('z')
>>> f, g, h = Box('f', x, y), Box('g', y, z), Box('h', z, x)

We can create arbitrary arrows with identity and composition:

>>> arrow = Arrow.id(x).then(f, g, h)
>>> assert arrow == f >> g >> h == h << g << f

We can create dagger functors from the free category to itself:

>>> ob = {x: z, y: y, z: x}
>>> ar = {f: g[::-1], g: f[::-1], h: h[::-1]}
>>> F = Functor(ob, ar)
>>> assert F(arrow) == (h >> f >> g)[::-1]

We can check the axioms of dagger (i.e. a contravariant involutive
identity-on-objects endofunctor):

>>> x, y, z = Ob('x'), Ob('y'), Ob('z')
>>> f, g = Box('f', x, y), Box('g', y, z)
>>> assert f[::-1][::-1] == f
>>> assert Id(x)[::-1] == Id(x)
>>> assert (f >> g)[::-1] == g[::-1] >> f[::-1]

We can check the axioms of dagger functors.

>>> assert F(Id(x)) == Id(F(x))
>>> assert F(f >> g) == F(f) >> F(g)
>>> assert F(f[::-1]) == F(f)[::-1]
>>> assert F(f.dom) == F(f).dom and F(f.cod) == F(f).cod

Functors are bubble-preserving.

>>> assert F(f.bubble()) == F(f).bubble()
"""

from __future__ import annotations

from dataclasses import dataclass
from functools import total_ordering, cached_property
from typing import (
Callable, Mapping, Iterable, Optional, Type, TYPE_CHECKING)

from discopy import messages, utils
from discopy.utils import (
factory,
factory_name,
from_tree,
rsubs,
unbiased,
MappingOrCallable,
Composable,
assert_isinstance,
assert_iscomposable,
assert_isparallel,
get_origin,
)

if TYPE_CHECKING:
import sympy

[docs]
@total_ordering
class Ob:
"""
An object with a string as :code:`name`.

Parameters:
name : The name of the object.

Example
-------
>>> x, x_, y = Ob('x'), Ob('x'), Ob('y')
>>> assert x == x_ and x != y
"""
def __setstate__(self, state):
if "name" not in state and "_name" in state:
state["name"] = state["_name"]
del state["_name"]
self.__dict__.update(state)

def __init__(self, name: str = ""):
assert_isinstance(name, str)
self.name = name

def __repr__(self):
return f"{factory_name(type(self))}({repr(self.name)})"

def __str__(self):
return str(self.name)

def __eq__(self, other):
return isinstance(other, type(self)) and self.name == other.name

def __hash__(self):
return hash(repr(self))

def __lt__(self, other):
return self.name < other.name

[docs]
def to_tree(self) -> dict:
"""
Serialise a DisCoPy object, see :func:`dumps`.

Example
-------
>>> Ob('x').to_tree()
{'factory': 'cat.Ob', 'name': 'x'}
"""
return {'factory': factory_name(type(self)), 'name': self.name}

[docs]
@classmethod
def from_tree(cls, tree: dict) -> Ob:
"""
Decode a serialised DisCoPy object, see :func:`loads`.

Parameters:
tree : DisCoPy serialisation.

Example
-------
>>> x = Ob('x')
>>> assert Ob.from_tree(x.to_tree()) == x
"""
return cls(tree['name'])

[docs]
@factory
class Arrow(Composable[Ob]):
"""
An arrow is a tuple of composable boxes :code:`inside` with a pair of
objects :code:`dom` and :code:`cod` as domain and codomain.

Parameters:
inside: The tuple of boxes inside an arrow.
dom: The domain of an arrow, i.e. its input.
cod: The codomain of an arrow, i.e. output
_scan: Whether to check composition.

.. autosummary::

id
then
dagger
bubble

Tip
---
For code clarity, it is recommended not to initialise arrows directly but
to use :meth:`Arrow.id` and :meth:`Arrow.then` instead. For example:

>>> x, y, z = map(Ob, "xyz")
>>> f, g = Box('f', x, y), Box('g', y, z)
>>> arrow = Arrow.id(x).then(f, g)  # Do this...
>>> arrow_ = Arrow((f, g), x, z)    # ...rather than that!
>>> assert arrow == arrow_

Note
----
Arrows can be indexed and sliced using square brackets. Indexing behaves
like that of strings, i.e. when we index an arrow we get an arrow back.

>>> assert (f >> g)[0] == f and (f >> g)[1] == g
>>> assert f[:0] == Arrow.id(f.dom)
>>> assert f[1:] == Arrow.id(f.cod)

Note
----
If ``dom`` or ``cod`` are not instances of ``ty_factory``, they are
automatically cast. This means one can use e.g. ``int`` instead of ``Ob``,
see :class:`monoidal.PRO`.
"""
ty_factory = Ob

def __setstate__(self, state):
if 'inside' not in state:  # Backward compatibility
self.dom, self.cod, self.inside = (
state['_dom'], state['_cod'], tuple(state['_boxes']))
del state['_dom'], state['_cod'], state['_boxes']
self.__dict__.update(state)

def __init__(self, inside: tuple[Box, ...], dom: Ob | str, cod: Ob | str,
_scan: bool = True) -> None:
ty_factory = type(self).ty_factory
dom = dom if isinstance(dom, ty_factory) else ty_factory(dom)
cod = cod if isinstance(cod, ty_factory) else ty_factory(cod)
self.dom, self.cod, self.inside = dom, cod, inside
if _scan:
for box in inside:
assert_isinstance(box, Box)
for f, g in zip((Id(dom), ) + inside, inside + (Id(cod), )):
assert_iscomposable(f, g)

def __iter__(self):
for box in self.inside:
yield box

def __getitem__(self, key):
if isinstance(key, slice):
if key.step == -1:
inside = tuple(box.dagger() for box in self.inside[key])
return self.factory(inside, self.cod, self.dom, _scan=False)
if (key.step or 1) != 1:
raise IndexError
inside = self.inside[key]
if not inside:
if (key.start or 0) >= len(self):
return self.id(self.cod)
if (key.start or 0) <= -len(self):
return self.id(self.dom)
return self.id(self.inside[key.start or 0].dom)
return self.factory(
inside, inside[0].dom, inside[-1].cod, _scan=False)
if isinstance(key, int):
if key >= len(self) or key < -len(self):
raise IndexError
if key < 0:
return self[len(self) + key]
return self[key:key + 1]
raise TypeError

def __len__(self):
return len(self.inside)

def __repr__(self):
if not self.inside:  # i.e. self is identity.
return f"{factory_name(type(self))}.id({repr(self.dom)})"
return f"{factory_name(self.factory)}(inside={repr(self.inside)}, " \
f"dom={repr(self.dom)}, cod={repr(self.cod)})"

def __str__(self):
return ' >> '.join(map(str, self.inside)) or f"Id({self.dom})"

def __eq__(self, other):
return isinstance(other, self.factory)\
and self.is_parallel(other) and self.inside == other.inside

def __hash__(self):
return hash(repr(self))

return self.sum_factory((self, )) + other

return self if other == 0 else NotImplemented

[docs]
@classmethod
def id(cls: Type[Arrow], dom: Optional[Ob] = None) -> Arrow:
"""
The identity arrow with the empty tuple inside, called with ``Id``.

Parameters:
dom : The domain (and codomain) of the identity.

Note
----
If ``dom`` is not provided, we use the default value of ``ty_factory``.

Example
-------
>>> assert Arrow.id() == Id() == Id(Ob())
>>> assert Arrow.id('x') == Id('x') == Id(Ob('x'))
"""
dom = cls.ty_factory() if dom is None else dom
return cls.factory((), dom, dom, _scan=False)

[docs]
def then(self, *others: Arrow) -> Arrow:
"""
Sequential composition, called with :code:`>>` and :code:`<<`.

Parameters:
others : The other arrows to compose.

Raises:
AxiomError : Whenever `self` and `others` do not compose.
"""
if any(isinstance(other, Sum) for other in others):
return self.sum_factory((self, )).then(*others)
inside, dom, cod = self.inside, self.dom, self.cod
for other in others:
assert_isinstance(other, self.factory)
assert_isinstance(self, other.factory)
inside, cod = inside + other.inside, other.cod
return self.factory(inside, dom, cod)

[docs]
def dagger(self) -> Arrow:
""" Contravariant involution, called with :code:`[::-1]`. """
return self[::-1]

[docs]
@classmethod
def zero(cls, dom, cod):
"""
Return the empty sum with a given domain and codomain.

Parameters:
dom : The domain of the empty sum.
cod : The codomain of the empty sum.
"""
return cls.sum_factory((), dom, cod)

[docs]
def bubble(self, *args, **kwargs) -> Bubble:
""" Unary operator on homsets. """
return self.bubble_factory(self, *args, **kwargs)

@property
def free_symbols(self) -> "set[sympy.Symbol]":
"""
The free :code:`sympy` symbols in an arrow.

Example
-------

>>> from sympy.abc import phi, psi
>>> x, y = Ob('x'), Ob('y')
>>> f = Box('f', x, y, data={"Alice": [phi + 1]})
>>> g = Box('g', y, x, data={"Bob": [psi / 2]})
>>> diagram = (f >> g).bubble() + Id(x)
>>> assert diagram.free_symbols == {phi, psi}
"""
return {x for box in self.inside for x in box.free_symbols}

[docs]
def subs(self, *args) -> Arrow:
"""
Substitute a variable by an expression.

Parameters:
var (sympy.Symbol) : The subtituted variable.
expr (sympy.Expr) : The substituting expression.

Tip
---
You can give a list of :code:`(var, expr)` for multiple substitution.

Example
-------
>>> from sympy.abc import phi, psi
>>> x, y = Ob('x'), Ob('y')
>>> f = Box('f', x, y, data={"Alice": [phi + 1]})
>>> g = Box('g', y, x, data={"Bob": [psi / 2]})
>>> assert (f >> g).subs(phi, phi + 1) == f.subs(phi, phi + 1) >> g
>>> assert (f >> g).subs(phi, 1) == f.subs(phi, 1) >> g
>>> assert (f >> g).subs(psi, 1) == f >> g.subs(psi, 1)
"""
inside = tuple(box.subs(*args) for box in self.inside)
return self.factory(inside, self.dom, self.cod, _scan=False)

[docs]
def lambdify(self, *symbols: "sympy.Symbol", **kwargs) -> Callable:
"""
Turn a symbolic diagram into a function from parameters to diagram.

Parameters:
symbols : The inputs of the function.
kwargs : Passed to :code:`sympy.lambdify`.

Example
-------
>>> from sympy.abc import phi, psi
>>> x, y, z = Ob('x'), Ob('y'), Ob('z')
>>> f, g = Box('f', x, y, data=phi), Box('g', y, z, data=psi)
>>> assert f.lambdify(psi)(42) == f
>>> assert (f >> g).lambdify(phi, psi)(42, 43)\\
...     == Box('f', x, y, data=42) >> Box('g', y, z, data=43)
"""
return lambda *xs: self.factory(
dom=self.dom, cod=self.cod, inside=tuple(
box.lambdify(*symbols, **kwargs)(*xs) for box in self.inside))

[docs]
def to_tree(self) -> dict:
"""
Serialise a DisCoPy arrow, see :func:`discopy.utils.dumps`.

Example
-------
>>> from pprint import PrettyPrinter
>>> pprint = PrettyPrinter(indent=4, width=70, sort_dicts=False).pprint
>>> f = Box('f', 'x', 'y', data=42)
>>> pprint((f >> f[::-1]).to_tree())
{   'factory': 'cat.Arrow',
'inside': [   {   'factory': 'cat.Box',
'name': 'f',
'dom': {'factory': 'cat.Ob', 'name': 'x'},
'cod': {'factory': 'cat.Ob', 'name': 'y'},
'data': 42},
{   'factory': 'cat.Box',
'name': 'f',
'dom': {'factory': 'cat.Ob', 'name': 'y'},
'cod': {'factory': 'cat.Ob', 'name': 'x'},
'is_dagger': True,
'data': 42}],
'dom': {'factory': 'cat.Ob', 'name': 'x'},
'cod': {'factory': 'cat.Ob', 'name': 'x'}}
"""
return {
'factory': factory_name(type(self)),
'inside': [box.to_tree() for box in self.inside],
'dom': self.dom.to_tree(), 'cod': self.cod.to_tree()}

[docs]
@classmethod
def from_tree(cls, tree: dict) -> Arrow:
"""
Decode a serialised DisCoPy arrow, see :func:`discopy.utils.loads`.

Parameters:
tree : DisCoPy serialisation.

Example
-------
>>> f = Box('f', 'x', 'y', data=42)
>>> assert Arrow.from_tree((f >> f[::-1]).to_tree()) == f >> f[::-1]
"""
dom, cod = map(from_tree, (tree['dom'], tree['cod']))
inside = tuple(map(from_tree, tree['inside']))
return cls(inside, dom, cod, _scan=False)

[docs]
@total_ordering
class Box(Arrow):
"""
A box is an arrow with a :code:`name` and the tuple of just itself inside.

Parameters:
name : The name of the box.
dom : The domain of the box, i.e. the input.
cod : The codomain of the box, i.e. the output.
data (any) : Extra data in the box, default is :code:`None`.
is_dagger : Whether the box is dagger.

Example
-------

>>> x, y = Ob('x'), Ob('y')
>>> f = Box('f', x, y, data=[42])
>>> assert f.inside == (f, )
"""
def __setstate__(self, state):
if 'inside' not in state:  # Backward compatibility
self.name, self.data, self.is_dagger = (
state['_name'], state['_data'], state['_dagger'])
del state['_name'], state['_data'], state['_dagger']
super().__setstate__(state)

def __init__(
self, name: str, dom: Ob, cod: Ob, data=None, is_dagger=False):
assert_isinstance(name, str)
self.name, self.data, self.is_dagger = name, data, is_dagger
Arrow.__init__(self, (self, ), dom, cod, _scan=False)

@cached_property
def free_symbols(self) -> "set[sympy.Symbol]":
def recursive_free_symbols(data):
if isinstance(data, Mapping):
data = data.values()
if isinstance(data, Iterable):
# Handles numpy 0-d arrays, which are actually not iterable.
if not hasattr(data, "shape") or data.shape != ():
return set().union(*map(recursive_free_symbols, data))
return getattr(data, "free_symbols", set())
return recursive_free_symbols(self.data)

def subs(self, *args) -> Box:
if not any(var in self.free_symbols for var in (
{var for var, _ in args[0]} if len(args) == 1 else {args[0]})):
return self
return type(self)(
self.name, self.dom, self.cod, is_dagger=self.is_dagger,
data=rsubs(self.data, *args))

def lambdify(self, *symbols: "sympy.Symbol", **kwargs) -> Callable:
if not any(x in self.free_symbols for x in symbols):
return lambda *xs: self
from sympy import lambdify
return lambda *xs: type(self)(
self.name, self.dom, self.cod, is_dagger=self.is_dagger,
data=lambdify(symbols, self.data, **kwargs)(*xs))

def dagger(self) -> Box:
return type(self)(
self.name, self.cod, self.dom,
data=self.data, is_dagger=not self.is_dagger)

def __getitem__(self, key):
if key == slice(None, None, -1):
return self.dagger()
return super().__getitem__(key)

def __repr__(self):
if self.is_dagger:
return repr(self.dagger()) + ".dagger()"
str_data = '' if self.data is None else ", data=" + repr(self.data)
return factory_name(type(self))\
+ f"({repr(self.name)}, {repr(self.dom)}, " \
f"{repr(self.cod)}{str_data})"

def __str__(self):
return str(self.name) + ("[::-1]" if self.is_dagger else '')

def __hash__(self):
return hash(Arrow.__repr__(self))

def __eq__(self, other):
if isinstance(other, Box):
return type(self) is type(other)\
and self.name == other.name\
and self.is_parallel(other)\
and self.is_dagger == other.is_dagger\
and bool(self.data == other.data)
return isinstance(other, Arrow)\
and self >> self.id(self.cod) == other  # cast box as diagram

def __lt__(self, other):
return self.name < other.name

def to_tree(self) -> dict:
tree = {
'factory': factory_name(type(self)),
'name': self.name,
'dom': self.dom.to_tree(),
'cod': self.cod.to_tree()}
if self.is_dagger:
tree['is_dagger'] = True
if self.data is not None:
tree['data'] = self.data
return tree

@classmethod
def from_tree(cls, tree: dict) -> Box:
name = tree['name']
dom, cod = map(from_tree, (tree['dom'], tree['cod']))
data, is_dagger = tree.get('data', None), 'is_dagger' in tree
return cls(name=name, dom=dom, cod=cod, data=data, is_dagger=is_dagger)

[docs]
class Sum(Box):
"""
A sum is a tuple of arrows :code:`terms` with the same domain and codomain.

Parameters:
terms : The terms of the formal sum.
dom : The domain of the formal sum.
cod : The codomain of the formal sum.

Example
-------
>>> x, y, z = Ob('x'), Ob('y'), Ob('z')
>>> f, g = Box('f', x, y), Box('g', y, z)
>>> unit = Sum((), x, y)
>>> assert f + unit == f == unit + f
>>> assert f >> (g + g) == (f >> g) + (f >> g) == (f + f) >> g

Important
---------
Domain and codomain are optional only if the terms are non-empty.

Note
----
The sum is non-commutative, i.e. :code:`Sum([f, g]) != Sum([g, f])`.
"""
def __init__(
self, terms: tuple[Arrow, ...], dom: Ob = None, cod: Ob = None):
if not terms and (dom is None or cod is None):
raise ValueError(messages.MISSING_TYPES_FOR_EMPTY_SUM)
dom = terms[0].dom if dom is None else dom
cod = terms[0].cod if cod is None else cod
for arrow in terms:
assert_isparallel(Sum((), dom, cod), arrow)
str_args = f", dom={repr(dom)}, cod={repr(cod)}" if not terms else ""
name = f"{factory_name(type(self))}(terms={repr(terms)}{str_args})"
self.terms = terms
super().__init__(name, dom, cod)

def __eq__(self, other):
if isinstance(other, Sum):
return (self.dom, self.cod, self.terms)\
== (other.dom, other.cod, other.terms)
return len(self.terms) == 1 and self.terms[0] == other

def __hash__(self):
return hash(repr(self))

def __repr__(self):
return self.name

def __str__(self):
return " + ".join(f"({arrow})" for arrow in self.terms)\
if self.terms else\
f"{factory_name(type(self))}((), {self.dom}, {self.cod})"

assert_isparallel(self, other)
other = other if isinstance(other, Sum)\
else self.sum_factory((other, ))
return self.sum_factory(self.terms + other.terms, self.dom, self.cod)

def __iter__(self):
for arrow in self.terms:
yield arrow

def __len__(self):
return len(self.terms)

@unbiased
def then(self, other):
other = other if isinstance(other, Sum)\
else self.sum_factory((other, ))
terms = tuple(f.then(g) for f in self.terms for g in other.terms)
return self.sum_factory(terms, self.dom, other.cod)

def dagger(self):
terms = tuple(f.dagger() for f in self.terms)
return self.sum_factory(terms, self.cod, self.dom)

@property
def free_symbols(self):
return {x for box in self.terms for x in box.free_symbols}

def subs(self, *args):
terms = tuple(f.subs(*args) for f in self.terms)
return self.sum_factory(terms, self.dom, self.cod)

def lambdify(self, *symbols, **kwargs):
return lambda *xs: self.sum_factory(
tuple(box.lambdify(*symbols, **kwargs)(*xs) for box in self.terms),
dom=self.dom, cod=self.cod)

def to_tree(self):
return {
'factory': factory_name(type(self)),
'terms': [t.to_tree() for t in self.terms],
'dom': self.dom.to_tree(),
'cod': self.cod.to_tree()}

@classmethod
def from_tree(cls, tree):
dom, cod = map(from_tree, (tree['dom'], tree['cod']))
terms = tuple(map(from_tree, tree['terms']))
return cls(terms=terms, dom=dom, cod=cod)

[docs]
class Bubble(Box):
"""
A bubble is a box with an arrow :code:`arg` inside and an optional pair of
objects :code:`dom` and :code:`cod`.

Parameters:
args : The arrows inside the bubble.
dom : The domain of the bubble, default is that of :code:`other`.
cod : The codomain of the bubble, default is that of :code:`other`.
name (str) : An optional name for the bubble.
method (str) : The method to call when a functor is applied to it.
kwargs : Passed to the `__init__` of :class:`Box`.
"""
def __init__(self, *args: Arrow, dom: Ob = None, cod: Ob = None,
name="Bubble", method="bubble", **kwargs):
dom, = set(arg.dom for arg in args) if dom is None else (dom, )
cod, = set(arg.cod for arg in args) if cod is None else (cod, )
self.args, self.method = args, method
Box.__init__(self, name, dom, cod, **kwargs)

@property
def arg(self):
""" The arrow inside the bubble if there is exactly one. """
if len(self.args) == 1:
return self.args[0]
raise ValueError(f"{self} has multiple args.")

@property
def is_id_on_objects(self):
""" Whether the bubble is identity on objects. """
return len(self.args) == 1 and (
self.dom, self.cod) == (self.arg.dom, self.arg.cod)

def __eq__(self, other):
if isinstance(other, Bubble):
return all(getattr(self, x) == getattr(other, x) for x in (
"args", "dom", "cod", "name", "method"))
return not isinstance(other, Box) and super().__eq__(other)

def __hash__(self):
return hash(tuple(getattr(self, x) for x in [
"args", "dom", "cod", "name", "method"]))

def __str__(self):
str_args = ",".join(map(str, self.args))
str_dom_cod = '' if self.is_id_on_objects else (
f'dom={self.dom}, cod={self.cod}')
return f"({str_args}).bubble({str_dom_cod})"

def __repr__(self):
repr_args = ", ".join(map(repr, self.args))
repr_dom_cod = "" if self.is_id_on_objects else (
f", dom={repr(self.dom)}, cod={repr(self.cod)}")
return factory_name(type(self)) + (f"({repr_args}{repr_dom_cod})")

@property
def free_symbols(self):
return super().free_symbols.union(*[f.free_symbols for f in self.args])

def to_tree(self):
return {
'factory': factory_name(type(self)),
'args': [f.to_tree() for f in self.args],
'dom': self.dom.to_tree(),
'cod': self.cod.to_tree()}

@classmethod
def from_tree(cls, tree):
args = [tree['arg']] if 'args' not in tree else tree['args']
dom, cod = map(from_tree, (tree['dom'], tree['cod']))
return cls(*map(from_tree, args), dom=dom, cod=cod)

[docs]
@dataclass
class Category:
"""
A category is just a pair of Python types :code:`ob` and :code:`ar` with
appropriate methods :code:`dom`, :code:`cod`, :code:`id` and :code:`then`.

Parameters:
ob : The objects of the category, default is :class:`Ob`.
ar : The arrows of the category, default is :class:`Arrow`.

Example
-------
>>> Category()
Category(cat.Ob, cat.Arrow)
>>> CAT
Category(cat.Category, cat.Functor)
"""
ob, ar = Ob, Arrow

def __init__(self, ob: type = None, ar: type = None):
self.ob, self.ar = (ob or type(self).ob), (ar or type(self).ar)

def __repr__(self):
return f"Category({factory_name(self.ob)}, {factory_name(self.ar)})"

def __eq__(self, other):
return isinstance(other, Category)\
and (self.ob, self.ar) == (other.ob, other.ar)

def __hash__(self):
return hash((self.ob, self.ar))

[docs]
class Functor(Composable[Category]):
"""
A functor is a pair of maps :code:`ob` and :code:`ar` and an optional
codomain category :code:`cod`.

Parameters:
ob : Mapping from :class:`Ob` to :code:`cod.ob`.
ar : Mapping from :class:`Box` to :code:`cod.ar`.
cod : The codomain, :code:`Category(Ob, Arrow)` by default.

Example
-------
>>> x, y, z = Ob('x'), Ob('y'), Ob('z')
>>> f, g = Box('f', x, y), Box('g', y, z)
>>> ob, ar = {x: y, y: z, z: y}, {f: g, g: g[::-1]}
>>> F = Functor(ob, ar)
>>> assert F(x) == y and F(f) == g

Tip
---
Both :code:`ob` and :code:`ar` can be a function rather than a dictionary.
In conjunction with :attr:`Box.data`, this can be used to create a
:class:`Functor` from a free category with infinitely many generators.

>>> ob = lambda x: x
>>> ar = lambda f: Box(f.name, f.dom, f.cod, data=f.data + 1)
>>> F = Functor(ob, ar)
>>> h = Box('h', x, x, data=42)
>>> assert F(h).data == 43 and F(F(h)).data == 44

If :attr:`Box.data` is a mutable object, then so can be the image of a
:class:`Functor` on it.

>>> ar = lambda f: f if all(f.data) else f[::-1]
>>> F = Functor(ob, ar)
>>> m = Box('m', x, x, data=[True])
>>> assert F(m) == m
>>> m.data.append(False)
>>> assert F(m) == m[::-1]
"""
dom = cod = Category(Ob, Arrow)

[docs]
@classmethod
def id(cls, dom: Category = None) -> Functor:
"""
The identity functor on a given category ``dom``.

Parameters:
dom : The domain of the functor.
"""
return cls(lambda x: x, lambda f: f, dom=dom, cod=dom)

[docs]
def then(self, other: Functor) -> Functor:
"""
The composition of functor with another.

Parameters:
other : The other functor with which to compose.

Note
----
Functor composition is unital only on the left. Indeed, we cannot check
equality of functors defined with functions instead of dictionaries.

Example
-------
>>> x, y = Ob('x'), Ob('y')
>>> F, G = Functor({x: y}, {}), Functor({y: x}, {})
>>> print(F >> G)
cat.Functor(ob={cat.Ob('x'): cat.Ob('x')}, ar={})
>>> assert F >> Functor.id() == F != Functor.id() >> F
>>> print(Functor.id() >> F)  # doctest: +ELLIPSIS
cat.Functor(ob=<function ...>, ar=...)
"""
assert_isinstance(other, Functor)
assert_iscomposable(self, other)
ob, ar = self.ob.then(other), self.ar.then(other)
return type(self)(ob, ar, dom=self.dom, cod=other.cod)

def __init__(
self,
ob: Mapping[Ob, Ob] | Callable[[Ob], Ob] | None = None,
ar: Mapping[Box, Arrow] | Callable[[Box], Arrow] | None = None,
dom: Category = None, cod: Category = None):
self.dom, self.cod = dom or type(self).dom, cod or type(self).cod
self.ob: MappingOrCallable[Ob, Ob] = MappingOrCallable(ob or {})
self.ar: MappingOrCallable[Box, Arrow] = MappingOrCallable(ar or {})

def __eq__(self, other):
return type(self) is type(other)\
and (self.ob, self.ar, self.cod) == (other.ob, other.ar, other.cod)

def __repr__(self):
cod_repr = "" if self.cod == type(self).cod else f", cod={self.cod}"
return factory_name(type(self))\
+ f"(ob={self.ob}, ar={self.ar}{cod_repr})"

def __call__(self, other):
if isinstance(other, Ob):
result, origin = self.ob[other], get_origin(self.cod.ob)
if isinstance(result, origin):
return result
return (result, ) if origin == tuple else self.cod.ob(result)
if isinstance(other, Sum):
return sum(map(self, other.terms),
self.cod.ar.zero(self(other.dom), self(other.cod)))
if isinstance(other, Bubble) and hasattr(self.cod.ar, other.method):
dom, cod = map(self, (other.dom, other.cod))
return getattr(self.cod.ar, other.method)(
*map(self, other.args), dom=dom, cod=cod)
if isinstance(other, Box) and other.is_dagger:
return self(other.dagger()).dagger()
if isinstance(other, Box):
result = self.ar[other]
# This allows some nice syntactic sugar for the ar mapping.
return result if isinstance(result, self.cod.ar)\
else self.cod.ar(result, self(other.dom), self(other.cod))
assert_isinstance(other, Arrow)
result = self.cod.ar.id(self(other.dom))
for box in other.inside:
result = result >> self(box)
return result

Arrow.sum_factory = Sum
Arrow.bubble_factory = Bubble
CAT = Category(Category, Functor)
Id = Arrow.id
```