Ordinal regressionΒΆ

Ordinal regression aims to fit a model to some data \((X, Y)\), where \(Y\) is an ordinal variable. To do so, we use a VPG model with a specific likelihood (gpflow.likelihoods.Ordinal).

[1]:
import gpflow

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 6)

np.random.seed(123)  # for reproducibility
[2]:
# make a one-dimensional ordinal regression problem

# This function generates a set of inputs X,
# quantitative output f (latent) and ordinal values Y


def generate_data(num_data):
    # First generate random inputs
    X = np.random.rand(num_data, 1)

    # Now generate values of a latent GP
    kern = gpflow.kernels.SquaredExponential(lengthscales=0.1)
    K = kern(X)
    f = np.random.multivariate_normal(mean=np.zeros(num_data), cov=K).reshape(-1, 1)

    # Finally convert f values into ordinal values Y
    Y = np.round((f + f.min()) * 3)
    Y = Y - Y.min()
    Y = np.asarray(Y, np.float64)

    return X, f, Y


np.random.seed(1)
num_data = 20
X, f, Y = generate_data(num_data)

plt.figure(figsize=(11, 6))
plt.plot(X, f, ".")
plt.ylabel("latent function value")

plt.twinx()
plt.plot(X, Y, "kx", mew=1.5)
plt.ylabel("observed data value")
[2]:
Text(0, 0.5, 'observed data value')
../../_images/notebooks_advanced_ordinal_regression_2_1.png
[3]:
# construct ordinal likelihood - bin_edges is the same as unique(Y) but centered
bin_edges = np.array(np.arange(np.unique(Y).size + 1), dtype=float)
bin_edges = bin_edges - bin_edges.mean()
likelihood = gpflow.likelihoods.Ordinal(bin_edges)

# build a model with this likelihood
m = gpflow.models.VGP(data=(X, Y), kernel=gpflow.kernels.Matern32(), likelihood=likelihood)

