Tensorflow how to access all RNN average states, not just the last state

My understanding is that it tf.nn.dynamic_rnn

returns the output of an RNN cell (like LSTM) at every time step as well as the final state. How can I access the states of the cells at all time steps, not just the last one? For example, I want to be able to average all hidden states and then use them in a subsequent layer.

Here's how I define an LSTM cell and then expand it with tf.nn.dynamic_rnn

. But this only gives the last state of the LSTM cell.

import tensorflow as tf
import numpy as np

# [batch-size, sequence-length, dimensions] 
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 6]

cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True)

outputs, last_state = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
out, last = sess.run([outputs, last_state], feed_dict=None)

      

+3


source to share


2 answers


Something like this should work.

import tensorflow as tf
import numpy as np


class CustomRNN(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell
        self._output_size = self._state_size # change the output size to the state size
        return returns
    def __call__(self, inputs, state):
        output, next_state = super(CustomRNN, self).__call__(inputs, state)
        return next_state, next_state # return two copies of the state, instead of the output and the state

X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 10]

cell = CustomRNN(num_units=64)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
states, last_state = sess.run([outputs, last_states], feed_dict=None)

      



This uses concatenated states as I don't know if you can store an arbitrary number of states of a tuple. A state variable is of the form (batch_size, max_time_size, state_size).

+1


source


I would point you to this thread (highlights me):

You can write a variant of LSTMCell that returns both state tensors as part of the output if you need both c and h for each time step. If you just need h state , then output each time step .



As @jasekp wrote in his comment, the output is indeed part of the h

state. The method dynamic_rnn

will then just add the whole piece h

over time (see the doc line _dynamic_rnn_loop

in that file )

def _dynamic_rnn_loop(cell,
                      inputs,
                      initial_state,
                      parallel_iterations,
                      swap_memory,
                      sequence_length=None,
                      dtype=None):
  """Internal implementation of Dynamic RNN.
    [...]
    Returns:
    Tuple `(final_outputs, final_state)`.
    final_outputs:
      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
      objects, then this returns a (possibly nsted) tuple of Tensors matching
      the corresponding shapes.

      

+1


source







All Articles