Tensorflow: tf.case with parameterized callable, list of events defined in for-loop

I am trying to implement a case branch in a learning loop for an ensemble of autoencoders: depending on a certain condition, only one specific autoencoder needs to be updated. I'm trying to implement this using tf.case (), but it doesn't work as I expected ...

def f(k_win):

    update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])

    return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)] 

winner_index = tf.argmin(Cost_Alpha_List, 0)



Case_List = []

for k in range(N_Class): 

    Case = (tf.equal(winner_index,k), lambda: f(k))   

    Case_List.append(Case)


Execution_List = tf.case(Case_List, lambda: f(0))

      

winner_index: autocoding index to update

f (k_win): Returns all fixes for a specific AE index

Case_List: contains pairs of boolean and parameterized functions

Execution_List: Called for sess.run () in an execution loop.

The k parameter in the for-loop should define a Case_List, specifically "lambda: f (k)", but it looks like after the list is created all "lambda: f (k)" are set to last k = N_Classes-1: the effect is that only the latest autoencoder will be updated, not the one with "winner_index". Does anyone know what's going on here ...?

Thank.

+3


source to share


1 answer


The problem is that the lambdas you define use a global variable k

, which by the time the function is called has the last value that was required in the ( N_Class - 1

) loop .

A simpler example:

lst = []
for k in range(10):
    lst.append(lambda: k * k)
print([lst_i() for lst_i in lst])

      

gives:

[81, 81, 81, 81, 81, 81, 81, 81, 81, 81]

      



Instead:

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

      

This answer explains the problem better and indicates several ways to overcome it. In your case, you can do something like this:

def f(k_win):

    update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])

    return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)] 

winner_index = tf.argmin(Cost_Alpha_List, 0)



Case_List = []

for k in range(N_Class): 

    Case = (tf.equal(winner_index,k), (lambda kk: lambda: f(kk))(k))   

    Case_List.append(Case)


Execution_List = tf.case(Case_List, lambda: f(0))

      

+1


source







All Articles