alice-loves-bob#

This experiment was presented at QNLP 2019.

[1]:
from discopy import Ty, Word

s, n = Ty('s'), Ty('n')
Alice, loves, Bob = Word('Alice', n), Word('loves', n.r @ s @ n.l), Word('Bob', n)

print("Vocabulary:\n{}".format('\n'.join(map(repr, [Alice, loves, Bob]))))
Vocabulary:
Word('Alice', Ty('n'))
Word('loves', Ty(Ob('n', z=1), 's', Ob('n', z=-1)))
Word('Bob', Ty('n'))
[2]:
from discopy import Diagram, Id, Cup
from discopy.grammar import draw

grammar = Cup(n, n.r) @ Id(s) @ Cup(n.l, n)
parsing = {"{} {} {}.".format(subj, verb, obj): subj @ verb @ obj >> grammar
            for subj in [Alice, Bob] for verb in [loves] for obj in [Alice, Bob]}

diagram = parsing['Alice loves Bob.']
print("Diagram for 'Alice loves Bob':")
draw(diagram, draw_type_labels=True)
Diagram for 'Alice loves Bob':
../_images/notebooks_alice-loves-bob_2_1.png
[3]:
sentences = list(parsing.keys())
print("Grammatical sentences:\n{}".format('\n'.join(sentences)))
Grammatical sentences:
Alice loves Alice.
Alice loves Bob.
Bob loves Alice.
Bob loves Bob.
[4]:
from discopy.quantum import Ket, H, Rx, CX, sqrt

def verb_ansatz(phase):
    return Ket(0, 0) >> H @ sqrt(2) @ Rx(phase) >> CX

print(verb_ansatz(0).eval())
Tensor(dom=Dim(1), cod=Dim(2, 2), array=[1.+0.j, 0.+0.j, 0.+0.j, 1.+0.j])
[5]:
(verb_ansatz(0) >> verb_ansatz(0.5).dagger()).measure()
[5]:
array([1.49975978e-32])
[6]:
.25 * (verb_ansatz(0) >> verb_ansatz(0).dagger()).measure()
[6]:
array([1.])
[7]:
from discopy import CircuitFunctor, qubit

ob = {s: 0, n: 1}
ar = lambda params: {
    Alice: Ket(0), Bob: Ket(1),
    loves: verb_ansatz(params['loves'])}

F = lambda params: CircuitFunctor(ob, ar(params))

params0 = {'loves': 0.5}

print("Circuit for 'Alice loves Bob':")
F(params0)(parsing['Alice loves Bob.']).draw(
    aspect='auto', draw_type_labels=False, figsize=(5, 5))
Circuit for 'Alice loves Bob':
../_images/notebooks_alice-loves-bob_7_1.png
[8]:
print("Amplitude for 'Alice loves Bob':")
print(F(params0)(parsing['Alice loves Bob.']).eval())
Amplitude for 'Alice loves Bob':
Tensor(dom=Dim(1), cod=Dim(1), array=[0.-1.j])
[9]:
evaluate = lambda F, sentence: F(parsing[sentence]).measure()

print("Does Alice love Bob?\n{}".format(
    "Yes" if evaluate(F(params0), 'Alice loves Bob.') else "No"))
Does Alice love Bob?
Yes
[10]:
corpus = {sentence: evaluate(F(params0), sentence) for sentence in sentences}

epsilon = 1e-2

print("True sentences:\n{}\n".format('\n'.join(sentence
    for sentence, probability in corpus.items() if probability > 1 - epsilon)))
print("False sentences:\n{}".format('\n'.join(sentence
    for sentence, probability in corpus.items() if probability < epsilon)))
True sentences:
Alice loves Bob.
Bob loves Alice.

False sentences:
Alice loves Alice.
Bob loves Bob.
[11]:
import jax.numpy as np
from jax import grad

from discopy import Tensor

Tensor.np = np  # This ensures we can differentiate Circuit.eval with respect to phases.

def mean_squared(y_true, y_pred):
    return np.mean((np.array(y_true) - np.array(y_pred)) ** 2)