# fit the model
opt = gpflow.optimizers.Scipy()
opt.minimize(m.training_loss, m.trainable_variables, options=dict(maxiter=100))
[3]:
      fun: 25.487470692631103
 hess_inv: <233x233 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 1.25391599e-01, -8.34295277e-03, -1.00211335e-02, -6.66170050e-02,
       -2.31394989e-02, -4.06923307e-02, -4.41550939e-02, -6.73552354e-02,
       -2.97403459e-02, -6.97145755e-02, -4.77751499e-04, -1.28216684e-02,
        2.15922407e-02,  5.09282300e-03,  1.45541404e-02,  6.89161252e-03,
       -2.42551292e-02, -9.46174862e-03, -2.19788314e-02, -2.05849614e-04,
       -1.22270587e-02, -2.20568134e-02, -1.67064620e-03,  2.77479368e-04,
       -4.07729955e-04, -5.72601731e-09,  2.48544749e-06, -1.87536411e-09,
       -7.61127053e-06,  1.94609649e-11, -6.40312905e-03,  1.97847188e-08,
        2.42345491e-06, -2.42523213e-07, -2.06798908e-05,  1.54133417e-04,
       -3.35396136e-03,  2.68666349e-04,  4.78041484e-05,  1.84643471e-04,
       -1.88612338e-04, -1.51946389e-04,  3.47144241e-04, -2.59907342e-02,
        7.38221708e-03, -3.48179191e-09,  1.01572602e-06, -2.16084277e-10,
        4.48441878e-04,  1.26913517e-11, -1.31563966e-03,  1.42353635e-08,
        1.74814860e-07, -4.67925517e-07, -1.70253006e-05,  9.67451077e-05,
        2.67988765e-03, -1.21413305e-03, -1.31564174e-02,  1.63179439e-03,
       -1.35535335e-03, -4.00252721e-04, -8.41964464e-04,  1.31397553e-02,
        3.09364365e-02, -1.37787493e-02, -5.81673766e-04, -5.26846909e-03,
       -2.84811574e-09, -1.51924096e-04, -6.49922216e-07, -8.72306548e-03,
        4.66235026e-03, -2.30818057e-02, -9.02137991e-04, -7.13773002e-04,
       -5.79371428e-05, -4.16300160e-04,  6.44358255e-05, -7.50279360e-03,
       -3.52310132e-04,  2.49648175e-03,  3.23837926e-03, -2.18791294e-02,
       -9.92876068e-03, -6.72232277e-02, -1.95555497e-05, -8.16703888e-05,
       -2.65370410e-08, -2.20651318e-06,  3.82604403e-06, -1.25913596e-04,
        7.31895737e-05, -3.69200739e-04, -1.79764322e-05, -1.31292157e-05,
        1.74249903e-06, -6.48589831e-06,  1.65192475e-06, -1.16962877e-04,
       -5.22060884e-06,  4.06241765e-05, -1.79762027e-05,  3.60713274e-03,
       -2.37341419e-02, -3.26068422e-02,  1.55263457e-02,  2.43220905e-02,
       -1.65293066e-09,  4.13170098e-05, -3.23374772e-07,  1.79247303e-02,
        8.80094534e-04, -5.92594557e-03, -6.07655646e-04, -4.28655638e-03,
       -2.69801462e-05,  6.31039562e-04,  3.54680328e-04, -4.96803600e-03,
        8.70662761e-04, -4.20890133e-03, -8.30145094e-06, -2.68520935e-02,
       -3.04943386e-02, -5.96560291e-02,  5.02371567e-02, -1.69166591e-02,
       -1.32372157e-03,  7.45804806e-11,  8.92332466e-05, -2.86717502e-08,
       -2.63068434e-06,  1.28441070e-06, -3.40292958e-06,  2.12371825e-04,
       -1.21228773e-03,  3.38571389e-04,  7.35553121e-03, -3.13474610e-03,
       -1.21251354e-02,  2.16192945e-04, -4.84555412e-04, -2.24971955e-02,
       -9.87632251e-04,  5.78167843e-03, -8.25627093e-03,  1.52010564e-02,
        1.24230637e-02, -4.64049936e-03,  9.34359631e-09,  1.48549889e-02,
       -8.01916346e-05,  3.53200001e-03,  8.62321238e-04, -3.95215111e-03,
        7.04611228e-04,  1.40365047e-03,  1.07643337e-04,  1.04302417e-02,
       -1.04896478e-03,  5.78069169e-03,  6.11629668e-03,  1.30020687e-02,
       -6.94244419e-03, -8.84817043e-03,  3.34723465e-03, -1.32459503e-02,
        7.23139143e-03, -2.54784029e-02,  1.14937963e-02,  4.13756913e-07,
       -9.61173523e-05, -7.11609527e-06, -8.85242568e-04,  3.20688390e-03,
       -1.13941190e-02,  4.38210159e-03,  1.31185877e-03,  1.25030271e-03,
       -1.21411408e-03, -1.73244727e-03,  3.52865236e-03,  6.83882529e-03,
       -9.65871334e-05,  3.09981879e-04, -1.50597018e-02, -4.95287187e-03,
        7.53994774e-05,  5.97282657e-03,  7.35930957e-03, -8.55451495e-03,
       -5.86463790e-04,  5.16916450e-03, -2.71680623e-03, -2.17781172e-02,
        3.56526824e-06,  3.80997845e-03,  1.87128518e-03, -2.05686141e-02,
        5.59811585e-03,  1.24521617e-02, -3.91121541e-03,  9.83636686e-04,
        5.82430233e-04,  4.37681708e-04,  1.27240352e-02, -3.54740360e-04,
        7.33627933e-04, -3.42282311e-04,  2.81604361e-03,  1.18179632e-02,
       -8.33583433e-04,  8.21443778e-04,  1.63396235e-04,  8.15674456e-04,
        4.19834758e-05, -8.09516348e-05, -6.62651024e-05,  1.62659465e-03,
       -1.61287338e-04, -7.76576872e-04, -2.65153064e-03,  6.93281112e-03,
        2.95066488e-03, -6.15601280e-04, -6.94388981e-02,  1.46701144e-03,
       -2.02808367e-03,  3.70975454e-03, -6.66313084e-03, -7.06742860e-03,
       -6.50258509e-05])
  message: b'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
     nfev: 116
      nit: 100
     njev: 116
   status: 1
  success: False
        x: array([-1.97907896e+00,  5.46701841e+00, -1.44968251e+00, -1.99974441e+00,
       -2.36114785e-01, -8.89346881e-04,  7.62040483e-01, -1.98107145e-01,
       -1.59229303e+00,  7.93682187e-01, -3.86891236e-01, -1.13951583e-01,
        3.78707088e-01, -8.97993258e-02,  3.46386924e-01, -7.61859281e-02,
        1.59571005e+00, -2.61504320e-01,  8.54336838e-01,  6.56865340e-03,
        4.45993591e-01, -3.20375086e-01,  6.90178780e-04,  9.88817619e-01,
       -2.80834123e-03,  6.18777424e-08,  1.28538846e-05,  1.00556831e-08,
        1.33406480e-04,  5.89613429e-11, -6.28892129e-02,  7.67118176e-09,
       -1.38966302e-04, -1.17831607e-05, -7.71493927e-04,  2.21552227e-03,
       -6.31942359e-02,  1.79722079e-03, -3.69548268e-03, -2.79570052e-03,
        1.02947016e-05,  4.52858319e-06,  3.22317997e-05,  2.24969374e-01,
        9.45886593e-01,  3.60353928e-08,  6.94200829e-06,  6.70488000e-09,
        8.43144084e-03,  3.49000142e-11, -2.24847923e-02,  7.57775783e-09,
       -4.44187611e-05, -4.05326522e-06, -2.51103004e-04,  7.33510546e-04,
        2.99292300e-02, -3.30696035e-02, -2.28136147e-01,  1.38894803e-03,
       -7.88064516e-04,  1.12092345e-05, -8.75652983e-06, -1.92217027e-02,
        1.41993514e-01,  6.14727174e-01, -9.80895071e-03, -1.31909989e-01,
       -4.48104348e-08, -6.29752498e-04, -1.33374631e-05,  1.21332066e-02,
        7.00032626e-02, -3.90308586e-01, -3.17492873e-02, -8.18879284e-03,
       -2.86415789e-04, -3.04367137e-05, -1.98340685e-07,  1.74927857e-03,
       -5.40886221e-05,  2.23700128e-02,  4.28743491e-02, -5.43021010e-03,
        3.54070954e-04,  1.37477679e-01,  9.99917310e-01, -2.04367064e-03,
       -4.50347360e-07, -9.60929364e-06,  3.04966567e-05,  1.97203358e-04,
        1.07912433e-03, -6.05898924e-03, -4.88669133e-04, -1.28207798e-04,
        2.44421870e-05,  2.22259342e-07,  1.20151932e-05,  2.73988571e-05,
       -7.15996682e-07,  3.48414611e-04, -1.75315103e-04, -1.52910942e-01,
        3.75898860e-03, -1.23786929e-02,  1.69801687e-01,  7.18765815e-01,
       -1.66883129e-08, -4.40748149e-03, -4.47482678e-06, -4.24396472e-01,
        2.30613258e-02, -8.94036616e-02, -1.37515629e-02, -3.75799190e-03,
       -9.54182585e-05,  3.85537780e-05,  2.32816689e-05,  7.92365000e-04,
        3.84561758e-05,  4.30450840e-02,  1.76195986e-02,  3.15240495e-02,
       -7.24815381e-04, -6.33892533e-02, -7.06781036e-02,  1.09865191e-01,
        5.73035294e-01, -2.04466270e-10,  1.13192204e-02, -1.83561014e-08,
        2.11830325e-05,  2.17039191e-06,  1.15339266e-04, -2.48827292e-04,
       -2.39818533e-02, -1.74613404e-01,  4.90783939e-02, -4.37070727e-04,
       -3.00789593e-01, -9.20354547e-05,  2.29335226e-04, -1.28968332e-02,
       -1.18931538e-04, -1.25981628e-01,  3.06821811e-02, -1.41213195e-01,
        3.20208052e-01,  1.62204016e-01, -2.34521519e-07,  5.21911744e-02,
        1.25361863e-03, -2.20249104e-03, -9.54694730e-04, -5.61824052e-04,
        1.88092516e-05,  8.03988680e-05,  2.18812712e-05,  1.90943218e-04,
        2.86054642e-05, -1.13137858e-01,  8.99244627e-04,  3.53436956e-02,
       -7.32408008e-04, -2.02198519e-02, -9.81261254e-02, -2.11316436e-01,
        1.18282455e-01,  3.07494683e-01,  6.84414964e-01, -2.34998534e-06,
       -2.49202432e-03, -2.14499377e-04, -1.36725336e-02,  3.84318867e-02,
       -4.85656131e-01,  6.70920040e-03, -3.72242135e-02, -6.34281601e-02,
        7.17546339e-03, -3.43835861e-05,  3.31866584e-04, -2.12536596e-01,
        4.21398018e-03, -9.04589147e-04, -2.67073328e-01, -9.12547593e-03,
        1.82209987e-03,  5.99702239e-02,  4.20817391e-01,  2.69430820e-01,
        3.30852545e-02, -6.13284122e-02, -2.50654421e-02, -7.41504006e-03,
       -4.92777439e-05,  2.49109651e-04,  1.13733326e-04,  9.35333056e-04,
        3.12446746e-04, -2.73080273e-01,  2.97714507e-02, -8.59073823e-02,
        5.73162086e-03,  3.08240155e-04,  8.21208844e-03, -6.95159214e-04,
        2.05128711e-04, -1.22898730e-02, -5.97399251e-02,  9.07757485e-01,
        9.94613648e-01,  4.55102917e-02,  5.61527426e-03,  1.55941350e-03,
       -2.06781146e-03,  2.30229768e-05, -1.28019734e-04, -5.98667892e-04,
        3.59218412e-05,  5.40233361e-03, -2.59294432e-02, -2.08567182e-01,
       -4.35793828e-02, -2.70732924e-05, -1.18749821e-02, -1.90956743e-04,
       -4.28419532e-05, -2.18219196e-04,  5.09939094e-02,  2.04367843e-01,
        1.67917634e-01])
