Ordinal regressionΒΆ

Ordinal regression aims to fit a model to some data \((X, Y)\), where \(Y\) is an ordinal variable. To do so, we use a VPG model with a specific likelihood (gpflow.likelihoods.Ordinal).

[1]:
import gpflow

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 6)

np.random.seed(123)  # for reproducibility
[2]:
# make a one-dimensional ordinal regression problem

# This function generates a set of inputs X,
# quantitative output f (latent) and ordinal values Y


def generate_data(num_data):
    # First generate random inputs
    X = np.random.rand(num_data, 1)

    # Now generate values of a latent GP
    kern = gpflow.kernels.SquaredExponential(lengthscales=0.1)
    K = kern(X)
    f = np.random.multivariate_normal(mean=np.zeros(num_data), cov=K).reshape(-1, 1)

    # Finally convert f values into ordinal values Y
    Y = np.round((f + f.min()) * 3)
    Y = Y - Y.min()
    Y = np.asarray(Y, np.float64)

    return X, f, Y


np.random.seed(1)
num_data = 20
X, f, Y = generate_data(num_data)

plt.figure(figsize=(11, 6))
plt.plot(X, f, ".")
plt.ylabel("latent function value")

plt.twinx()
plt.plot(X, Y, "kx", mew=1.5)
plt.ylabel("observed data value")
[2]:
Text(0, 0.5, 'observed data value')
../../_images/notebooks_advanced_ordinal_regression_2_1.png
[3]:
# construct ordinal likelihood - bin_edges is the same as unique(Y) but centered
bin_edges = np.array(np.arange(np.unique(Y).size + 1), dtype=float)
bin_edges = bin_edges - bin_edges.mean()
likelihood = gpflow.likelihoods.Ordinal(bin_edges)

# build a model with this likelihood
m = gpflow.models.VGP(data=(X, Y), kernel=gpflow.kernels.Matern32(), likelihood=likelihood)

