Tensorflow: tf.case with parameterized callable, list of events defined in for-loop
I am trying to implement a case branch in a learning loop for an ensemble of autoencoders: depending on a certain condition, only one specific autoencoder needs to be updated. I'm trying to implement this using tf.case (), but it doesn't work as I expected ...
def f(k_win):
update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])
return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)]
winner_index = tf.argmin(Cost_Alpha_List, 0)
Case_List = []
for k in range(N_Class):
Case = (tf.equal(winner_index,k), lambda: f(k))
Case_List.append(Case)
Execution_List = tf.case(Case_List, lambda: f(0))
winner_index: autocoding index to update
f (k_win): Returns all fixes for a specific AE index
Case_List: contains pairs of boolean and parameterized functions
Execution_List: Called for sess.run () in an execution loop.
The k parameter in the for-loop should define a Case_List, specifically "lambda: f (k)", but it looks like after the list is created all "lambda: f (k)" are set to last k = N_Classes-1: the effect is that only the latest autoencoder will be updated, not the one with "winner_index". Does anyone know what's going on here ...?
Thank.
source to share
The problem is that the lambdas you define use a global variable k
, which by the time the function is called has the last value that was required in the ( N_Class - 1
) loop .
A simpler example:
lst = []
for k in range(10):
lst.append(lambda: k * k)
print([lst_i() for lst_i in lst])
gives:
[81, 81, 81, 81, 81, 81, 81, 81, 81, 81]
Instead:
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
This answer explains the problem better and indicates several ways to overcome it. In your case, you can do something like this:
def f(k_win):
update_BW = tf.train.AdamOptimizer(learning_rate=learningrate).minimize(Cost_List[k_win])
return update_MSE_winner(k_win) + [update_BW, update_n_List(k_win), update_n_alpha_List(k_win)]
winner_index = tf.argmin(Cost_Alpha_List, 0)
Case_List = []
for k in range(N_Class):
Case = (tf.equal(winner_index,k), (lambda kk: lambda: f(kk))(k))
Case_List.append(Case)
Execution_List = tf.case(Case_List, lambda: f(0))
source to share