Importing Tensorflow Model in Java

I am trying to import and use my trainable model (Tensorflow, Python) in Java.

I was able to save the model in Python but ran into problems when trying to make predictions using the same model in Java.

Here , you can see the python code for initializing, training, saving the model.

Here , you can see the Java code for importing and making predictions for the input values.

The error message I receive is: Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7 [[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]] at org.tensorflow.Session.run(Native Method) at org.tensorflow.Session.access$100(Session.java:48) at org.tensorflow.Session$Runner.runHelper(Session.java:285) at org.tensorflow.Session$Runner.run(Session.java:235) at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)

I believe the problem is in some place in the python code, but I couldn't find it.

Any help is appreciated!

Thank,

Peter

+3


source to share


3 answers


The Java function importGraphDef()

only imports the computational graph (written tf.train.write_graph

in your Python code), it doesn't load the values โ€‹โ€‹of the trainable variables (stored in the breakpoint), so you get an error complaining about uninitialized variables.

The TensorFlow SavedModel format , on the other hand, includes all the information about the model (graph, checkpoint state, other metadata) and use in Java you want to use SavedModelBundle.load

to create a session initialized with trainable variable values.

To export a model in this format from Python you can take a look at the related question Expand Revised SavedModel Source File in Google Cloud Engine

In your case, this should mean the following in Python:

def save_model(session, input_tensor, output_tensor):
  signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
    outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
  )
  b = saved_model_builder.SavedModelBuilder('/tmp/model')
  b.add_meta_graph_and_variables(session,
                                 [tf.saved_model.tag_constants.SERVING],
                                 signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
  b.save() 

      



And it is called that through save_model(session, x, yhat)

And then in Java, load the model using:

try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
  // b.session().run(...)
}

      

Hope it helps.

+6


source


Fwiw, Deeplearning4j allows importing models trained on TensorFlow with Keras 1.0 (Keras 2.0 support on the go).

https://deeplearning4j.org/model-import-keras



We've also built a library called Jumpy, which is a wrapper around Numpy and Pyjnius arrays that uses pointers instead of copying data, making it more efficient than Py4j when dealing with tensors.

https://deeplearning4j.org/jumpy

+2


source


Your python model will, of course, fail:

sess.run(init) #<---this will fail
save_model(sess)
error = tf.reduce_mean(tf.square(prediction - y))

#accuracy = tf.reduce_mean(tf.cast(error, 'float'))
print('Error:', error)

      

init

not defined in the model - I'm not sure what you want to achieve at this point, but that should give you a starting point

+1


source







All Articles