Using a new op when importing a graph in tensorflow
I am new to TensorFlow. I am trying to import a trained TensorFlow network with checkpoint files. The network I am using has a custom op that works great when I use it in Python. However, I need to freeze the graph because I have to use the C ++ API. I am calling freeze_graph
with the following command from the TensorFlow base directory:
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=<local path>/data/graph_vgg.pb --input_checkpoint=<local path>/data/VGGnet_fast_rcnn_iter_70000.ckpt --output_node_names="cls_prob,bbox_pred" --output_graph=<local path>/graph_frozen.pb
But I am getting the following error when I try to freeze the graph.
Traceback (most recent call last):
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 202, in <module>
app.run(main=main, argv=[sys.argv[0]] + unparsed)
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 44, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 134, in main
FLAGS.variable_names_blacklist)
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 99, in freeze_graph
_ = importer.import_graph_def(input_graph_def, name="")
File "<local path>/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/framework/importer.py", line 260, in import_graph_def
raise ValueError('No op named %s in defined operations.' % node.op)
ValueError: No op named RoiPool in defined operations.
The input graph has a node with an op type RoiPool
that TensorFlow does not recognize. I researched the code that is throwing this error and it looks something like this: op is not registered with TensorFlow. I have an embedded file .so
with me. Should I copy something? I couldn't find anything like this on the internet. Any help or pointers would be great. I spent a lot of time on this problem. The code works fine in python and the layer that the op uses is in the project directory. Please help me understand what I need to do to make it work.
Edit: This is the custom op code that is used on the net.
source to share
I'm not familiar with this particular RoiPooling implementation, but the way I usually set up a custom op that requires freezing is roi_pooling_op.cc, and its associated python file (defines the gradient and imports * .so) in // tensorflow / user _ops ...
The BUILD file in the // tensorflow / user_ops directory should have
tf_custom_op_library(
name = "roi_pooling_op.so",
srcs = ["roipooling_op.cc"],
)
py_library(
name = "roi_pooling_op_py",
srcs = ["roi_pooling.py"],
data = [":roi_pooling_op.so"],
srcs_version = "PY2AND3",
)
* data = [":roi_pooling_op.so"]
not mentioned in the Tensorflow docs, but you don't have to dig through your local bazel-bin directory and instead use tf.resource_loader.get_path_to_datafile
* .so to import
_roi_pooling_module = tf.load_op_library(tf.resource_loader.get_path_to_datafile("roi_pooling_op.so"))
roi_pool = _roi_pooling_module.roi_pool
roi_pool_grad = _roi_pooling_module.roi_pool_grad
@ops.RegisterGradient("RoiPool")
def _roi_pool_grad(op, grad, _):
grad_out = roi_pool_grad(...)
return grad_out, None
Update the freeze line, in the BUILD // tensorflow / python / tools directory, add "//tensorflow/user_ops:roi_pooling_op_py",
pyzeb.exe as a dependency.
Finally, rebuild and install everything (custom-op, freeze_graph and pip package / wheel)
bazel build --config opt //tensorflow/user_ops:roi_pooling_op.so
bazel build --config opt //tensorflow/user_ops:roi_pooling_op_py
bazel build --config opt //tensorflow/python/tools:freeze_graph
bazel build --config opt //tensorflow/tools/pip_package:build_pip_package
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
pip install --ignore-installed --upgrade /tmp/tensorflow_pkg/tensorflow-1.2.1-py2-none-any.whl
Now you can use it in your python code with
from tensorflow.user_ops import roi_pooling
Now you can freeze the chart without any problem.
source to share
I followed Jared, and I think he got most of the way, but I needed a last fragment of fooobar.com/questions/1006263 / ... . I pasted tf.load_op_library('/path/to/custom_op.so')
right before calling import_graph_def
directly into freeze_graph.py
. Then I managed to freeze the graph.
source to share