Source code for discopy.closed

# -*- coding: utf-8 -*-

"""
The free closed monoidal category, i.e. with exponential objects.

Summary
-------

.. autosummary::
    :template: class.rst
    :nosignatures:
    :toctree:

    Ty
    Exp
    Over
    Under
    Diagram
    Box
    Eval
    Curry
    Sum
    Category
    Functor

Axioms
------

:meth:`Diagram.curry` and :meth:`Diagram.uncurry` are inverses.

>>> x, y, z = map(Ty, "xyz")
>>> f, g, h = Box('f', x, z << y), Box('g', x @ y, z), Box('h', y, x >> z)

>>> from discopy.drawing import Equation
>>> Equation(f.uncurry().curry(), f).draw(
...     path='docs/_static/closed/curry-left.png', margins=(0.1, 0.05))

.. image:: /_static/closed/curry-left.png
    :align: center

>>> Equation(h.uncurry(left=False).curry(left=False), h).draw(
...     path='docs/_static/closed/curry-right.png', margins=(0.1, 0.05))

.. image:: /_static/closed/curry-right.png
    :align: center

>>> Equation(
...     g.curry().uncurry(), g, g.curry(left=False).uncurry(left=False)).draw(
...         path='docs/_static/closed/uncurry.png')

.. image:: /_static/closed/uncurry.png
    :align: center
"""

from __future__ import annotations

from discopy import cat, monoidal
from discopy.cat import Category, factory
from discopy.utils import (
    factory_name,
    from_tree,
)


