Multiclass classification with GPflow

Alexander G. de G. Matthews and James Hensman, 2016

In [2]:
import gpflow
import tensorflow as tf
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')
%matplotlib inline
In [3]:
#make a one dimensional classification problem
np.random.seed(1)
X = np.random.rand(100,1)
K = np.exp(-0.5*np.square(X - X.T)/0.01) + np.eye(100)*1e-6
f = np.dot(np.linalg.cholesky(K), np.random.randn(100,3))

plt.figure(figsize=(12,6))
plt.plot(X, f, '.')
Out[3]:
[<matplotlib.lines.Line2D at 0x11e280c50>,
 <matplotlib.lines.Line2D at 0x11e280dd8>,
 <matplotlib.lines.Line2D at 0x11e28e048>]
../_images/notebooks_multiclass_2_1.png
In [4]:
Y = np.array(np.argmax(f, 1).reshape(-1,1), dtype=float)

Sparse Variational Gaussian approximation

In [5]:
m = gpflow.models.SVGP(
    X, Y, kern=gpflow.kernels.Matern32(1) + gpflow.kernels.White(1, variance=0.01),
    likelihood=gpflow.likelihoods.MultiClass(3),
    Z=X[::5].copy(), num_latent=3, whiten=True, q_diag=True)

m.kern.white.variance.trainable = False
m.feature.trainable = False
m.as_pandas_table()
Out[5]:
class prior transform trainable shape fixed_shape value
SVGP/kern/matern32/variance Parameter None +ve True () True 1.0
SVGP/kern/matern32/lengthscales Parameter None +ve True () True 1.0
SVGP/kern/white/variance Parameter None +ve False () True 0.01
SVGP/feature/Z Parameter None (none) False (20, 1) True [[0.417022004703], [0.0923385947688], [0.41919...
SVGP/q_mu Parameter None (none) True (20, 3) True [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...
SVGP/q_sqrt Parameter None +ve True (20, 3) True [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, ...
In [6]:
opt = gpflow.train.ScipyOptimizer()
opt.minimize(m)
INFO:tensorflow:Optimization terminated with:
  Message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
  Objective function value: 75.774953
  Number of iterations: 177
  Number of functions evaluations: 210
In [7]:
def plot(m):
    f = plt.figure(figsize=(12,6))
    a1 = f.add_axes([0.05, 0.05, 0.9, 0.6])
    a2 = f.add_axes([0.05, 0.7, 0.9, 0.1])
    a3 = f.add_axes([0.05, 0.85, 0.9, 0.1])


    xx = np.linspace(m.X.read_value().min(), m.X.read_value().max(), 200).reshape(-1,1)
    mu, var = m.predict_f(xx)
    mu, var = mu.copy(), var.copy()
    p, _ = m.predict_y(xx)

    a3.set_xticks([])
    a3.set_yticks([])


    a3.set_xticks([])
    a3.set_yticks([])

    for i in range(m.likelihood.num_classes):
        x = m.X.read_value()[m.Y.read_value().flatten()==i]
        points, = a3.plot(x, x*0, '.')
        color=points.get_color()
        a1.plot(xx, mu[:,i], color=color, lw=2)
        a1.plot(xx, mu[:,i] + 2*np.sqrt(var[:,i]), '--', color=color)
        a1.plot(xx, mu[:,i] - 2*np.sqrt(var[:,i]), '--', color=color)
        a2.plot(xx, p[:,i], '-', color=color, lw=2)

    a2.set_ylim(-0.1, 1.1)
    a2.set_yticks([0, 1])
    a2.set_xticks([])
In [8]:
plot(m)
../_images/notebooks_multiclass_8_0.png
In [16]:
m.kern.as_pandas_table()
Out[16]:
class prior transform trainable shape fixed_shape value
SGPMC/kern/matern32/variance Parameter Ga([ 1.],[ 1.]) +ve True () True 1.016544676754959
SGPMC/kern/matern32/lengthscales Parameter Ga([ 2.],[ 2.]) +ve True () True 0.1213019253370915
SGPMC/kern/white/variance Parameter None +ve True () True 0.009559277736662014

Sparse MCMC

In [10]:
with gpflow.defer_build():
    m = gpflow.models.SGPMC(X, Y,
                     kern=gpflow.kernels.Matern32(1, lengthscales=0.1) + gpflow.kernels.White(1, variance=0.01),
                     likelihood=gpflow.likelihoods.MultiClass(3),
                     Z=X[::5].copy(), num_latent=3)
    m.kern.matern32.variance.prior = gpflow.priors.Gamma(1.,1.)
    m.kern.matern32.lengthscales.prior = gpflow.priors.Gamma(2.,2.)
    m.kern.white.variance.trainables = False

m.compile()
In [11]:
opt = gpflow.train.ScipyOptimizer()
opt.minimize(m, maxiter=10)
INFO:tensorflow:Optimization terminated with:
  Message: b'STOP: TOTAL NO. of ITERATIONS EXCEEDS LIMIT'
  Objective function value: 150.701047
  Number of iterations: 11
  Number of functions evaluations: 14
In [12]:
hmc = gpflow.train.HMC()
samples = hmc.sample(m, num_samples=500, epsilon=0.04, lmax=15, logprobs=False) # verbose=True
In [13]:
def plot_from_samples(m, samples):
    f = plt.figure(figsize=(12,6))
    a1 = f.add_axes([0.05, 0.05, 0.9, 0.6])
    a2 = f.add_axes([0.05, 0.7, 0.9, 0.1])
    a3 = f.add_axes([0.05, 0.85, 0.9, 0.1])

    xx = np.linspace(m.X.read_value().min(), m.X.read_value().max(), 200).reshape(-1,1)

    Fpred, Ypred = [], []
    for s in samples[100::10].iterrows(): # burn 100, thin 10
        m.assign(s[1])
        Ypred.append(m.predict_y(xx)[0])
        Fpred.append(m.predict_f_samples(xx, 1).squeeze())

    for i in range(m.likelihood.num_classes):
        x = m.X.read_value()[m.Y.read_value().flatten() == i]
        points, = a3.plot(x, x*0, '.')
        color = points.get_color()
        for F in Fpred:
            a1.plot(xx, F[:,i], color=color, lw=0.2, alpha=1.0)
        for Y in Ypred:
            a2.plot(xx, Y[:,i], color=color, lw=0.5, alpha=1.0)

    a2.set_ylim(-0.1, 1.1)
    a2.set_yticks([0, 1])
    a2.set_xticks([])

    a3.set_xticks([])
    a3.set_yticks([])


plot_from_samples(m, samples)
../_images/notebooks_multiclass_14_0.png
In [14]:
samples.head()
Out[14]:
SGPMC/V SGPMC/feature/Z SGPMC/kern/matern32/lengthscales SGPMC/kern/matern32/variance SGPMC/kern/white/variance
0 [[1.77515741069, -0.916314713091, -0.858842697... [[0.328250501281], [0.125453980297], [0.419753... 0.121302 1.016545 0.009559
1 [[1.77515741069, -0.916314713091, -0.858842697... [[0.328250501281], [0.125453980297], [0.419753... 0.121302 1.016545 0.009559
2 [[1.77515741069, -0.916314713091, -0.858842697... [[0.328250501281], [0.125453980297], [0.419753... 0.121302 1.016545 0.009559
3 [[1.77515741069, -0.916314713091, -0.858842697... [[0.328250501281], [0.125453980297], [0.419753... 0.121302 1.016545 0.009559
4 [[1.77515741069, -0.916314713091, -0.858842697... [[0.328250501281], [0.125453980297], [0.419753... 0.121302 1.016545 0.009559
In [15]:
_ = plt.hist(np.vstack(samples['SGPMC/kern/matern32/lengthscales']).flatten(), 50, normed=True)
plt.xlabel('lengthscale')
Out[15]:
<matplotlib.text.Text at 0x1248b7c88>
../_images/notebooks_multiclass_16_1.png