Ordinal regressionΒΆ

Ordinal regression aims to fit a model to some data \((X, Y)\), where \(Y\) is an ordinal variable. To do so, we use a VPG model with a specific likelihood (gpflow.likelihoods.Ordinal).

[1]:
import gpflow

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 6)

np.random.seed(123)  # for reproducibility
[2]:
# make a one-dimensional ordinal regression problem

# This function generates a set of inputs X,
# quantitative output f (latent) and ordinal values Y


def generate_data(num_data):
    # First generate random inputs
    X = np.random.rand(num_data, 1)

    # Now generate values of a latent GP
    kern = gpflow.kernels.SquaredExponential(lengthscales=0.1)
    K = kern(X)
    f = np.random.multivariate_normal(mean=np.zeros(num_data), cov=K).reshape(-1, 1)

    # Finally convert f values into ordinal values Y
    Y = np.round((f + f.min()) * 3)
    Y = Y - Y.min()
    Y = np.asarray(Y, np.float64)

    return X, f, Y


np.random.seed(1)
num_data = 20
X, f, Y = generate_data(num_data)

plt.figure(figsize=(11, 6))
plt.plot(X, f, ".")
plt.ylabel("latent function value")

plt.twinx()
plt.plot(X, Y, "kx", mew=1.5)
plt.ylabel("observed data value")
[2]:
Text(0, 0.5, 'observed data value')
../../_images/notebooks_advanced_ordinal_regression_2_1.png
[3]:
# construct ordinal likelihood - bin_edges is the same as unique(Y) but centered
bin_edges = np.array(np.arange(np.unique(Y).size + 1), dtype=float)
bin_edges = bin_edges - bin_edges.mean()
likelihood = gpflow.likelihoods.Ordinal(bin_edges)

# build a model with this likelihood
m = gpflow.models.VGP(data=(X, Y), kernel=gpflow.kernels.Matern32(), likelihood=likelihood)

