16
votes

I'm using keras 1.0.1 I'm trying to add an attention layer on top of an LSTM. This is what I have so far, but it doesn't work.

input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1)(lstm))
att = Reshape((-1, input_length))(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
merge = Merge([att, lstm], "mul")
hid = Merge("sum")(merge)

last = Dense(self.HID_DIM, activation="relu")(hid)

The network should apply an LSTM over the input sequence. Then each hidden state of the LSTM should be input into a fully connected layer, over which a Softmax is applied. The softmax is replicated for each hidden dimension and multiplied by the LSTM hidden states elementwise. Then the resulting vector should be averaged.

EDIT: This compiles, but I'm not sure if it does what I think it should do.

input_ = Input(shape=(input_length, input_dim))
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
att = TimeDistributed(Dense(1))(lstm)
att = Flatten()(att)
att = Activation(activation="softmax")(att)
att = RepeatVector(self.HID_DIM)(att)
att = Permute((2,1))(att)
mer = merge([att, lstm], "mul")
hid = AveragePooling1D(pool_length=input_length)(mer)
hid = Flatten()(hid)
1
here a simple way to add attention: stackoverflow.com/a/62949137/10375049Marco Cerliani

1 Answers

1
votes

The first piece of code you have shared is incorrect. The second piece of code looks correct except for one thing. Do not use TimeDistributed as the weights will be the same. Use a regular Dense layer with a non linear activation.


    input_ = Input(shape=(input_length, input_dim))
    lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_)
    att = Dense(1, activation='tanh')(lstm_out )
    att = Flatten()(att)
    att = Activation(activation="softmax")(att)
    att = RepeatVector(self.HID_DIM)(att)
    att = Permute((2,1))(att)
    mer = merge([att, lstm], "mul")

Now you have the weight adjusted states. How you use it is up to you. Most versions of Attention I have seen, just add these up over the time axis and then use the output as the context.