How to visualize attention weight from AttentionWrapper

I want to visualize the attention scores in the latest lateorflow (1.2). I am using AttentionWrapper in contrib.seq2seq to create an RNNCell, with BasicDecoder as decoder, and then using dynamic_decode () to generate outputs step by step.

How can I access the weights for all steps? Thank!

+3


source to share


1 answer


You can access attention weights by setting the alignment_history = True flag in the AttentionWrapper definition.

Here's an example:

# Define attention mechanism
attn_mech = tf.contrib.seq2seq.LuongMonotonicAttention(
    num_units = attention_unit_size, memory = decoder_outputs,
    memory_sequence_length = input_lengths)

# Define attention cell
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
    cell = decoder_cell, attention_mechanism = attn_mech,
    alignment_history=True)

# Define train helper
train_helper = tf.contrib.seq2seq.TrainingHelper(
    inputs = encoder_inputs, 
    sequence_length = input_lengths)

# Define decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
    cell = attn_cell, 
    helper = train_helper, initial_state=decoder_initial_state)

# Dynamic decoding
dec_outputs, dec_states, _ = tf.contrib.seq2seq.dynamic_decode(decoder)

      

And then, inside the session, you can access the weights as shown below:



with tf.Session() as sess:
    ...
    alignments = sess.run(dec_states.alignment_history.stack(), feed_dict)

      

Finally, you can visualize attention (alignment) like this:

def plot_attention(attention_map, input_tags = None, output_tags = None):    
    attn_len = len(attention_map)

    # Plot the attention_map
    plt.clf()
    f = plt.figure(figsize=(15, 10))
    ax = f.add_subplot(1, 1, 1)

    # Add image
    i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues')

    # Add colorbar
    cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
    cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
    cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)

    # Add labels
    ax.set_yticks(range(attn_len))
    if output_tags != None:
      ax.set_yticklabels(output_tags[:attn_len])

    ax.set_xticks(range(attn_len))
    if input_tags != None:
      ax.set_xticklabels(input_tags[:attn_len], rotation=45)

    ax.set_xlabel('Input Sequence')
    ax.set_ylabel('Output Sequence')

    # add grid and legend
    ax.grid()

    plt.show()

# input_tags - word representation of input sequence, use None to skip
# output_tags - word representation of output sequence, use None to skip
# i - index of input element in batch

plot_attention(alignments[:, i, :], input_tags, output_tags)

      

enter image description here

+1


source







All Articles