# fit the model
opt = gpflow.optimizers.Scipy()
opt.minimize(m.training_loss, m.trainable_variables, options=dict(maxiter=100))
[3]:
      fun: 25.48747181534459
 hess_inv: <233x233 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 1.26152196e-01, -8.41224623e-03, -9.93727410e-03, -6.68680311e-02,
       -2.47128560e-02, -3.90863360e-02, -4.36104289e-02, -6.57310044e-02,
       -2.94547316e-02, -6.95339820e-02, -5.18409931e-04, -1.27567520e-02,
        2.12060902e-02,  5.10858094e-03,  1.45583925e-02,  6.92877794e-03,
       -2.48692447e-02, -9.36685225e-03, -2.19435231e-02, -2.07624925e-04,
       -1.23378319e-02, -2.20723607e-02, -1.66797389e-03,  2.77045988e-04,
       -4.07993387e-04, -5.73153719e-09,  2.48701864e-06, -1.87131978e-09,
       -7.61084282e-06,  1.95440893e-11, -6.40743023e-03,  1.96717474e-08,
        2.44989380e-06, -2.36112948e-07, -2.04825774e-05,  1.53901841e-04,
       -3.34610551e-03,  2.68233576e-04,  5.28237186e-05,  1.72004444e-04,
       -1.95482386e-04, -1.51001568e-04,  3.36074376e-04, -2.61784695e-02,
        7.39189341e-03, -3.48309123e-09,  1.01706802e-06, -2.12059577e-10,
        4.49455581e-04,  1.27598923e-11, -1.31846180e-03,  1.41883895e-08,
        1.83815767e-07, -4.65411983e-07, -1.69575962e-05,  9.67286542e-05,
        2.68790924e-03, -1.21799250e-03, -1.31688203e-02,  1.61413673e-03,
       -1.39405353e-03, -3.92109395e-04, -8.36125347e-04,  1.33759882e-02,
        3.02330881e-02, -1.37287609e-02, -5.81481338e-04, -5.30825089e-03,
       -2.84863822e-09, -1.52401716e-04, -6.50002479e-07, -8.74922236e-03,
        4.66238033e-03, -2.30820544e-02, -9.20358661e-04, -7.62067846e-04,
       -5.77553675e-05, -4.16414679e-04,  6.97308437e-05, -7.63116327e-03,
       -3.31714098e-04,  2.47244274e-03,  3.19345849e-03, -2.30163127e-02,
       -9.16326473e-03, -6.99927451e-02, -1.95768484e-05, -8.22869408e-05,
       -2.65827574e-08, -2.21390750e-06,  3.82836028e-06, -1.26323228e-04,
        7.31873060e-05, -3.69202145e-04, -1.82574213e-05, -1.38785404e-05,
        1.74176275e-06, -6.48739309e-06,  1.73236956e-06, -1.18944099e-04,
       -4.89635341e-06,  4.02537922e-05, -1.86915679e-05,  2.87939664e-03,
       -2.33175454e-02, -3.37415061e-02,  1.34777177e-02,  2.42273910e-02,
       -1.65600144e-09,  4.15510506e-05, -3.23739873e-07,  1.79038020e-02,
        8.78259587e-04, -5.93853814e-03, -6.18745344e-04, -4.26758542e-03,
       -2.67589464e-05,  6.20611307e-04,  3.58206945e-04, -5.04746058e-03,
        8.73170883e-04, -4.30841902e-03, -9.98885633e-06, -2.74696566e-02,
       -2.99910145e-02, -6.19882581e-02,  4.86420338e-02, -1.57700316e-02,
       -1.29721471e-03,  7.45061567e-11,  8.93714103e-05, -2.87343274e-08,
       -2.59934363e-06,  1.28626551e-06, -3.24626970e-06,  2.11857175e-04,
       -1.20636043e-03,  3.01520454e-04,  7.41327654e-03, -3.11098798e-03,
       -1.22258157e-02,  2.66834160e-04, -5.77313324e-04, -2.22873063e-02,
       -8.59041039e-04,  5.50580340e-03, -7.98580793e-03,  1.53251841e-02,
        1.23638607e-02, -4.66014857e-03,  8.67519113e-09,  1.50394423e-02,
       -8.15618660e-05,  3.63733104e-03,  8.90499008e-04, -3.83498171e-03,
        6.99589634e-04,  1.37842553e-03,  9.06914470e-05,  1.06859549e-02,
       -1.12950401e-03,  5.24432343e-03,  6.09674509e-03,  1.24893681e-02,
       -6.90802984e-03, -9.15412275e-03,  2.75348541e-03, -1.30203425e-02,
        7.21250627e-03, -2.49634754e-02,  1.15169738e-02,  4.11520908e-07,
       -9.58987630e-05, -7.03900052e-06, -8.83240952e-04,  3.20539135e-03,
       -1.12879286e-02,  4.38003790e-03,  1.36017528e-03,  1.11508780e-03,
       -1.28581270e-03, -1.72238245e-03,  3.41572042e-03,  6.89257814e-03,
       -4.01043009e-05,  3.11272748e-04, -1.48929272e-02, -4.99911941e-03,
        7.94435229e-05,  5.94885114e-03,  7.35840121e-03, -8.85628541e-03,
       -5.92934998e-04,  5.05437053e-03, -2.76643700e-03, -2.16315863e-02,
        5.76633626e-06,  3.75703035e-03,  1.88744607e-03, -2.08894712e-02,
        5.60037684e-03,  1.20461395e-02, -3.83438272e-03,  1.01245552e-03,
        6.20995715e-04,  3.99885132e-04,  1.29613111e-02, -3.65630516e-04,
        7.35866892e-04, -3.38583229e-04,  2.89798475e-03,  1.18681526e-02,
       -8.35164287e-04,  8.24129729e-04,  1.67323538e-04,  8.15391105e-04,
        4.24443656e-05, -7.88873449e-05, -6.72175178e-05,  1.65245291e-03,
       -1.63905514e-04, -7.57546037e-04, -2.65058528e-03,  6.71521731e-03,
        2.71789776e-03, -4.53788676e-04, -7.05487493e-02,  1.51970089e-03,
       -2.04612759e-03,  3.71262166e-03, -7.02766316e-03, -7.30399148e-03,
       -2.12642742e-04])
  message: b'STOP: TOTAL NO. of ITERATIONS REACHED LIMIT'
     nfev: 116
      nit: 100
     njev: 116
   status: 1
  success: False
        x: array([-1.97911023e+00,  5.46677555e+00, -1.44968151e+00, -1.99978856e+00,
       -2.36156700e-01, -8.91373417e-04,  7.62007697e-01, -1.98089892e-01,
       -1.59227266e+00,  7.93685832e-01, -3.86898740e-01, -1.13959783e-01,
        3.78664184e-01, -8.97922860e-02,  3.46412307e-01, -7.61912349e-02,
        1.59571557e+00, -2.61498927e-01,  8.54379536e-01,  6.56793847e-03,
        4.45956063e-01, -3.20422721e-01,  6.88570533e-04,  9.88816502e-01,
       -2.80878280e-03,  6.18741008e-08,  1.28562383e-05,  1.00547938e-08,
        1.33412006e-04,  5.89916929e-11, -6.28972851e-02,  7.62098751e-09,
       -1.38951664e-04, -1.17796589e-05, -7.71391647e-04,  2.21535534e-03,
       -6.31960765e-02,  1.79706148e-03, -3.69506365e-03, -2.79563387e-03,
        1.03297753e-05,  4.48026631e-06,  3.17563117e-05,  2.24978162e-01,
        9.45886615e-01,  3.60348805e-08,  6.94360319e-06,  6.70557218e-09,
        8.43280902e-03,  3.49308012e-11, -2.24885162e-02,  7.55800323e-09,
       -4.44138698e-05, -4.05201455e-06, -2.51071070e-04,  7.33502871e-04,
        2.99380932e-02, -3.30745035e-02, -2.28167700e-01,  1.38899681e-03,
       -7.88485490e-04,  1.12462514e-05, -7.46173920e-06, -1.92175811e-02,
        1.42001646e-01,  6.14726579e-01, -9.80900047e-03, -1.31921465e-01,
       -4.48085528e-08, -6.29938594e-04, -1.33373003e-05,  1.21251052e-02,
        7.00036169e-02, -3.90307794e-01, -3.17449303e-02, -8.18832196e-03,
       -2.86470232e-04, -3.02352092e-05, -3.44827858e-07,  1.74662434e-03,
       -5.35072658e-05,  2.23658974e-02,  4.28672661e-02, -5.44312958e-03,
        3.59041497e-04,  1.37473184e-01,  9.99917298e-01, -2.04383560e-03,
       -4.50410123e-07, -9.61212219e-06,  3.05007420e-05,  1.97071295e-04,
        1.07911979e-03, -6.05893662e-03, -4.88596546e-04, -1.28200793e-04,
        2.44422300e-05,  2.25537592e-07,  1.20142519e-05,  2.73577408e-05,
       -7.06729888e-07,  3.48351667e-04, -1.75517253e-04, -1.52914450e-01,
        3.76203003e-03, -1.23777370e-02,  1.69791138e-01,  7.18740166e-01,
       -1.66890155e-08, -4.40721213e-03, -4.47484156e-06, -4.24394957e-01,
        2.30604160e-02, -8.93998103e-02, -1.37494988e-02, -3.75506810e-03,
       -9.54510862e-05,  3.80046481e-05,  2.32784560e-05,  7.91116923e-04,
        3.80864893e-05,  4.30378440e-02,  1.76172349e-02,  3.15319746e-02,
       -7.25387106e-04, -6.34001268e-02, -7.06792825e-02,  1.09870234e-01,
        5.73033404e-01, -2.04459886e-10,  1.13191372e-02, -1.83875241e-08,
        2.11935036e-05,  2.17140143e-06,  1.15397913e-04, -2.48980565e-04,
       -2.39800515e-02, -1.74621351e-01,  4.90889408e-02, -4.33639509e-04,
       -3.00794133e-01, -9.11717166e-05,  2.25884961e-04, -1.28788980e-02,
       -1.15490417e-04, -1.25985797e-01,  3.06938926e-02, -1.41200163e-01,
        3.20201275e-01,  1.62204795e-01, -2.34518284e-07,  5.21925357e-02,
        1.25344305e-03, -2.20057554e-03, -9.54189108e-04, -5.56672948e-04,
        1.86885420e-05,  7.91600486e-05,  2.15279013e-05,  1.95877431e-04,
        2.62049864e-05, -1.13155861e-01,  8.97392085e-04,  3.53186742e-02,
       -7.34797422e-04, -2.02171794e-02, -9.81333102e-02, -2.11301734e-01,
        1.18277352e-01,  3.07506756e-01,  6.84409359e-01, -2.35058010e-06,
       -2.49194977e-03, -2.14458478e-04, -1.36717338e-02,  3.84309348e-02,
       -4.85637323e-01,  6.70860534e-03, -3.72233702e-02, -6.34290568e-02,
        7.17537949e-03, -3.48412464e-05,  3.27229111e-04, -2.12534126e-01,
        4.21384801e-03, -9.02527141e-04, -2.67057449e-01, -9.12901724e-03,
        1.82231900e-03,  5.99689631e-02,  4.20809671e-01,  2.69425029e-01,
        3.30836743e-02, -6.13316321e-02, -2.50645620e-02, -7.40122602e-03,
       -4.92381166e-05,  2.46019748e-04,  1.14043418e-04,  9.30885081e-04,
        3.10041332e-04, -2.73113289e-01,  2.97716767e-02, -8.59248419e-02,
        5.73094667e-03,  3.07678366e-04,  8.21620631e-03, -6.95125860e-04,
        2.04982904e-04, -1.22878321e-02, -5.97405319e-02,  9.07757443e-01,
        9.94612781e-01,  4.55080953e-02,  5.61473040e-03,  1.55877317e-03,
       -2.06761496e-03,  2.31054840e-05, -1.27993477e-04, -5.98232443e-04,
        3.59505436e-05,  5.40332134e-03, -2.59311479e-02, -2.08571871e-01,
       -4.35780186e-02, -2.36028474e-05, -1.18914145e-02, -1.90867493e-04,
       -4.25856487e-05, -2.18028359e-04,  5.09815613e-02,  2.04357521e-01,
        1.67917535e-01])
