2
votes

I'm building a LSTM network from scratch, from my own understanding of how LSTM cells work.

There are no layers, so I'm trying to implement non-vectorized forms of the equations I see in the tutorials. I'm also using peepholes from the cell state.

So far, I understand that it looks like this: LSTM network

With that I've made these equations for each of the gates for forward pass:

i_t = sigmoid( i_w * (x_t + c_t) + i_b )
f_t = sigmoid( f_w * (x_t + c_t) + f_b )

cell_gate = tanh( c_w * x_t + c_b )

c_t = (f_t * c_t) + (i_t * cell_gate)

o_t = sigmoid( o_w * (x_t + c_t) + o_b )

h_t = o_t * tanh(c_t)

Where _w's mean weights for that respective gate and _b for biases. Also, I've named that first sigmoid on the far left the "cell_gate".


Back pass is where things get fuzzy for me, I'm not sure how to derive these equations correctly.

I know generally to calculate error, the equation is: error = f'(x_t) * (received_error). Where f'(x_t) is the first derivative of the activation function and received_error could be either (target - output) for output neurons or ∑(o_e * w_io) for hidden neurons.

Where o_e is the error of one of the cells the current cell outputs to and w_io is the weight connecting them.

I not sure if the LSTM cell as a whole is considered a neuron, so I treated each of the gates as neurons and tried to calculate error signals for each. Then used the error signal from the cell gate alone to pass back up the network...:

o_e = sigmoid'(o_w * (x_t + c_t) + o_b) * (received_error)
o_w += o_l * x_t * o_e
o_b += o_l * sigmoid(o_b) * o_e

...The rest of the gates follow the same format...

Then the error for the entire LSTM cell is equal to o_e.

Then for a LSTM cell above the current cell, the error it receives is equal to to:

tanh'(x_t) * ∑(o_e * w_io)

Is this all correct? Am I doing anything completely wrong?

1

1 Answers

0
votes

I to am taking on this task, I believe your approach is correct:

https://github.com/evolvingstuff/LongShortTermMemory/blob/master/src/com/evolvingstuff/LSTM.java

Some nice work from: Thomas Lahore

    ////////////////////////////////////////////////////////////// 
    ////////////////////////////////////////////////////////////// 
    //BACKPROP 
    ////////////////////////////////////////////////////////////// 
    ////////////////////////////////////////////////////////////// 

    //scale partials 
    for (int c = 0; c < cell_blocks; c++) { 
        for (int i = 0; i < full_input_dimension; i++) { 
            this.dSdwWeightsInputGate[c][i] *= ForgetGateAct[c]; 
            this.dSdwWeightsForgetGate[c][i] *= ForgetGateAct[c]; 
            this.dSdwWeightsNetInput[c][i] *= ForgetGateAct[c]; 

            dSdwWeightsInputGate[c][i] += full_input[i] * neuronInputGate.Derivative(InputGateSum[c]) * NetInputAct[c]; 
            dSdwWeightsForgetGate[c][i] += full_input[i] * neuronForgetGate.Derivative(ForgetGateSum[c]) * CEC1[c]; 
            dSdwWeightsNetInput[c][i] += full_input[i] * neuronNetInput.Derivative(NetInputSum[c]) * InputGateAct[c]; 
        } 
    } 

    if (target_output != null) { 
        double[] deltaGlobalOutputPre = new double[output_dimension]; 
        for (int k = 0; k < output_dimension; k++) { 
            deltaGlobalOutputPre[k] = target_output[k] - output[k]; 
        } 

        //output to hidden 
        double[] deltaNetOutput = new double[cell_blocks]; 
        for (int k = 0; k < output_dimension; k++) { 
            //links 
            for (int c = 0; c < cell_blocks; c++) { 
                deltaNetOutput[c] += deltaGlobalOutputPre[k] * weightsGlobalOutput[k][c]; 
                weightsGlobalOutput[k][c] += deltaGlobalOutputPre[k] * NetOutputAct[c] * learningRate; 
            } 
            //bias 
            weightsGlobalOutput[k][cell_blocks] += deltaGlobalOutputPre[k] * 1.0 * learningRate; 
        } 

        for (int c = 0; c < cell_blocks; c++) { 

            //update output gates 
            double deltaOutputGatePost = deltaNetOutput[c] * CECSquashAct[c]; 
            double deltaOutputGatePre = neuronOutputGate.Derivative(OutputGateSum[c]) * deltaOutputGatePost; 
            for (int i = 0; i < full_input_dimension; i++) { 
                weightsOutputGate[c][i] += full_input[i] * deltaOutputGatePre * learningRate; 
            } 
            peepOutputGate[c] += CEC3[c] * deltaOutputGatePre * learningRate; 

            //before outgate 
            double deltaCEC3 = deltaNetOutput[c] * OutputGateAct[c] * neuronCECSquash.Derivative(CEC3[c]); 

            //update input gates 
            double deltaInputGatePost = deltaCEC3 * NetInputAct[c]; 
            double deltaInputGatePre = neuronInputGate.Derivative(InputGateSum[c]) * deltaInputGatePost; 
            for (int i = 0; i < full_input_dimension; i++) { 
                weightsInputGate[c][i] += dSdwWeightsInputGate[c][i] * deltaCEC3 * learningRate; 
            } 
            peepInputGate[c] += CEC2[c] * deltaInputGatePre * learningRate; 

            //before ingate 
            double deltaCEC2 = deltaCEC3; 

            //update forget gates 
            double deltaForgetGatePost = deltaCEC2 * CEC1[c]; 
            double deltaForgetGatePre = neuronForgetGate.Derivative(ForgetGateSum[c]) * deltaForgetGatePost; 
            for (int i = 0; i < full_input_dimension; i++) { 
                weightsForgetGate[c][i] += dSdwWeightsForgetGate[c][i] * deltaCEC2 * learningRate; 
            } 
            peepForgetGate[c] += CEC1[c] * deltaForgetGatePre * learningRate; 

            //update cell inputs 
            for (int i = 0; i < full_input_dimension; i++) { 
                weightsNetInput[c][i] += dSdwWeightsNetInput[c][i] * deltaCEC3 * learningRate; 
            } 
            //no peeps for cell inputs 
        } 
    } 

    ////////////////////////////////////////////////////////////// 

    //roll-over context to next time step 
    for (int j = 0; j < cell_blocks; j++) { 
        context[j] = NetOutputAct[j]; 
        CEC[j] = CEC3[j]; 
    } 

Also, and perhaps even more interesting is the Lecture and Lecture Notes from Andrej Karpathy:

https://youtu.be/cO0a0QYmFm8?t=45m36s

http://cs231n.stanford.edu/slides/2016/winter1516_lecture10.pdf