[4]:
# here we'll plot the expected value of Y +- 2 std deviations, as if the distribution were Gaussian
plt.figure(figsize=(11, 6))
X_data, Y_data = (m.data[0].numpy(), m.data[1].numpy())
Xtest = np.linspace(X_data.min(), X_data.max(), 100).reshape(-1, 1)
mu, var = m.predict_y(Xtest)
(line,) = plt.plot(Xtest, mu, lw=2)
col = line.get_color()
plt.plot(Xtest, mu + 2 * np.sqrt(var), "--", lw=2, color=col)
plt.plot(Xtest, mu - 2 * np.sqrt(var), "--", lw=2, color=col)
plt.plot(X_data, Y_data, "kx", mew=2)
[4]:
[<matplotlib.lines.Line2D at 0x7f4740415ba8>]
../../_images/notebooks_advanced_ordinal_regression_4_1.png
[5]:
## to see the predictive density, try predicting every possible discrete value for Y.
def pred_log_density(m):
    Xtest = np.linspace(X_data.min(), X_data.max(), 100).reshape(-1, 1)
    ys = np.arange(Y_data.max() + 1)
    densities = []
    for y in ys:
        Ytest = np.full_like(Xtest, y)
        # Predict the log density
        densities.append(m.predict_log_density((Xtest, Ytest)))
    return np.vstack(densities)
[6]:
fig = plt.figure(figsize=(14, 6))
plt.imshow(
    np.exp(pred_log_density(m)),
    interpolation="nearest",
    extent=[X_data.min(), X_data.max(), -0.5, Y_data.max() + 0.5],
    origin="lower",
    aspect="auto",
    cmap=plt.cm.viridis,
)
plt.colorbar()
plt.plot(X, Y, "kx", mew=2, scalex=False, scaley=False)
[6]:
[<matplotlib.lines.Line2D at 0x7f47402b35c0>]
../../_images/notebooks_advanced_ordinal_regression_6_1.png
[7]:
# Predictive density for a single input x=0.5
x_new = 0.5
Y_new = np.arange(np.max(Y_data + 1)).reshape([-1, 1])
X_new = np.full_like(Y_new, x_new)
# for predict_log_density x and y need to have the same number of rows
dens_new = np.exp(m.predict_log_density((X_new, Y_new)))
fig = plt.figure(figsize=(8, 4))
plt.bar(x=Y_new.flatten(), height=dens_new.flatten())
[7]:
<BarContainer object of 8 artists>
../../_images/notebooks_advanced_ordinal_regression_7_1.png