Source code for discopy.python.additive

# -*- coding: utf-8 -*-

"""
The category of Python functions with disjoint union as monoidal product.

Summary
-------

.. autosummary::
    :template: class.rst
    :nosignatures:
    :toctree:

    Ty
    Function
"""

from __future__ import annotations

from functools import cache

from discopy.utils import assert_isinstance, tuplify
from discopy.python import function


""" Lists of types interpreted as disjoint union. """
Ty = tuple[type, ...]


[docs] class Function(function.Function): """ Python functions with disjoint union as tensor. Parameters: inside : The callable Python object inside the function. dom : The domain of the function, i.e. its list of input types. cod : The codomain of the function, i.e. its list of output types. .. admonition:: Summary .. autosummary:: tensor swap trace """ __ambiguous_inheritance__ = True ty_factory = Ty def __init__(self, inside, dom, cod, is_swap_of=None): self.is_swap_of = is_swap_of super().__init__(inside, dom, cod) def __call__(self, obj, tag=0): if self.type_checking: assert_isinstance(obj, self.dom[tag]) result = self.inside(obj, *(() if len(self.dom) == 1 else (tag, ))) if self.type_checking: obj, tag = (result, 0) if len(self.cod) == 1 else result assert_isinstance(obj, self.cod[tag]) return result
[docs] def tensor(self, other: Function) -> Function: """ The disjoint union of two functions, called with :code:`@`. Parameters: other : The other function to compose in sequence. """ dom, cod = self.dom + other.dom, self.cod + other.cod def inside(obj, tag=0): if tag < len(self.dom): result = self(obj, tag) obj, tag = (result, 0) if len(self.cod) == 1 else result else: result = other(obj, tag - len(self.dom)) obj, tag = (result, 0) if len(other.cod) == 1 else result tag += len(self.cod) return obj if len(cod) == 1 else (obj, tag) return Function(inside, dom, cod)
[docs] @staticmethod @cache def swap(x: Ty, y: Ty) -> Function: """ Swap the tags of a disjoint union from `x + y` to `y + x`. Parameters: x : The tuple of types on the left. y : The tuple of types on the right. """ x, y = map(tuplify, (x, y)) def inside(obj, tag=0): new_tag = tag + len(y) if tag < len(x) else tag - len(x) if len(x + y) == 1: assert new_tag == 0 return obj return (obj, new_tag) return Function(inside, dom=x + y, cod=y + x, is_swap_of=(x, y))
def dagger(self): if self.is_swap_of is None: raise ValueError return Function.swap(*self.is_swap_of[::-1])
[docs] def trace(self, n=1, left=False): """ The additive trace of a function. Parameters: n : The number of types to trace over. """ if left: raise NotImplementedError dom, cod = self.dom[:-n], self.cod[:-n] def inside(obj, tag=0): run_at_least_once = True while run_at_least_once or tag >= len(cod): run_at_least_once = False result = self(obj, tag) obj, tag = (result, 0) if len(self.cod) == 1 else result return obj if len(cod) == 1 else result return Function(inside, dom, cod)
@staticmethod def merge(x: Ty, n=2) -> Function: def inside(obj, tag=0): if len(x) == 1: assert tag % len(x) == 0 return obj return (obj, tag % len(x)) return Function(inside, n * x, x)
Swap = Function.braid = Function.swap Id = Function.twist = Function.id Merge = Function.merge class Category: ob = Ty ar = Function