[docs] @factory class Ty(monoidal.Ty): """ A closed type is a monoidal type that can be exponentiated. Parameters: inside (Ty) : The objects inside the type. Note ---- We can exponentials of types. >>> x, y, z = Ty(*"xyz") >>> print((x ** y) ** z) ((x ** y) ** z) We can also distinguish left- and right-exponentials. >>> print((x >> y) << z) ((x >> y) << z) """ def __pow__(self, other: Ty) -> Ty: return Exp(self, other) if isinstance(other, Ty)\ else super().__pow__(other) def __lshift__(self, other): return Over(self, other) def __rshift__(self, other): return Under(other, self) def __repr__(self): return factory_name(type(self))\ + f"({', '.join(map(repr, self.inside))})" @property def left(self) -> Ty: return self.inside[0].left if self.is_exp else None @property def right(self) -> Ty: return self.inside[0].right if self.is_exp else None @property def is_exp(self): """ Whether the type is an :class:`Exp` object. Example ------- >>> x, y = Ty('x'), Ty('y') >>> assert (x ** y).is_exp and (x ** y @ Ty()).is_exp """ return len(self) == 1 and isinstance(self.inside[0], Exp) @property def is_under(self): """ Whether the type is an :class:`Under` object. Example ------- >>> x, y = Ty('x'), Ty('y') >>> assert (x >> y).is_under and (x >> y @ Ty()).is_under """ return len(self) == 1 and isinstance(self.inside[0], Under) @property def is_over(self): """ Whether the type is an :class:`Over` object. Example ------- >>> x, y = Ty('x'), Ty('y') >>> assert (x << y).is_over and (x << y @ Ty()).is_over """ return len(self) == 1 and isinstance(self.inside[0], Over)
[docs] class Exp(Ty, cat.Ob): """ A :code:`base` type to an :code:`exponent` type, called with :code:`**`. Parameters: base : The base type. exponent : The exponent type. """ __ambiguous_inheritance__ = (cat.Ob, ) def __init__(self, base: Ty, exponent: Ty): self.base, self.exponent = base, exponent super().__init__(self) @property def left(self): return self.exponent if isinstance(self, Under) else self.base @property def right(self): return self.base if isinstance(self, Under) else self.exponent def __eq__(self, other): if isinstance(other, type(self)): return (self.base, self.exponent) == (other.base, other.exponent) if isinstance(other, Exp): return False # Avoid infinite loop with Over(x, y) == Under(x, y). return isinstance(other, Ty) and other.inside == (self, ) def __hash__(self): return hash(repr(self)) def __str__(self): return f"({self.base} ** {self.exponent})" def __repr__(self): return factory_name(type(self))\ + f"({repr(self.base)}, {repr(self.exponent)})" def to_tree(self): return { 'factory': factory_name(type(self)), 'base': self.base.to_tree(), 'exponent': self.exponent.to_tree()} @classmethod def from_tree(cls, tree): return cls(*map(from_tree, (tree['base'], tree['exponent'])))
[docs] class Over(Exp): """ An :code:`exponent` type over a :code:`base` type, called with :code:`<<`. Parameters: base : The base type. exponent : The exponent type. """ def __str__(self): return f"({self.base} << {self.exponent})"
[docs] class Under(Exp): """ A :code:`base` type under an :code:`exponent` type, called with :code:`>>`. Parameters: base : The base type. exponent : The exponent type. """ def __str__(self): return f"({self.exponent} >> {self.base})"
[docs] @factory class Diagram(monoidal.Diagram): """ A closed diagram is a monoidal diagram with :class:`Curry` and :class:`Eval` boxes. Parameters: inside(Layer) : The layers inside the diagram. dom (Ty) : The domain of the diagram, i.e. its input. cod (Ty) : The codomain of the diagram, i.e. its output. """ __ambiguous_inheritance__ = True ty_factory = Ty
[docs] def curry(self, n=1, left=True) -> Diagram: """ Wrapper around :class:`Curry` called by :class:`Functor`. Parameters: n : The number of atomic types to curry. left : Whether to curry on the left or right. """ return self.curry_factory(self, n, left)
[docs] @classmethod def ev(cls, base: Ty, exponent: Ty, left=True) -> Eval: """ Wrapper around :class:`Eval` called by :class:`Functor`. Parameters: base : The base of the exponential type to evaluate. exponent : The exponent of the exponential type to evaluate. left : Whether to evaluate on the left or right. """ return cls.eval_factory( base << exponent if left else exponent >> base)
[docs] def uncurry(self: Diagram, left=True) -> Diagram: """ Uncurry a closed diagram by composing it with :meth:`Diagram.ev`. Parameters: left : Whether to uncurry on the left or right. """ base, exponent = self.cod.base, self.cod.exponent return self @ exponent >> self.ev(base, exponent, True) if left\ else exponent @ self >> self.ev(base, exponent, False)
[docs] class Box(monoidal.Box, Diagram): """ A closed box is a monoidal box in a closed diagram. Parameters: name (str) : The name of the box. dom (Ty) : The domain of the box, i.e. its input. cod (Ty) : The codomain of the box, i.e. its output. """ __ambiguous_inheritance__ = (monoidal.Box, )
[docs] class Eval(Box): """ The evaluation of an exponential type. Parameters: x : The exponential type to evaluate. """ def __init__(self, x: Exp): self.base, self.exponent = x.base, x.exponent self.left = isinstance(x, Over) dom, cod = (x @ self.exponent, self.base) if self.left\ else (self.exponent @ x, self.base) super().__init__("Eval" + str(x), dom, cod)
[docs] class Curry(monoidal.Bubble, Box): """ The currying of a closed diagram. Parameters: arg : The diagram to curry. n : The number of atomic types to curry. left : Whether to curry on the left or right. """ def __init__(self, arg: Diagram, n=1, left=True): self.n, self.left = n, left name = f"Curry({arg}, {n}, {left})" if left: dom = arg.dom[:len(arg.dom) - n] cod = arg.cod << arg.dom[len(arg.dom) - n:] else: dom, cod = arg.dom[n:], arg.dom[:n] >> arg.cod monoidal.Bubble.__init__( self, arg, dom=dom, cod=cod, drawing_name="$\\Lambda$") Box.__init__(self, name, dom, cod)
[docs] class Sum(monoidal.Sum, Box): """ A closed sum is a monoidal sum and a closed box. Parameters: terms (tuple[Diagram, ...]) : The terms of the formal sum. dom (Ty) : The domain of the formal sum. cod (Ty) : The codomain of the formal sum. """ __ambiguous_inheritance__ = (monoidal.Sum, )
Diagram.over, Diagram.under, Diagram.exp\ = map(staticmethod, (Over, Under, Exp)) Diagram.sum_factory = Sum Id = Diagram.id
[docs] class Category(monoidal.Category): """ A closed category is a monoidal category with methods :code:`exp` (:code:`over` and / or :code:`under`), :code:`ev` and :code:`curry`. Parameters: ob : The type of objects. ar : The type of arrows. """ ob, ar = Ty, Diagram
[docs] class Functor(monoidal.Functor): """ A closed functor is a monoidal functor that preserves evaluation and currying. Parameters: ob (Mapping[Ty, Ty]) : Map from atomic :class:`Ty` to :code:`cod.ob`. ar (Mapping[Box, Diagram]) : Map from :class:`Box` to :code:`cod.ar`. cod (Category) : The codomain of the functor. """ dom = cod = Category(Ty, Diagram) def __call__(self, other): for cls, attr in [(Over, "over"), (Under, "under"), (Exp, "exp")]: if isinstance(other, cls): method = getattr(self.cod.ar, attr) return method(self(other.base), self(other.exponent)) if isinstance(other, Curry): return self.cod.ar.curry( self(other.arg), len(self(other.cod.exponent)), other.left) if isinstance(other, Eval): return self.cod.ar.ev( self(other.base), self(other.exponent), other.left) return super().__call__(other)
def to_rigid(self): from discopy import rigid return Functor( ob=lambda x: rigid.Ty(x.inside[0].name), ar=lambda f: rigid.Box( f.name, Diagram.to_rigid(f.dom), Diagram.to_rigid(f.cod)), cod=rigid.Category())(self) Id = Diagram.id Diagram.to_rigid = to_rigid Diagram.curry_factory = Curry Diagram.eval_factory = Eval