Source code for discopy.python

# -*- coding: utf-8 -*-

"""
The category of Python functions with tuple as monoidal product.

Summary
-------

.. autosummary::
    :template: class.rst
    :nosignatures:
    :toctree:

    Ty
    Function

.. admonition:: Functions

    .. autosummary::
        :template: function.rst
        :nosignatures:
        :toctree:

        exp
        tuplify
        untuplify
        is_tuple
"""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from contextlib import contextmanager

from discopy.cat import Composable, assert_iscomposable, assert_isinstance
from discopy.utils import (
    tuplify, untuplify, Whiskerable, classproperty, get_origin)


Ty = tuple[type, ...]


[docs] def exp(base: Ty, exponent: Ty) -> Ty: """ The exponential of a tuple of Python types by another. Parameters: base (python.Ty) : The base type. exponent (python.Ty) : The exponent type. """ return (Callable[list(exponent), tuple[base]], )
[docs] def is_tuple(typ: type) -> bool: """ Whether a given type is tuple or a paramaterised tuple. Parameters: typ : The type to check for equality with tuple. """ return get_origin(typ) is tuple
[docs] @dataclass class Function(Composable[Ty], Whiskerable): """ A function is a callable :code:`inside` with a pair of types :code:`dom` and :code:`cod`. Parameters: inside : The callable Python object inside the function. dom : The domain of the function, i.e. its input type. cod : The codomain of the function, i.e. its output type. .. admonition:: Summary .. autosummary:: id then tensor swap copy discard ev curry uncurry fix trace """ inside: Callable dom: Ty cod: Ty type_checking = True def __init__(self, inside: Callable, dom: Ty, cod: Ty): dom, cod = map(tuplify, (dom, cod)) self.inside, self.dom, self.cod = inside, dom, cod
[docs] def id(dom: Ty) -> Function: """ The identity function on a given tuple of types :code:`dom`. Parameters: dom (python.Ty) : The typle of types on which to take the identity. """ return Function(lambda *xs: untuplify(xs), dom, dom)
[docs] def then(self, other: Function) -> Function: """ The sequential composition of two functions, called with :code:`>>`. Parameters: other : The other function to compose in sequence. """ assert_iscomposable(self, other) return Function( lambda *args: other(*tuplify(self(*args))), self.dom, other.cod)
@classproperty @contextmanager def no_type_checking(cls): tmp, cls.type_checking = cls.type_checking, False try: yield finally: cls.type_checking = tmp def __call__(self, *xs): if self.type_checking: if len(xs) != len(self.dom): raise ValueError for (x, t) in zip(xs, self.dom): callable(x) or assert_isinstance(x, t) ys = self.inside(*xs) if self.type_checking: if len(self.cod) != 1 and ( not isinstance(ys, tuple) or len(self.cod) != len(ys)): raise RuntimeError for (y, t) in zip(tuplify(ys), self.cod): callable(y) or assert_isinstance(y, t) return ys
[docs] def tensor(self, other: Function) -> Function: """ The parallel composition of two functions, called with :code:`@`. Parameters: other : The other function to compose in sequence. """ def inside(*xs): left, right = xs[:len(self.dom)], xs[len(self.dom):] return untuplify(tuplify(self(*left)) + tuplify(other(*right))) return Function(inside, self.dom + other.dom, self.cod + other.cod)
[docs] @staticmethod def swap(x: Ty, y: Ty) -> Function: """ The function for swapping two tuples of types :code:`x` and :code:`y`. Parameters: x : The tuple of types on the left. y : The tuple of types on the right. """ def inside(*xs): return untuplify(tuplify(xs)[len(x):] + tuplify(xs)[:len(x)]) return Function(inside, dom=x + y, cod=y + x)
braid = swap
[docs] @staticmethod def copy(x: Ty, n=2) -> Function: """ The function for making :code:`n` copies of a tuple of types :code:`x`. Parameters: x : The tuple of types to copy. n : The number of copies. """ return Function(lambda *xs: n * xs, dom=x, cod=n * x)
[docs] @staticmethod def discard(dom: Ty) -> Function: """ The function discarding a tuple of types, i.e. making zero copies. Parameters: dom : The tuple of types to discard. """ return Function.copy(dom, 0)
[docs] @staticmethod def ev(base: Ty, exponent: Ty, left=True) -> Function: """ The evaluation function, i.e. take a function and apply it to an argument. Parameters: base : The output type. exponent : The input type. left : Whether to take the function on the left or right. """ if left: dom, cod = Function.exp(base, exponent) + exponent, base return Function(lambda f, *xs: f(*xs), dom, cod) dom, cod = exponent + Function.exp(base, exponent), base return Function(lambda *xs: xs[-1](*xs[:-1]), dom, cod)
[docs] def curry(self, n=1, left=True) -> Function: """ Currying, i.e. turn a binary function into a function-valued function. Parameters: n : The number of types to curry. left : Whether to curry on the left or right. """ if left: dom = self.dom[:len(self.dom) - n] cod = Function.exp(self.cod, self.dom[len(self.dom) - n:]) else: dom, cod = self.dom[n:], Function.exp(self.cod, self.dom[:n]) return Function(dom=dom, cod=cod, inside=lambda *xs: lambda *ys: self(*(xs + ys) if left else (ys + xs)))
[docs] def uncurry(self, left=True) -> Function: """ Uncurrying, i.e. turn a function-valued function into a binary function. Parameters: left : Whether to uncurry on the left or right. """ base, exponent = self.cod[0].__args__[-1], self.cod[0].__args__[:-1] base = tuple(base.__args__) if is_tuple(base) else (base, ) return self @ exponent >> Function.ev(base, exponent) if left\ else exponent @ self >> Function.ev(base, exponent, left=False)
[docs] def fix(self, n=1) -> Function: """ The parameterised fixed point of a function. Parameters: n : The number of types to take the fixed point over. """ def inside(*xs, y=None): result = self.inside(*xs + (() if y is None else (y, ))) return y if result == y else inside(*xs, y=result) return self if n == 0\ else Function(inside, self.dom[:-1], self.cod).fix(n - 1)
[docs] def trace(self, n=1, left=False): """ The trace of a function. Parameters: n : The number of types to trace over. """ if left: raise NotImplementedError dom, cod, traced = self.dom[:-n], self.cod[:-n], self.dom[-n:] fixed = (self >> self.discard(cod) @ traced).fix() return self.copy(dom) >> dom @ fixed\ >> self >> cod @ self.discard(traced)
exp = over = under = staticmethod(lambda x, y: exp(x, y))
@dataclass class Dict(Composable[int], Whiskerable): inside: dict[int, int] dom: int cod: int def __getitem__(self, key): return self.inside[key] @staticmethod def id(x: int = 0): return Dict({i: i for i in range(x)}, x, x) def then(self, other: Dict) -> Dict: inside = {i: self[other[i]] for i in range(other.cod)} return Dict(inside, self.dom, other.cod) def tensor(self, other: Dict) -> Dict: inside = {i: self[i] for i in range(self.cod)} inside.update({ self.cod + i: self.dom + other[i] for i in range(other.cod)}) return Dict(inside, self.dom + other.dom, self.cod + other.cod) @staticmethod def swap(x: int, y: int) -> Dict: inside = {i: i + x if i < x else i - x for i in range(x + y)} return Dict(inside, x + y, x + y) @staticmethod def copy(x: int, n=2) -> Dict: return Dict({i: i % x for i in range(n * x)}, x, n * x) class Category: ob = Ty ar = Function