Skip to content

Vanilla Deep GP #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gpflux/architectures/constant_input_dim_deep_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _construct_kernel(input_dim: int, is_last_layer: bool) -> SquaredExponential
# data) seems a bit weird - that's really long lengthscales? And I remember seeing
# something where the value scaled with the number of dimensions before
lengthscales = [2.0] * input_dim
return SquaredExponential(lengthscales=lengthscales, variance=variance)
return gpflow.kernels.ArcCosine(order=1, weight_variances=1/lengthscales, variance=variance)


def build_constant_input_dim_deep_gp(X: np.ndarray, num_layers: int, config: Config) -> DeepGP:
Expand Down
8 changes: 5 additions & 3 deletions gpflux/experiment_support/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -64,9 +64,11 @@ def plot_layer(

def plot_layers(
X: TensorType, means: List[TensorType], covs: List[TensorType], samples: List[TensorType]
) -> None: # pragma: no cover
) -> Tuple[plt.Figure, np.ndarray]:
L = len(means)
fig, axes = plt.subplots(3, L, figsize=(L * 3.33, 10))
for i in range(L):
layer_input = X if i == 0 else samples[i - 1][0]
plot_layer(X, layer_input, means[i], covs[i], samples[i], i, axes[:, i])
axs = axes[:, i] if L > 1 else axes
plot_layer(X, layer_input, means[i], covs[i], samples[i], i, axs)
return fig, axes
2 changes: 1 addition & 1 deletion gpflux/layers/gp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def call(self, inputs: TensorType, *args: List[Any], **kwargs: Dict[str, Any]) -

# Metric names should be unique; otherwise they get overwritten if you
# have multiple with the same name
name = f"{self.name}_prior_kl" if self.name else "prior_kl"
name = f"{self.name}_kl" if self.name else "kl"
self.add_metric(loss_per_datapoint, name=name, aggregation="mean")

return outputs
Expand Down
1 change: 1 addition & 0 deletions gpflux/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
Base model classes implemented in GPflux
"""
from gpflux.models.deep_gp import DeepGP
from gpflux.models.vanilla_deep_gp import VanillaDeepGP
27 changes: 8 additions & 19 deletions gpflux/models/deep_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
""" This module provides the base implementation for DeepGP models. """

import itertools
from typing import List, Optional, Tuple, Type, Union
from typing import List, Optional, Tuple, Type, Union, Sequence

import tensorflow as tf

Expand All @@ -25,7 +25,7 @@

import gpflux
from gpflux.layers import LayerWithObservations, LikelihoodLayer
from gpflux.sampling.sample import Sample
from tensorflow.python.framework.ops import inside_function


class DeepGP(Module):
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
gpflux.layers.LikelihoodLayer, gpflow.likelihoods.Likelihood
], # fully-qualified for autoapi
*,
input_dim: Optional[int] = None,
input_dim: Optional[Union[int, Sequence[int]]] = None,
target_dim: Optional[int] = None,
default_model_class: Type[tf.keras.Model] = tf.keras.Model,
num_data: Optional[int] = None,
Expand All @@ -96,7 +96,11 @@ def __init__(
If you do not specify a value for this parameter explicitly, it is automatically
detected from the :attr:`~gpflux.layers.GPLayer.num_data` attribute in the GP layers.
"""
self.inputs = tf.keras.Input((input_dim,), name="inputs")
if isinstance(input_dim, int):
self.inputs = tf.keras.Input((input_dim,), name="inputs")
else:
self.inputs = tf.keras.Input(input_dim, name="inputs")

