How to feed a boolean placeholder using TensorFlowInferenceInterface.java?

I am trying to run training in Keras Tensorflow graph using Java Tensorflow API.
In addition to the standard input image placeholder, this graph contains the "keras_learning_phase" placeholder, which is required to feed with a boolean value.

The point is that there is no boolean value in TensorFlowInferenceInterface - you can only supply it using float , double , int, or byte .

Obviously when I try to pass int this tensor with this code:

inferenceInterface.fillNodeInt("keras_learning_phase",  
                               new int[]{1}, new int[]{0});

      

I get

tensorflow_inference_jni.cc:207 Error during output: Internal: Output 0 of type int32 does not match the declared output type bool for node _recv_keras_learning_phase_0 = _Recvclient_terminated = true, recv_device = "/ work: local / replica: 0 / task: 0 / CPU: 0" , send_device = "/ work: local / replica: 0 / task: 0 / CPU: 0", send_device_incarnation = 4742451733276497694, tensor_name = "keras_learning_phase", tensor_type = DT_BOOL, _device = "/ work: local / replica: 0 / task 0 / CPU: 0 "

Is there a way to get around this?
Perhaps there is some way to explicitly convert the Placeholder node to a Constant graph ?
Or perhaps it is possible to avoid creating this Placeholder in the graphic first?

+3


source to share


2 answers


The class TensorFlowInferenceInterface

is essentially a convenience wrapper on top of the full TensorFlow Java API that supports booleans.

You could add a method TensorFlowInferenceInterface

to accomplish what you want. Similarly fillNodeInt

, you can add the following (note the caveat that booleans are represented as one byte in TensorFlow):



public void fillNodeBool(String inputName, int[] dims, bool[] src) {
  byte[] b = new byte[src.length];
  for (int i = 0; i < src.length; ++i) {
    b[i] = src[i] ? 1 : 0;
  }
  addFeed(inputName, Tensor.create(DatType.BOOL, mkDims(dims), ByteBuffer.wrap(b)));
}

      

Hope it helps. If it works, I would encourage you to contribute to the TensorFlow codebase.

+5


source


In addition to ash , since the Tensorflow API has changed a bit. Using this worked for me:



public void feed(String inputName, boolean[] src, long... dims) {
  byte[] b = new byte[src.length];
  for (int i = 0; i < src.length; i++) {
    b[i] = src[i] ? (byte) 1 : (byte) 0;
  }
  addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
}

      

0


source







All Articles