# fit the model
opt = gpflow.optimizers.Scipy()
opt.minimize(m.training_loss, m.trainable_variables, options=dict(maxiter=100))
[3]:
      fun: 25.487472798476816
 hess_inv: <233x233 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 1.26630780e-01, -8.47299722e-03, -9.87628877e-03, -6.69841403e-02,
       -2.59031694e-02, -3.78472213e-02, -4.31504148e-02, -6.44750050e-02,
       -2.92352943e-02, -6.93855308e-02, -5.43500163e-04, -1.27124189e-02,
        2.09378892e-02,  5.12045175e-03,  1.45692449e-02,  6.96025879e-03,
       -2.53312856e-02, -9.29268988e-03, -2.19144904e-02, -2.08969607e-04,
       -1.24218419e-02, -2.20854645e-02, -1.66570380e-03,  2.76646957e-04,
       -4.08209217e-04, -5.73584493e-09,  2.48830294e-06, -1.86776012e-09,
       -7.61031049e-06,  1.96119310e-11, -6.41098687e-03,  1.95834642e-08,
        2.46990175e-06, -2.31089622e-07, -2.03312573e-05,  1.53738973e-04,
       -3.34047794e-03,  2.67954865e-04,  5.67574553e-05,  1.62344131e-04,
       -2.00401506e-04, -1.50249742e-04,  3.27914360e-04, -2.62417672e-02,
        7.39984029e-03, -3.48418303e-09,  1.01813597e-06, -2.08650993e-10,
        4.50254194e-04,  1.28152833e-11, -1.32069792e-03,  1.41518437e-08,
        1.90648833e-07, -4.63440920e-07, -1.69056086e-05,  9.67227819e-05,
        2.69422762e-03, -1.22101187e-03, -1.31790906e-02,  1.60016487e-03,
       -1.42217443e-03, -3.85743062e-04, -8.30398643e-04,  1.35366613e-02,
        2.96747867e-02, -1.36892981e-02, -5.81398736e-04, -5.34192251e-03,
       -2.84947443e-09, -1.52792066e-04, -6.50144585e-07, -8.77041563e-03,
        4.66299000e-03, -2.30852540e-02, -9.33642929e-04, -7.98738236e-04,
       -5.75583603e-05, -4.16481858e-04,  7.37808003e-05, -7.72747937e-03,
       -3.16381784e-04,  2.45430934e-03,  3.15929009e-03, -2.38239209e-02,
       -8.57209294e-03, -7.20111306e-02, -1.95958076e-05, -8.28084742e-05,
       -2.66189690e-08, -2.21994768e-06,  3.83027124e-06, -1.26654819e-04,
        7.31945892e-05, -3.69249435e-04, -1.84623242e-05, -1.44475677e-05,
        1.74228760e-06, -6.48824090e-06,  1.79388845e-06, -1.20430484e-04,
       -4.65508755e-06,  3.99744686e-05, -1.92299823e-05,  2.39203892e-03,
       -2.29972607e-02, -3.45639964e-02,  1.18960036e-02,  2.41496163e-02,
       -1.65863662e-09,  4.17333940e-05, -3.24068363e-07,  1.78886233e-02,
        8.77039047e-04, -5.94929993e-03, -6.26924753e-04, -4.25307108e-03,
       -2.65105181e-05,  6.12656280e-04,  3.60872313e-04, -5.10717851e-03,
        8.74760991e-04, -4.38628235e-03, -1.20962043e-05, -2.78636891e-02,
       -2.95973305e-02, -6.36798344e-02,  4.74019182e-02, -1.48920939e-02,
       -1.27518126e-03,  7.44537116e-11,  8.95105409e-05, -2.87881674e-08,
       -2.57384755e-06,  1.28785247e-06, -3.11926593e-06,  2.11457495e-04,
       -1.20190328e-03,  2.71797675e-04,  7.45882071e-03, -3.09273928e-03,
       -1.23014859e-02,  3.05745982e-04, -6.46302112e-04, -2.21216377e-02,
       -7.59899525e-04,  5.30109425e-03, -7.78239929e-03,  1.54178504e-02,
        1.23195034e-02, -4.67136844e-03,  8.19216385e-09,  1.51758447e-02,
       -8.25631043e-05,  3.71363612e-03,  9.12003046e-04, -3.74580987e-03,
        6.95949575e-04,  1.35962134e-03,  7.79591599e-05,  1.08771815e-02,
       -1.18915546e-03,  4.84597007e-03,  6.08031556e-03,  1.21099543e-02,
       -6.88028906e-03, -9.37326791e-03,  2.29962345e-03, -1.28438011e-02,
        7.20029469e-03, -2.45899375e-02,  1.15349087e-02,  4.09760456e-07,
       -9.57363692e-05, -6.97849298e-06, -8.81710743e-04,  3.20445134e-03,
       -1.12092977e-02,  4.37909192e-03,  1.39810165e-03,  1.01165869e-03,
       -1.33714136e-03, -1.71438418e-03,  3.33268654e-03,  6.93842280e-03,
        2.21889057e-06,  3.12105565e-04, -1.47655905e-02, -5.03520008e-03,
        8.22491081e-05,  5.93187962e-03,  7.36095061e-03, -9.07779414e-03,
       -5.97672151e-04,  4.96539646e-03, -2.80376878e-03, -2.15197845e-02,
        7.78330047e-06,  3.71652455e-03,  1.89958329e-03, -2.11308181e-02,
        5.60044480e-03,  1.17242569e-02, -3.78082795e-03,  1.03516895e-03,
        6.50295849e-04,  3.71926016e-04,  1.31393913e-02, -3.73952177e-04,
        7.37557595e-04, -3.35896907e-04,  2.96072282e-03,  1.19082266e-02,
       -8.36535499e-04,  8.26537235e-04,  1.70216390e-04,  8.15149939e-04,
        4.27677533e-05, -7.73087866e-05, -6.79329288e-05,  1.67186806e-03,
       -1.65784027e-04, -7.42618276e-04, -2.64956175e-03,  6.55034951e-03,
        2.53913936e-03, -3.33975122e-04, -7.13821092e-02,  1.55999668e-03,
       -2.05976618e-03,  3.71547024e-03, -7.30499096e-03, -7.48254595e-03,
       -3.34095620e-04])
  message: b'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
     nfev: 116
      nit: 100
   status: 1
  success: False
        x: array([-1.97913484e+00,  5.46658209e+00, -1.44968137e+00, -1.99982278e+00,
       -2.36189672e-01, -8.93120677e-04,  7.61982094e-01, -1.98076790e-01,
       -1.59225658e+00,  7.93689473e-01, -3.86904489e-01, -1.13966290e-01,
        3.78630615e-01, -8.97866674e-02,  3.46432560e-01, -7.61947134e-02,
        1.59572083e+00, -2.61494712e-01,  8.54412397e-01,  6.56737196e-03,
        4.45926316e-01, -3.20460248e-01,  6.87339056e-04,  9.88815622e-01,
       -2.80913333e-03,  6.18711529e-08,  1.28581109e-05,  1.00541958e-08,
        1.33416242e-04,  5.90164893e-11, -6.29036837e-02,  7.58188501e-09,
       -1.38940328e-04, -1.17769119e-05, -7.71312443e-04,  2.21523158e-03,
       -6.31975828e-02,  1.79694338e-03, -3.69473159e-03, -2.79556923e-03,
        1.03593977e-05,  4.44417259e-06,  3.13875110e-05,  2.24985144e-01,
        9.45886831e-01,  3.60343959e-08,  6.94486163e-06,  6.70618486e-09,
        8.43388082e-03,  3.49555348e-11, -2.24914389e-02,  7.54271695e-09,
       -4.44100827e-05, -4.05103666e-06, -2.51046444e-04,  7.33499813e-04,
        2.99450370e-02, -3.30783356e-02, -2.28192519e-01,  1.38901616e-03,
       -7.88805572e-04,  1.12743764e-05, -6.45095561e-06, -1.92145631e-02,
        1.42007880e-01,  6.14726363e-01, -9.80906090e-03, -1.31931221e-01,
       -4.48072108e-08, -6.30090821e-04, -1.33371978e-05,  1.21183668e-02,
        7.00040768e-02, -3.90308015e-01, -3.17414300e-02, -8.18792052e-03,
       -2.86522161e-04, -3.00806388e-05, -4.58413311e-07,  1.74465752e-03,
       -5.30634356e-05,  2.23627216e-02,  4.28617634e-02, -5.45282558e-03,
        3.62847380e-04,  1.37470125e-01,  9.99917288e-01, -2.04397618e-03,
       -4.50459313e-07, -9.61443390e-06,  3.05039839e-05,  1.96961923e-04,
        1.07911888e-03, -6.05890751e-03, -4.88538334e-04, -1.28194812e-04,
        2.44421629e-05,  2.28052619e-07,  1.20135304e-05,  2.73272181e-05,
       -6.99659194e-07,  3.48303083e-04, -1.75675719e-04, -1.52916574e-01,
        3.76430384e-03, -1.23770584e-02,  1.69783051e-01,  7.18719328e-01,
       -1.66896389e-08, -4.40700092e-03, -4.47486717e-06, -4.24393497e-01,
        2.30597482e-02, -8.93970313e-02, -1.37478575e-02, -3.75282668e-03,
       -9.54779277e-05,  3.75859466e-05,  2.32742161e-05,  7.90180704e-04,
        3.77991328e-05,  4.30321020e-02,  1.76153861e-02,  3.15382092e-02,
       -7.25749949e-04, -6.34079088e-02, -7.06802899e-02,  1.09874088e-01,
        5.73032005e-01, -2.04452533e-10,  1.13190850e-02, -1.84137254e-08,
        2.12019663e-05,  2.17224037e-06,  1.15445117e-04, -2.49098133e-04,
       -2.39786661e-02, -1.74627924e-01,  4.90975126e-02, -4.30884815e-04,
       -3.00798263e-01, -9.05084869e-05,  2.23224736e-04, -1.28650161e-02,
       -1.12822995e-04, -1.25988935e-01,  3.07027062e-02, -1.41190296e-01,
        3.20196118e-01,  1.62205482e-01, -2.34514630e-07,  5.21934395e-02,
        1.25330732e-03, -2.19922262e-03, -9.53800651e-04, -5.52749570e-04,
        1.86046456e-05,  7.82260908e-05,  2.12619983e-05,  1.99555680e-04,
        2.44103402e-05, -1.13169262e-01,  8.95998798e-04,  3.52992520e-02,
       -7.36590738e-04, -2.02150513e-02, -9.81384792e-02, -2.11290176e-01,
        1.18273528e-01,  3.07515339e-01,  6.84405129e-01, -2.35104885e-06,
       -2.49189402e-03, -2.14426375e-04, -1.36711230e-02,  3.84302881e-02,
       -4.85622977e-01,  6.70824291e-03, -3.72226336e-02, -6.34296425e-02,
        7.17535793e-03, -3.51832489e-05,  3.23650768e-04, -2.12532464e-01,
        4.21368471e-03, -9.00828479e-04, -2.67044807e-01, -9.13178263e-03,
        1.82247159e-03,  5.99680554e-02,  4.20803774e-01,  2.69420802e-01,
        3.30824332e-02, -6.13340584e-02, -2.50638725e-02, -7.39068892e-03,
       -4.91971204e-05,  2.43659553e-04,  1.14272021e-04,  9.27500948e-04,
        3.08166986e-04, -2.73138927e-01,  2.97717883e-02, -8.59384957e-02,
        5.73040053e-03,  3.07274984e-04,  8.21930467e-03, -6.95093330e-04,
        2.04875175e-04, -1.22862446e-02, -5.97410083e-02,  9.07757777e-01,
        9.94612062e-01,  4.55064404e-02,  5.61429245e-03,  1.55827712e-03,
       -2.06746073e-03,  2.31687467e-05, -1.27972007e-04, -5.97905784e-04,
        3.59743749e-05,  5.40408201e-03, -2.59325112e-02, -2.08575293e-01,
       -4.35768469e-02, -2.10053999e-05, -1.19036764e-02, -1.90810280e-04,
       -4.23905777e-05, -2.17840598e-04,  5.09721001e-02,  2.04349612e-01,
        1.67917416e-01])
