set_backend

Contents

set_backend#

discopy.matrix.set_backend(name)[source]#

Override the default backend.

Parameters:

name (Literal['numpy'] | ~typing.Literal['jax'] | ~typing.Literal['pytorch'] | ~typing.Literal['tensorflow']) – The name of the backend.

Return type:

None

Example

>>> set_backend('jax')
>>> assert type(Matrix([0, 1, 1, 0], 2, 2).array).__module__\
...     == 'jaxlib.xla_extension'
>>> set_backend('numpy')
>>> assert type(Matrix([0, 1, 1, 0], 2, 2).array).__module__\
...     == 'numpy'