[4]:
# here we'll plot the expected value of Y +- 2 std deviations, as if the distribution were Gaussian
plt.figure(figsize=(11, 6))
X_data, Y_data = (m.data[0].numpy(), m.data[1].numpy())
Xtest = np.linspace(X_data.min(), X_data.max(), 100).reshape(-1, 1)
mu, var = m.predict_y(Xtest)
(line,) = plt.plot(Xtest, mu, lw=2)
col = line.get_color()
plt.plot(Xtest, mu + 2 * np.sqrt(var), "--", lw=2, color=col)
plt.plot(Xtest, mu - 2 * np.sqrt(var), "--", lw=2, color=col)
plt.plot(X_data, Y_data, "kx", mew=2)
[4]:
[<matplotlib.lines.Line2D at 0x7f4b280bfba8>]
../../_images/notebooks_advanced_ordinal_regression_4_1.png
[5]:
## to see the predictive density, try predicting every possible discrete value for Y.
def pred_log_density(m):
    Xtest = np.linspace(X_data.min(), X_data.max(), 100).reshape(-1, 1)
    ys = np.arange(Y_data.max() + 1)
    densities = []
    for y in ys:
        Ytest = np.full_like(Xtest, y)
        # Predict the log density
        densities.append(m.predict_log_density((Xtest, Ytest)))
    return np.vstack(densities)
[6]:
fig = plt.figure(figsize=(14, 6))
plt.imshow(
    np.exp(pred_log_density(m)),
    interpolation="nearest",
    extent=[X_data.min(), X_data.max(), -0.5, Y_data.max() + 0.5],
    origin="lower",
    aspect="auto",
    cmap=plt.cm.viridis,
)
plt.colorbar()
plt.plot(X, Y, "kx", mew=2, scalex=False, scaley=False)
[6]:
[<matplotlib.lines.Line2D at 0x7f4b187b24a8>]
../../_images/notebooks_advanced_ordinal_regression_6_1.png
[7]:
# Predictive density for a single input x=0.5
x_new = 0.5
Y_new = np.arange(np.max(Y_data + 1)).reshape([-1, 1])
X_new = np.full_like(Y_new, x_new)
# for predict_log_density x and y need to have the same number of rows
dens_new = np.exp(m.predict_log_density((X_new, Y_new)))
fig = plt.figure(figsize=(8, 4))
plt.bar(x=Y_new.flatten(), height=dens_new.flatten())
[7]:
<BarContainer object of 8 artists>
../../_images/notebooks_advanced_ordinal_regression_7_1.png