Tracking Multiple Losses with Keras

For networks such as VAEs with competing losses, it is useful to track each loss independently. That is, it is useful to see the total loss as well as the data term and KL code terms.

Is this possible in Keras? Loss can be reconstructed using vae.losses, but they are tensorflow layers and therefore cannot be used in keras (for example, cannot create a second model that calculates vae loss as output).

It seems like doing this to add them to the list of metrics at compile time, but they don't fit the metrics model.

Here is some sample code, sorry for the length, it is slightly adapted from the example code from Keras. The main difference is that I have explicitly translated the calculation of the KL div to the fetch layer, which feels more natural than the original code example.

'''This script demonstrates how to build a variational autoencoder with Keras.

Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
'''    
from keras.layers import Input, Dense, Lambda, Layer
from keras.models import Model
from keras import backend as K
from keras import metrics

batch_size = 100
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0


x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

class CustomSamplingLayer(Layer):
    def __init__(self, **kwargs):
        super(CustomSamplingLayer, self).__init__(**kwargs)

    def kl_div_loss(self, z_mean, z_log_var):
        kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(kl_loss)

    def call(self, inputs):
        z_mean = inputs[0]
        z_log_var = inputs[1]
        loss = self.kl_div_loss(z_mean, z_log_var)
        self.add_loss(loss, inputs=inputs)
        epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,
                                  stddev=epsilon_std)
        return z_mean + K.exp(z_log_var / 2) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
z = CustomSamplingLayer()([z_mean, z_log_var])

# we instantiate these layers separately so as to reuse them later
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

# Custom loss layer
class CustomVariationalLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(CustomVariationalLayer, self).__init__(**kwargs)

    def vae_loss(self, x, x_decoded_mean):
        xent_loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
        return K.mean(xent_loss)

    def call(self, inputs):
        x = inputs[0]
        x_decoded_mean = inputs[1]
        loss = self.vae_loss(x, x_decoded_mean)
        self.add_loss(0.0 * loss, inputs=inputs)
        return x_decoded_mean
y = CustomVariationalLayer()([x, x_decoded_mean])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

      

+3


source to share


1 answer


I tried to do something similar on the gumbel-softmax (definitive) VAE implemented in Keras here . The trick for me was to use metrics as you intended. Here's the setting for the model:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from keras.layers import Input, Dense, Lambda
from keras.models import Model, Sequential
from keras import backend as K
from keras.datasets import mnist
from keras.activations import softmax
from keras.objectives import binary_crossentropy as bce


batch_size = 200
data_dim = 784
M = 10
N = 10
nb_epoch = 3
epsilon_std = 0.01

tmp = []

anneal_rate = 0.0003
min_temperature = 0.5

tau = K.variable(5.0, name="temperature")
x = Input(batch_shape=(batch_size, data_dim))
h = Dense(256, activation='relu')(Dense(512, activation='relu')(x))
logits_y = Dense(M*N)(h)

def sampling(logits_y):
    U = K.random_uniform(K.shape(logits_y), 0, 1)
    y = logits_y - K.log(-K.log(U + 1e-20) + 1e-20)
    y = softmax(K.reshape(y, (-1, N, M)) / tau)
    y = K.reshape(y, (-1, N*M))
    return y

z = Lambda(sampling, output_shape=(M*N,))(logits_y)
generator = Sequential()
generator.add(Dense(256, activation='relu', input_shape=(N*M, )))
generator.add(Dense(512, activation='relu'))
generator.add(Dense(data_dim, activation='sigmoid'))
x_hat = generator(z)

      

Here I define the total loss to optimize the model and then the individual functions for the components. Note that it KL_loss

takes two arguments that are not used. Keras will throw an exception if your metric function does not accept these two arguments.

def gumbel_loss(x, x_hat):
    q_y = K.reshape(logits_y, (-1, N, M))
    q_y = softmax(q_y)
    log_q_y = K.log(q_y + 1e-20)
    kl_tmp = q_y * (log_q_y - K.log(1.0/M))
    KL = K.sum(kl_tmp, axis=(1, 2))
    elbo = data_dim * bce(x, x_hat) - KL
    return elbo

def KL_loss(y_true, y_pred):
    q_y = K.reshape(logits_y, (-1, N, M))
    q_y = softmax(q_y)
    log_q_y = K.log(q_y + 1e-20)
    kl_tmp = q_y * (log_q_y - K.log(1.0/M))
    KL = K.sum(kl_tmp, axis=(1, 2))
    return K.mean(-KL)

def bce_loss(y_true, y_pred):
    return K.mean(data_dim * bce(y_true, y_pred))

      



Then it is compiled and run.

vae = Model(x, x_hat)
vae.compile(optimizer='adam', loss=gumbel_loss,
            metrics = [KL_loss, bce_loss])

# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

for e in range(nb_epoch):
    vae.fit(x_train, x_train,
        shuffle=True,
        epochs=1,
        batch_size=batch_size,
        validation_data=(x_test, x_test))
    out = vae.predict(x_test, batch_size = batch_size)
    K.set_value(tau, np.max([K.get_value(tau) * np.exp(- anneal_rate * e), min_temperature]))

      

I experimented with callbacks and a lot of other things before getting my head around this, so hopefully this helps.

+1


source







All Articles