f = lambda phase: mean_squared(*zip(*[
    (evaluate(F(params0), sentence), evaluate(F({'loves': phase}), sentence))
     for sentence in sentences]))

grad(f)(0.75)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[11]:
DeviceArray(3.141591, dtype=float32)
[12]:
from jax import vmap

x = np.arange(0.0, 1.0, 0.01)
y = vmap(f)(x)
dy = vmap(grad(f))(x)
[13]:
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (12, 5)

plt.subplot(2, 1, 1)
plt.plot(x, y)
plt.title("functorial landscape for 'loves'")
plt.ylabel('loss')

plt.subplot(2, 1, 2)
plt.plot(x, dy)
plt.xlabel('phase')
plt.ylabel('grad')
plt.show()
../_images/notebooks_alice-loves-bob_13_0.png
[14]:
from sklearn.model_selection import train_test_split

sentence_train, sentence_test = train_test_split(sentences, test_size=0.25, random_state=42)

print("Training set:\n{}\n".format('\n'.join(sentence_train)))
print("Testing set:\n{}".format('\n'.join(map(
    lambda x: 'Does ' + x.replace('loves', 'love').replace('.', '?'), sentence_test))))
Training set:
Bob loves Bob.
Alice loves Alice.
Bob loves Alice.

Testing set:
Does Alice love Bob?
[15]:
from jax import jit
from time import time

loss = {sentence: lambda params: mean_squared(corpus[sentence], evaluate(F(params), sentence))
                 for sentence in sentences}

@jit
def testing_loss(params):
    return np.mean(np.array([loss[sentence](params) for sentence in sentence_test]))

start = time()
print("{1:.3f} seconds to compile the testing loss of params0 ({0})".format(
    testing_loss(params0), time() - start))

step_size = 1e-2

@jit
def update(params):
    for sentence in sentence_train:
        params = {word: phase - step_size * grad(loss[sentence])(params)[word]
                  for word, phase in params.items()}
    return params

start = time()
print("{1:.3f} seconds to compile the update function just in time:\n{0}".format(
    update(params0), time() - start))
0.189 seconds to compile the testing loss of params0 (3.650718323258376e-30)
0.796 seconds to compile the update function just in time:
{'loves': DeviceArray(0.5, dtype=float32)}
[16]:
from random import random, seed; seed(420)

print("Random parameter initialisation...")

params = {'loves': random()}
print("Initial parameters: {}".format(params))

print("Initial testing loss: {:.5f}\n".format(testing_loss(params)))

epochs, iterations = 7, 10

for epoch in range(epochs):
    start = time()
    for i in range(iterations):
        params = update(params)

    print("Epoch {} ({:.3f} milliseconds)".format(epoch, 1e3 * (time() - start)))
    print("Testing loss: {:.5f}".format(testing_loss(params)))
    print("params['loves'] = {:.3f}\n".format(params['loves']))
Random parameter initialisation...
Initial parameters: {'loves': 0.026343380459525556}
Initial testing loss: 0.98638

Epoch 0 (756.192 milliseconds)
Testing loss: 0.00321
params['loves'] = 0.424

Epoch 1 (2.225 milliseconds)
Testing loss: 0.00060
params['loves'] = 0.450

Epoch 2 (1.870 milliseconds)
Testing loss: 0.00024
params['loves'] = 0.460

Epoch 3 (1.807 milliseconds)
Testing loss: 0.00013
params['loves'] = 0.466

Epoch 4 (1.769 milliseconds)
Testing loss: 0.00008
params['loves'] = 0.470

Epoch 5 (1.998 milliseconds)
Testing loss: 0.00005
params['loves'] = 0.473

Epoch 6 (1.823 milliseconds)
Testing loss: 0.00004
params['loves'] = 0.475

[17]:
print("Does Alice love Bob?")
print("Yes" if evaluate(F(params), 'Alice loves Bob.') > 1 - epsilon else "No")
Does Alice love Bob?
Yes