"""
Implements the semantic category :py:class:`.Matrix`.
In this category, a box with domain :py:class:`PRO(n) <.monoidal.PRO>`
and codomain :py:class:`PRO(m) <.monoidal.PRO>` represents
an :math:`n \\times m` matrix.
The ``>>`` and ``<<`` operations correspond to matrix multiplication
and ``@`` operation corresponds to the direct sum of matrices:
.. math::
\\mathbf{A} \\oplus \\mathbf{B}
= \\begin{pmatrix} \\mathbf{A} & 0 \\\\ 0 & \\mathbf{B} \\end{pmatrix}
Example
-------
>>> x = Matrix(PRO(2), PRO(1), [2, 4])
>>> x.array
array([[2],
[4]])
>>> x @ x
Matrix(dom=PRO(4), cod=PRO(2), array=[2, 0, 4, 0, 0, 2, 0, 4])
>>> (x @ x).array
array([[2, 0],
[4, 0],
[0, 2],
[0, 4]])
:py:class:`.Matrix` can be used to evaluate
:py:class:`.optics.Diagram` s from :py:mod:`.quantum.optics`.
"""
from discopy import messages, monoidal
from discopy.cat import AxiomError
from discopy.monoidal import PRO, Sum
from discopy.tensor import array2string
from discopy.utils import unbiased
import numpy as np
[docs]class Matrix(monoidal.Box):
""" Implements a matrix with dom, cod and numpy array.
Examples
--------
>>> m = Matrix(PRO(2), PRO(2), [0, 1, 1, 0])
>>> v = Matrix(PRO(1), PRO(2), [0, 1])
>>> assert (str(v) == repr(v)
... == 'Matrix(dom=PRO(1), cod=PRO(2), array=[0, 1])')
>>> v >> m >> v.dagger()
Matrix(dom=PRO(1), cod=PRO(1), array=[0])
>>> m + m
Matrix(dom=PRO(2), cod=PRO(2), array=[0, 2, 2, 0])
>>> assert m.then(m, m, m, m) == m == m >> m >> m >> m >> m
The monoidal product for :py:class:`.Matrix` is the direct sum:
>>> x = Matrix(PRO(2), PRO(1), [2, 4])
>>> x.array
array([[2],
[4]])
>>> x @ x
Matrix(dom=PRO(4), cod=PRO(2), array=[2, 0, 4, 0, 0, 2, 0, 4])
>>> (x @ x).array
array([[2, 0],
[4, 0],
[0, 2],
[0, 4]])
"""
def __init__(self, dom, cod, array):
self._array = np.array(array).reshape((len(dom), len(cod)))
super().__init__("O Tensor", dom, cod)
@property
def array(self):
""" Numpy array. """
return self._array
def __repr__(self):
return "Matrix(dom={!r}, cod={!r}, array={})".format(
self.dom, self.cod, array2string(self.array.flatten()))
def __str__(self):
return repr(self)
@unbiased
def then(self, other):
if isinstance(other, Sum):
return monoidal.Diagram.then(self, other)
if not isinstance(other, Matrix):
raise TypeError(messages.type_err(Matrix, other))
if self.cod != other.dom:
raise AxiomError(messages.does_not_compose(self, other))
array = np.matmul(self.array, other.array)
return Matrix(self.dom, other.cod, array)
def tensor(self, *others):
if len(others) != 1 or any(isinstance(other, Sum) for other in others):
return monoidal.Diagram.tensor(self, *others)
other = others[0]
if not isinstance(other, Matrix):
raise TypeError(messages.type_err(Matrix, other))
dom, cod = self.dom @ other.dom, self.cod @ other.cod
array = block_diag(self.array, other.array)
return Matrix(dom, cod, array)
def __add__(self, other):
if other == 0:
return self
if not isinstance(other, Matrix):
raise TypeError(messages.type_err(Matrix, other))
if (self.dom, self.cod) != (other.dom, other.cod):
raise AxiomError(messages.cannot_add(self, other))
return Matrix(self.dom, self.cod, self.array + other.array)
def __radd__(self, other):
return self.__add__(other)
def dagger(self):
array = np.conjugate(np.transpose(self.array))
return Matrix(self.cod, self.dom, array)
@staticmethod
def id(dom=PRO()):
return Matrix(dom, dom, np.identity(len(dom)))
@staticmethod
def swap(left, right):
if left == PRO(1) and right == PRO(1):
return Matrix(left @ right, left @ right, np.array([0, 1, 1, 0]))
raise NotImplementedError
def block_diag(*arrs):
"""Compute the block diagonal of matrices, taken from scipy."""
if arrs == ():
arrs = ([],)
arrs = [np.atleast_2d(a) for a in arrs]
bad_args = [k for k in range(len(arrs)) if arrs[k].ndim > 2]
if bad_args:
raise ValueError("arguments in the following positions have dimension "
"greater than 2: %s" % bad_args)
shapes = np.array([a.shape for a in arrs])
out_dtype = np.find_common_type([arr.dtype for arr in arrs], [])
out = np.zeros(np.sum(shapes, axis=0), dtype=out_dtype)
r, c = 0, 0
for i, (rr, cc) in enumerate(shapes):
out[r:r + rr, c:c + cc] = arrs[i]
r += rr
c += cc
return out