Using a loop to fill a matrix in TensorFlow
So I am trying to fill a matrix in TensorFlow, the size of this matrix changes depending on the input, so I am using TensorArray for this. Basically, the Numpy equivalent of this:
areas = np.zeros((len(rows)-1,len(cols)-1))
for r in range(len(rows)-1):
for c in range(len(cols)-1):
areas[r,c] = (rows[r+1]-rows[r])*(cols[c+1]-cols[c])
I tried to implement this in TensorFlow with tf.while_loop
and tf.TensorArray
:
i = tf.constant(0)
areas = tf.TensorArray(dtype='float32', size=length_rc-1)
while_condition = lambda i, rows, areas: tf.less(i, length_rc-1)
def row_loop(i, rows, areas):
j = tf.constant(0)
area = tf.TensorArray(dtype='float32', size=length_rc-1)
while_condition = lambda j, cols, area: tf.less(j, length_rc-1)
def col_loop(j, cols, area):
area = area.write(j, tf.multiply(tf.subtract(rows[i+1],rows[i]),tf.subtract(cols[j+1],cols[j])))
return [tf.add(j,1), cols, area]
r = tf.while_loop(while_condition, col_loop, [j, cols, areas])
areas = areas.write(i, r[2].stack())
return [tf.add(i, 1), rows, areas]
# do the loop:
r = tf.while_loop(while_condition, row_loop, [i, rows, areas])
areas = r[2].stack()
p = sess.run([areas], feed_dict={pred_batch: pred, gt_batch: gt})
However, it doesn't seem to work and I'm not sure why. As you can see my code is similar to this post: Howe TensorArray and while_loop work together in tensorflow?
But that doesn't seem to work, does anyone know what the problem is? The specific error I am getting:
ValueError: Inconsistent shapes: saw (?,) but expected () (and infer_shape=True)
source to share
What does not work? What do you expect from what actually happens?
On the one hand, your loop condition in both cases looks like 1. In the first case, you will skip the last row and the last column, since it range
only produces values ββsmaller than the argument.
Likewise, in the second case, your condition is tf.less(i, length_rc-1)
: you probably want it to i
be equal length_rc-1
in the last iteration, not less. The condition must be tf.less(i, length_rc)
.
source to share