[4]:
# here we'll plot the expected value of Y +- 2 std deviations, as if the distribution were Gaussian
plt.figure(figsize=(11, 6))
X_data, Y_data = (m.data[0].numpy(), m.data[1].numpy())
Xtest = np.linspace(X_data.min(), X_data.max(), 100).reshape(-1, 1)
mu, var = m.predict_y(Xtest)
(line,) = plt.plot(Xtest, mu, lw=2)
col = line.get_color()
plt.plot(Xtest, mu + 2 * np.sqrt(var), "--", lw=2, color=col)
plt.plot(Xtest, mu - 2 * np.sqrt(var), "--", lw=2, color=col)
plt.plot(X_data, Y_data, "kx", mew=2)
[4]:
[<matplotlib.lines.Line2D at 0x7f319835b470>]
../../_images/notebooks_advanced_ordinal_regression_4_1.png
[5]:
## to see the predictive density, try predicting every possible discrete value for Y.
def pred_log_density(m):
    Xtest = np.linspace(X_data.min(), X_data.max(), 100).reshape(-1, 1)
    ys = np.arange(Y_data.max() + 1)
    densities = []
    for y in ys:
        Ytest = np.full_like(Xtest, y)
        # Predict the log density
        densities.append(m.predict_log_density((Xtest, Ytest)))
    return np.vstack(densities)
[6]:
fig = plt.figure(figsize=(14, 6))
plt.imshow(
    np.exp(pred_log_density(m)),
    interpolation="nearest",
    extent=[X_data.min(), X_data.max(), -0.5, Y_data.max() + 0.5],
    origin="lower",
    aspect="auto",
    cmap=plt.cm.viridis,
)
plt.colorbar()
plt.plot(X, Y, "kx", mew=2, scalex=False, scaley=False)
[6]:
[<matplotlib.lines.Line2D at 0x7f31a00fe828>]
../../_images/notebooks_advanced_ordinal_regression_6_1.png
[7]:
# Predictive density for a single input x=0.5
x_new = 0.5
Y_new = np.arange(np.max(Y_data + 1)).reshape([-1, 1])
X_new = np.full_like(Y_new, x_new)
# for predict_log_density x and y need to have the same number of rows
dens_new = np.exp(m.predict_log_density((X_new, Y_new)))
fig = plt.figure(figsize=(8, 4))
plt.bar(x=Y_new.flatten(), height=dens_new.flatten())
[7]:
<BarContainer object of 8 artists>
../../_images/notebooks_advanced_ordinal_regression_7_1.png