How to feed a boolean placeholder using TensorFlowInferenceInterface.java?
I am trying to run training in Keras Tensorflow graph using Java Tensorflow API.
In addition to the standard input image placeholder, this graph contains the "keras_learning_phase" placeholder, which is required to feed with a boolean value.
The point is that there is no boolean value in TensorFlowInferenceInterface - you can only supply it using float , double , int, or byte .
Obviously when I try to pass int this tensor with this code:
inferenceInterface.fillNodeInt("keras_learning_phase",
new int[]{1}, new int[]{0});
I get
tensorflow_inference_jni.cc:207 Error during output: Internal: Output 0 of type int32 does not match the declared output type bool for node _recv_keras_learning_phase_0 = _Recvclient_terminated = true, recv_device = "/ work: local / replica: 0 / task: 0 / CPU: 0" , send_device = "/ work: local / replica: 0 / task: 0 / CPU: 0", send_device_incarnation = 4742451733276497694, tensor_name = "keras_learning_phase", tensor_type = DT_BOOL, _device = "/ work: local / replica: 0 / task 0 / CPU: 0 "
Is there a way to get around this?
Perhaps there is some way to explicitly convert the Placeholder node to a Constant graph ?
Or perhaps it is possible to avoid creating this Placeholder in the graphic first?
source to share
The class TensorFlowInferenceInterface
is essentially a convenience wrapper on top of the full TensorFlow Java API that supports booleans.
You could add a method TensorFlowInferenceInterface
to accomplish what you want. Similarly fillNodeInt
, you can add the following (note the caveat that booleans are represented as one byte in TensorFlow):
public void fillNodeBool(String inputName, int[] dims, bool[] src) {
byte[] b = new byte[src.length];
for (int i = 0; i < src.length; ++i) {
b[i] = src[i] ? 1 : 0;
}
addFeed(inputName, Tensor.create(DatType.BOOL, mkDims(dims), ByteBuffer.wrap(b)));
}
Hope it helps. If it works, I would encourage you to contribute to the TensorFlow codebase.
source to share
In addition to ash , since the Tensorflow API has changed a bit. Using this worked for me:
public void feed(String inputName, boolean[] src, long... dims) {
byte[] b = new byte[src.length];
for (int i = 0; i < src.length; i++) {
b[i] = src[i] ? (byte) 1 : (byte) 0;
}
addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
}
source to share