How do TensorArray and while_loop work together in tensorflow?

I am trying to create a very simple example for the combination of TensorArray and while_loop:

# 1000 sequence in the length of 100
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix")
matrix_rows = tf.shape(matrix)[0]
ta = tf.TensorArray(tf.float32, size=matrix_rows)
ta = ta.unstack(matrix)

init_state = (0, ta)
condition = lambda i, _: i < n
body = lambda i, ta: (i + 1, ta.write(i,ta.read(i)*2))

# run the graph
with tf.Session() as sess:
    (n, ta_final) = sess.run(tf.while_loop(condition, body, init_state),feed_dict={matrix: tf.ones(tf.float32, shape=(100,1000))})
    print (ta_final.stack())

      

But I am getting the following error:

ValueError: Tensor("while/LoopCond:0", shape=(), dtype=bool) must be from the same graph as Tensor("Merge:0", shape=(), dtype=float32).

      

Anyone have any idea what the problem is?

+2


source to share


1 answer


There are several things in the code. First, you do not need to unbend the matrix in TensorArray

order to use it inside the loop, you can safely refer to the matrix Tensor

inside the body and index it using the notation matrix[i]

. Another problem is the different data types between your matrix ( tf.int32

) and TensorArray

( tf.float32

), based on your code, you multiply the matrix ints by 2 and write the result to the array so that it is int32. Finally, when you want to read the final result of the loop, the correct operation TensorArray.stack()

you should be doing in your call session.run

.

Here's a working example:



import numpy as np
import tensorflow as tf    

# 1000 sequence in the length of 100
matrix = tf.placeholder(tf.int32, shape=(100, 1000), name="input_matrix")
matrix_rows = tf.shape(matrix)[0]
ta = tf.TensorArray(dtype=tf.int32, size=matrix_rows)

init_state = (0, ta)
condition = lambda i, _: i < matrix_rows
body = lambda i, ta: (i + 1, ta.write(i, matrix[i] * 2))
n, ta_final = tf.while_loop(condition, body, init_state)
# get the final result
ta_final_result = ta_final.stack()

# run the graph
with tf.Session() as sess:
    # print the output of ta_final_result
    print sess.run(ta_final_result, feed_dict={matrix: np.ones(shape=(100,1000), dtype=np.int32)}) 

      

+3


source







All Articles