TF LSTM: save state from training session for prediction session later

I am trying to save the last state of the LSTM from learning to reuse it in the prediction phase later. The problem I am running into is that in the TF LSTM model, the state is passed from one iteration to training to the next combination of placeholder and numpy array - none of which seems to be included in the default graph when the session persists ...

To work around this, I create a dedicated TF variable to hold the latest version of the state to add to the session graph, for example:

# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')

      

This seems to add a variable savedState

to the saved session graph nicely and is easily restored later with the rest of the session.

The problem is that the only way I was able to use this variable later in the restored session is if I initialize all the variables in the session AFTER I restored it (which looks like resetting all the trainable variables including the weight / offsets / etc!). If I initialize the variables first and THEN restore the session (which works great in terms of storing the trained varialbes), then I get the error that I am trying to access an uninitialized variable.

I know there is a way to initialize a specific separate varialbe (which I use when saving it natively), but the problem is when we restore them, we call them by name as strings, are we not just looping through the variable itself ?!

# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')

      

What is the correct way to do this? As a workaround, I am currently storing the state in the CSV as a numpy array and then restoring it the same way. It works fine, but is clearly not the cleanest solution, given that every other aspect of saving / restoring a TF session works fine.

Any suggestions are appreciated!

** EDIT: Here's the code that works well as described in the accepted answer below:

# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')


# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])

      

I haven't tested yet to see if this approach involves inadvertently initializing any other variables in the saved session, but can't see why it should, since we're only running the specific one.

+1


source to share


1 answer


The problem is that creating a new one tf.Variable

after creation Saver

means that it Saver

doesn't know the new variable. It is still saved in the metagraph, but not saved in the checkpoint:

import tensorflow as tf
with tf.Graph().as_default():
  var_a = tf.get_variable("a", shape=[])
  saver = tf.train.Saver()
  var_b = tf.get_variable("b", shape=[])
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  initializer = tf.global_variables_initializer()
  with tf.Session() as session:
    session.run([initializer])
    saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
  new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  with tf.Session() as session:
    new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!

      

I annotated a quick reproduction of your problem above with variables that it knows about Saver

.



Now the solution is relatively simple. I suggest creating Variable

before Saver

and then tf.assign to update its value (make sure you run the op returns tf.assign

). The assigned value will be saved at the breakpoints and restored just like other variables.

This can be better handled Saver

as a special case when None

passed to its constructor argument var_list

(i.e. it can automatically fetch new variables). Feel free to open a feature request on Github for this.

+1


source







All Articles