Importing Tensorflow Model in Java
I am trying to import and use my trainable model (Tensorflow, Python) in Java.
I was able to save the model in Python but ran into problems when trying to make predictions using the same model in Java.
Here , you can see the python code for initializing, training, saving the model.
Here , you can see the Java code for importing and making predictions for the input values.
The error message I receive is:
Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
[[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)
I believe the problem is in some place in the python code, but I couldn't find it.
Any help is appreciated!
Thank,
Peter
source to share
The Java function importGraphDef()
only imports the computational graph (written tf.train.write_graph
in your Python code), it doesn't load the values โโof the trainable variables (stored in the breakpoint), so you get an error complaining about uninitialized variables.
The TensorFlow SavedModel format , on the other hand, includes all the information about the model (graph, checkpoint state, other metadata) and use in Java you want to use SavedModelBundle.load
to create a session initialized with trainable variable values.
To export a model in this format from Python you can take a look at the related question Expand Revised SavedModel Source File in Google Cloud Engine
In your case, this should mean the following in Python:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()
And it is called that through save_model(session, x, yhat)
And then in Java, load the model using:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}
Hope it helps.
source to share
Fwiw, Deeplearning4j allows importing models trained on TensorFlow with Keras 1.0 (Keras 2.0 support on the go).
https://deeplearning4j.org/model-import-keras
We've also built a library called Jumpy, which is a wrapper around Numpy and Pyjnius arrays that uses pointers instead of copying data, making it more efficient than Py4j when dealing with tensors.
source to share
Your python model will, of course, fail:
sess.run(init) #<---this will fail
save_model(sess)
error = tf.reduce_mean(tf.square(prediction - y))
#accuracy = tf.reduce_mean(tf.cast(error, 'float'))
print('Error:', error)
init
not defined in the model - I'm not sure what you want to achieve at this point, but that should give you a starting point
source to share