Working of LSTM and GRU
Recurrent Neural Networks
Recurrent Neural Networks (RNN) are models which are popularly used for sequence processing tasks and works well on sequential data. RNNs has shown a great success in solving Natural Language processing (NLP) problems. RNNs are a generalization of feed forward neural network with an internal memory. They are ‘recurrent’ because they perform the same function to every data input while the output of the data depends on one previous time step. After producing the output, it is fed back as the input to obtain the next output. Thus for making a decision, RNN depends on the current input and the output from the previous input. RNN is usually represented using the Figure 1
In Figure 1, X, h, y can be compared to the input, hidden and output layers of a fee forward network. There can be one or many hidden layers and each hidden layer can have many neurons. A, B and C are the network parameters.
Problems with Vanilla RNNs
- RNNs cannot track long term dependencies.
- RNNs suffer from vanishing and exploding gradient problems.
Fortunately LSTMs do not suffer from these problems
LSTM (Long Short term Memory)
Long Short Term Memory (LSTM) was introduced by Hochreiter & Schmidhuber in 1997. In this section, we look at the working of an LSTM. An LSTM cell is shown in the Figure 2 below
The main component of an LSTM is the cell state represented by the horizontal line running at the top. This is shown in Figure 3. Ct is the cell state and denotes the context vector which acts as a carrier of information from the previous states.
Information can be added or removed from the context vector C with the help of gates. Gates control how much information should be changed (i.e. added or removed) in the context vector. The gates are usually controlled by a ‘Sigmoid’ function. The value of the sigmoid falls between 0 and 1. Thus when the value of Sigmoid is 0, the gate does not allow any information to flow through whereas a value of 1 allows all the information. Thus depending on the value of sigmoid, the amount of information flow is controlled. LSTM has 3 gates. They are
- Forget Gate
- Input Gate
- Output Gate
The forget gate is shown in the Figure 4
The input to the forget gate is a concatenation of the previous output and the current input. The function of forget gate is to provide an appropriate output which can be used to forget information that has been added recently in the context vector. Thus the output of the forget gate is multiplied with the context vector to perform the appropriate action. If the output of the sigmoid is a zero then it completely gets rid of the information whereas a one keeps all the information which is coming in. The output of the sigmoid function ft is multiplied pointwise with the context vector at time t-1. Let see it step by step:
Step1: Concatenate the previous out with the current input to obtain the concatenated vector. The input to the sigmoid function is
Step 2: Pass it via the Sigmoid function after multiplying with the weights shown below. bf is the bias term
Step 3: Assuming the above output is ft, multiply it with the context vector from the previous time step. Doing it removes or keeps some of the information depending on ft
The above equation, depending on the value of ft, adds or removes information. Note that, the above equation is only a partial output using which we calculate the next context vector.
This gate gets to decide what information has to be added depending on the current input. The input gate is shown in figure below:
In this step we get to select what information we need to add. Let’s take it step by step again
Step1: Apply the sigmoid function to the concatenated input (from Step 1 of Forget gate) after applying the weights
Step2: Apply the tanh function to the concatenated input (from Step 1 of Forget gate) after applying the weights
Step 3: Multiply the results of the equations above to get the information. Assuming the result in Step1 and Step2 are represented by i and c, the output would be
The above product represents the information that needs to be added. The final context vector is obtained by adding the above information to the result obtained from the Step 3 of the Forget gate. Thus
The output gate controls what needs to be passed on and evaluate the output of the current LSTM cell. We can understand it from Figure 6
The output will depend on the context vector generated. The Sigmoid function gets to decide what needs to go as the output and what needs to be withheld. Thus
Step1: The concatenated input again goes though the sigmoid which acts as a control. Thus ot
Step2: The new context vector generated goes though a tanh function
Step3: The output would be the product of the above two. Thus,
This describes how the LSTM cell works in deciding the context vector and the output. There are many variants to the LSTM and once such variant is Gated Recurrent Units (GRU) which is described in the next section. The only advantage that GRU has over LSTM is that they are computationally cheaper than LSTM. Research has shown that the performance of both LSTM and GRUs are similar.
Gated Recurrent Units (GRUs)
Gated Recurrent Unit (GRU) was proposed by Cho et al. in 2014. A GRU is shown in Figure 7.
In a GRU, the cell state is ridden off and the hidden state y is used for information transfer. Also unlike LSTM, GRU has only 2 gates. They are (indicated using dotted lines in Figure 7)
- Reset Gate
- Update Gate
The Reset gate is shown in Figure 7. The function of this gate is to decide the amount of past information it has to let through and the amount of past information it has to forget. The output from the reset gate is [similar to the LSTM, the input is concatenated with the previous output before passing it into the sigmoid)[biases are left out here] given by the equation
The update gate combines the functions of the forget gate and input gate of LSTM. This gate gets to decide which information it has to throw away and which new information it has to add. The value of zt controls this and can be represented using the equation
The information that has to be remembered is then calculated using rt which is obtained from above
Finally, the output is calculated using the equation