Reuse value of TensorFlow variable between sessions without writing to disk

In sklearn, I'm used to having a model that I can run fit

and then predict

on. However, with TensorFlow, I am having problems loading the received parameters from fit

when I call predict

. It boils down to the fact that I don't know how to reuse the value of a variable between sessions. For example,

import tensorflow as tf

x = tf.Variable(0.0)

# fit code
with tf.Session() as sess1:
    sess1.run(tf.global_variables_initializer())
    sess1.run(tf.assign(x, 1.0)) # at end of training, x = 1.0

# predict code
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    print(sess2.run(x)) # want this to be 1.0, but is 0.0

      

I can think of one workaround, but it seems really hacky and will be annoying if there are multiple variables I want to reuse:

import tensorflow as tf

x = tf.Variable(0.0)

# fit code
with tf.Session() as sess1:
    sess1.run(tf.global_variables_initializer())
    sess1.run(tf.assign(x, 1.0)) # at end of training, x = 1.0
    learned_x = sess1.run(x) # remember value of learned x at end of session

# predict code
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    sess2.run(tf.assign(x, learned_x))
    print(sess2.run(x)) # prints 1.0

      

How can I reuse variables between sessions without writing to disk (i.e. using tf.train.Saver

)? Is the workaround I wrote above suitable for this?

+3


source to share


1 answer


To mimic sklearn's model, just wrap session

in one class so you can share it between methods, e.g.

class Model:
    def __init__(self):
        self.graph = self.build_graph()
        self.session = tf.Session()
        self.session.run(tf.global_variables_initializer())

    def build_graph(self):
        return {'x': tf.Variable(0.0)}

    def fit(self):
        self.session.run(tf.assign(self.graph['x'], 1.0))

    def predict(self):
        print(self.session.run(self.graph['x']))

    def close(self):
        tf.reset_default_graph()
        self.session.close()

m = Model()
m.fit()
m.predict()
m.close()

      



Make sure you manually close session

manually and handle exceptions.

0


source







All Articles