self.targets = tf.keras.Input((target_dim,), name="targets")
self.f_layers = f_layers
if isinstance(likelihood, gpflow.likelihoods.Likelihood):
Expand Down Expand Up @@ -267,18 +271,3 @@ def as_prediction_model(
model_class = self._get_model_class(model_class)
outputs = self.call(self.inputs)
return model_class(self.inputs, outputs)


def sample_dgp(model: DeepGP) -> Sample: # TODO: should this be part of a [Vanilla]DeepGP class?
function_draws = [layer.sample() for layer in model.f_layers]
# TODO: error check that all layers implement .sample()?

class ChainedSample(Sample):
""" This class chains samples from consecutive layers. """

def __call__(self, X: TensorType) -> tf.Tensor:
for f in function_draws:
X = f(X)
return X

return ChainedSample()
228 changes: 228 additions & 0 deletions gpflux/models/vanilla_deep_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
r"""
This module provides the base implementation for vanilla Deep GP models.
By "Vanilla" we refer to a model that can only contain standard :class:`GPLayer`\s,
this model does not support keras layers, latent variable layers, etc.
"""

from typing import List, Optional, Sequence, Tuple, Type, Union

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

import gpflow
from gpflow.base import TensorType

from gpflux.experiment_support.plotting import plot_layers
from gpflux.layers import GPLayer, LikelihoodLayer, TrackableLayer
from gpflux.models.deep_gp import DeepGP
from gpflux.sampling.sample import Sample


class VanillaDeepGP(DeepGP):
def __init__(
self,
gp_layers: List[GPLayer],
likelihood: Union[
LikelihoodLayer, gpflow.likelihoods.Likelihood
], # fully-qualified for autoapi
*,
input_dim: Optional[Union[int, Sequence[int]]] = None,
target_dim: Optional[int] = None,
default_model_class: Type[tf.keras.Model] = tf.keras.Model,
num_data: Optional[int] = None,
):
"""
:param gp_layers: The layers ``[f₁, f₂, …, fₙ]`` describing the latent
function ``f(x) = fₙ(⋯ (f₂(f₁(x))))``.
:param likelihood: The layer for the likelihood ``p(y|f)``. If this is a
GPflow likelihood, it will be wrapped in a :class:`~gpflux.layers.LikelihoodLayer`.
Alternatively, you can provide a :class:`~gpflux.layers.LikelihoodLayer` explicitly.
:param input_dim: The input dimensionality.
:param target_dim: The target dimensionality.
:param default_model_class: The default for the *model_class* argument of
:meth:`as_training_model` and :meth:`as_prediction_model`;
see the :attr:`default_model_class` attribute.
:param num_data: The number of points in the training dataset; see the
:attr:`num_data` attribute.
If you do not specify a value for this parameter explicitly, it is automatically
detected from the :attr:`~gpflux.layers.GPLayer.num_data` attribute in the GP layers.
"""
# if not all([isinstance(layer, GPLayer) for layer in gp_layers]):
# raise ValueError(
# "`VanillaDeepGP` can only be build out of `GPLayer`s. "
# "Use `DeepGP` for a hybrid model with, for example, keras layers, "
# "latent variable layers and GP layers."
# )

super().__init__(
gp_layers,
likelihood,
input_dim=input_dim,
target_dim=target_dim,
default_model_class=default_model_class,
num_data=num_data,
)

def as_dnn_model(self):
"""
Creates a Neural Network equivalent of the Deep GP model
"""
outputs = self.inputs
for layer in self.f_layers:
original_convert_to_tensor = layer._convert_to_tensor_fn
layer._convert_to_tensor_fn = tfp.distributions.Distribution.mean,
outputs = tf.convert_to_tensor(layer(outputs, training=False))
layer._convert_to_tensor_fn = original_convert_to_tensor

model = tf.keras.Model(inputs=self.inputs, outputs=outputs)
return model

# likelihood = self.likelihood_layer.likelihood
# likelihood_container = TrackableLayer()
# likelihood_container.likelihood = likelihood
# y = likelihood
# model = tf.keras.model.Model(=self.inputs, f)
# return model
# loss = gpflux.losses.LikelihoodLoss(likelihood)
# model.compile(loss=loss, optimizer="adam")
# outputs = self.call(self.inputs)

def fit(
self,
X: TensorType,
Y: TensorType,
*,
batch_size: Optional[int] = 32,
learning_rate: float = 0.01,
epochs: int = 128,
verbose: int = 0,
) -> tf.keras.callbacks.History:
"""
Compile and fit the model on training data (X, Y). Optimization
uses Adam optimizer with decaying learning rate. This method wraps
:meth:`tf.keras.Model.compile` and :meth:`tf.keras.Model.fit`.

.. note:
Parameter docs copied from Keras documention.

:param X: Input data. A Numpy array (or array-like).
:param Y: Target data. Like the input data `X`.
:param batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
If set to `None`, will use size of dataset.
:param epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `X` and `Y`
data provided.
If unspecified, `epochs` will default to 128.
:param verbose: 0, 1, or 2. Verbosity mode.
0 = silent, 1 = progress bar, 2 = one line per epoch.
Note that the progress bar is not particularly useful when
logged to a file, so verbose=2 is recommended when not running
interactively (eg, in a production environment).

:returns:
A `History` object. Its `History.history` attribute is
a record of training loss values and metrics values
at successive epochs, as well as validation loss values
and validation metrics values (if applicable).
"""
training_model = self.as_training_model()
training_model.compile(optimizer=tf.optimizers.Adam(learning_rate=learning_rate))

callbacks = [
tf.keras.callbacks.ReduceLROnPlateau(
"loss", factor=0.95, patience=3, min_lr=1e-6, verbose=verbose
),
]

history = training_model.fit(
{"inputs": X, "targets": Y},
batch_size=batch_size or len(X),
epochs=epochs,
callbacks=callbacks,
verbose=verbose,
)
return history

def predict_f(self, X: TensorType) -> Tuple[TensorType, TensorType]:
prediction_model = self.as_prediction_model(X)
output = prediction_model.predict(X)
return output.f_mean, output.f_var

def predict_y(self, X: TensorType) -> Tuple[TensorType, TensorType]:
prediction_model = self.as_prediction_model(X)
output = prediction_model.predict(X)
return output.y_mean, output.y_var

def sample(self) -> Sample:
function_draws = [layer.sample() for layer in self.f_layers]

class ChainedSample(Sample):
""" This class chains samples from consecutive layers. """

def __call__(self, X: TensorType) -> tf.Tensor:
self.inner_layers = []
for f in function_draws:
X = f(X)
self.inner_layers.append(X)
return X

return ChainedSample()

def plot(
self, X: TensorType, Y: TensorType, num_test_points: int = 100, x_margin: float = 3.0
) -> Tuple[plt.Figure, np.ndarray]:

if X.shape[1] != 1:
raise NotImplementedError("DeepGP plotting is only supported for 1D models.")

means, covs, samples = [], [], []

X_test = np.linspace(X.min() - x_margin, X.max() + x_margin, num_test_points).reshape(-1, 1)
layer_input = X_test

for layer in self.f_layers:
layer.full_cov = True
layer.num_samples = 5
layer_output = layer(layer_input)

mean = layer_output.mean()
cov = layer_output.covariance()
sample = tf.convert_to_tensor(layer_output) # generates num_samples samples...

print(mean.shape)
print(cov.shape)
print(sample.shape)

layer_input = sample[0] # for the next layer

means.append(mean.numpy().T) # transpose to go from [1, N] to [N, 1]
covs.append(cov.numpy())
samples.append(sample.numpy())

fig, axes = plot_layers(X_test, means, covs, samples)

if axes.ndim == 1:
axes[-1].plot(X, Y, "kx")
elif axes.ndim == 2:
axes[-1, -1].plot(X, Y, "kx")

return fig, axes