LSTM : Why do we need it?

Sumanth
7 min readJun 4, 2021

Motivation : Share your knowledge before it become useless

Hi Folks, I have come up with the new story on Deep learning and this story give you the insights of LSTM which is state of the Art topic in Deep learning.

Let’s get started ………………

What you will learn by end of the story?

>Intro to plain RNN

>Why do we need LSTM?

>GRU Intro

Intro to plain RNN

RNN(Recurrent Neural Networks) are the special neural networks with single layer of hidden activation functions takes the input along with the time.

Got confused, don’t worry let me make it simpler.

If you notice in the above pic the blue colored cell is the single RNN unit which is unfolded along the time axis in the sense we will have single RNN unit and we pass the inputs at different time stamps.

For example if we consider the sentence — “Thanks for a great party at the weekend, we really enjoyed it” which we need to classify whether it is positive /negative comment. First we will pass the corresponding w2v(Thanks) as input at time t to the RNN unit, followed by w2v(for) at t+1 to the same RNN unit where we passed the 1st word and w2v(great) at t+2 again to the same RNN unit, like wise we will pass all inputs to the same RNN unit along the time axis.

Once we pass the complete first sentence words will predict the output as highlighted in the red and then back propagation happens through time to update the corresponding weights and these weights are used for our second sentence words and this repeats for other sentences as well.

Are you with me so far????????????

Don’t worry will go more in detail to get better idea. Let’s dive deeper now and make ourselves dirty.

Just go through the pic which helps you better in understanding how actually RNN works.

Forward propagation:

w, w’, w’’ are trainable parameters which will get initialized like in SGD optimization and we pass the input Xi1 along the w and also the Oo along w’ at time t and this is passed as (Oo*w’+Xi1*w) at time t to RNN unit which will generate O1 as output.

Note : Oo is just the zero vector[0,0,0,0,0,…], since we don’t have any previous outputs to pass at time t.

Now, at time t+1 we will pass the (Xi2*w+O1*w’), O1 is the previous output at time t.

The above procedure is repeated until we reach the last word. consider we have 10 words in the given Xi sentence the final output which we got is O10 and this is passed to the sigmoid to generate the probabilistic output to classify whether the Xi sentence is positive/negative.

Backward propagation:

Here in RNN backpropagation happens through time.

w’’ is calculated at once since the w” is not travelling along the time. But w, w’ are travelling along the time as we can see in the pic. so we should calculate the respective gradients i.e (dL/dw, dL/dw’) at different time intervals in which the forward propagations happened.

w”_new = w”_old — (eta*dL/dw” wrt w”_old) only once it is calculated

at time t+10:

w_new = w_old — (eta*dL/dw wrt w_old)

dL/dw = (dL/dyi^)*(dyi^/dO10)*(dO10/dw)

w’_new = w’_old — (eta*dL/dw’ wrt w’_old)

dL/dw’ = (dL/dyi^)*(dyi^/dO10)*(dO10/dw’)

now at t+9:

w_new1 = w_new — (eta*dL/dw wrt w_new)

dL/dw = (dL/dyi^)*(dyi^/dO10)*(dO10/dO9)*(dO9/dw)

w’_new1= w’_new — (eta*dL/dw’ wrt w’_new)

dL/dw’ = (dL/dyi^)*(dyi^/dO10)*(dO10/dO9)*(dO9/dw’)

.

.

.

now at t :

w_new10 = w_new 9— (eta*dL/dw wrt w_new9)

dL/dw = (dL/dyi^)*(dyi^/dO10)*(dO10/dO9)*(dO9/dO8)*….*(dO1/dw)

w’_new10= w’_new 9— (eta*dL/dw’ wrt w’_new9)

dL/dw’ = (dL/dyi^)*(dyi^/dO10)*(dO10/dO9)*(dO9/dO8)*….*(dO1/dw’)

Once we go through all the time stamps as explained in the above we will be getting final updated weights (w, w’, w”) and these are carry forwarded to second sentence and this repeats for all the sentences.

Are you messed up with the above explanation don’t worry just keep in mind that backpropagation in RNN happens through time and finally weights get updated along the time intervals. still you have confusion take a pen and paper just redraw the diagram of RNN and write down the gradients.

Why do we need LSTM?

Did you notice that for a sentence which has 20 words will have to pass at 20 different intervals of time and while backpropagation we have to update the weights at 20 different intervals.

Consider at time t how will our gradient for w looks like dL/dw = (dL/dyi^)*(dyi^/dO20)*(dO20/dO19)*(dO19/dO18)*….*(dO10/dO9)……….(dO1/dw).

What do you think if we have lots of gradients multiplied to each other. We will run into vanishing gradient and also the exponential gradient problem.

Did you notice we just have single RNN unit but still running to vanishing and exploding gradient problem. Because we have the long sentence consisting of many words.

Consider that we have many to many RNN with equal output length as input.

Suppose our O7 is dependent on x1 input. By the time it reaches to time stamp where O7 is generated the x1 might have undergone many changes due to lots of mathematical operations being performed on it. so O7 would see the transformed x1 not the actual x1.

Hence we won’t get the desired O7. This is the problem of long term dependencies where our later outputs are not able see the original previous inputs. Hence the plain RNN can manage the short term dependencies but not the long term dependencies.

To put it in a short why we need LSTM:

  1. Large sentences makes the gradient vanishes & explodes.
  2. Plain RNN has the problem of Long term dependencies.

To compensate the above problems we have the LSTM’s.

Will look into the structure of LSTM and how it solves the above 2 problems.

Let me break down for you now.

All you need to know about the LSTM is 3 gates.

  1. Forget Gate:

This gate tells that how much information we should retain from the previous cell state i.e (previous time stamp’s LSTM cell state).

Suppose if we want to carry forward the previous cell state as it is it would multiply with vector [1,1,1,1…]. If we multiply with [1/2,1/2….1/2] we will retain the 50% of the previous cell state.

2. Input Gate:

This gate tells after retaining the previous cell state information how much information we should add to generate the current cell state.

Suppose if we want to carry forward the previous cell state as it is we will add the [0,0,0,….0] vector to it.

3. Output Gate:

This gate helps us to generate the output of current LSTM cell with the help of current input, previous output, previous cell state.

Are these gates helpful to overcome the problems which we faced in plain RNN?

Yes of course, For the long term dependencies where later output O7 is dependent on previous input X1 we can pass the same X1 to the RNN unit by using the help of forget gate and input gate which are multiplied by [1,1,1,….1] and added by [0,0,0,….0]. so that the X1 is carry forwarded as it is without undergoing transformation to next time stamp and this can be done until it reaches the t+7.

Since the LSTM unit’s can become short circuit to transfer the previous cell states as it is to next time stamp with this while backpropagation few weights will not get updated along the time stamps where the short circuit exists which in turn we can use long sentences as inputs to our LSTM without vanishing and exploding gradient problems.

With the LSTM’s we can mange both the long and short term dependencies with the help of short circuit. That’s why the name LONG SHORT TERM MEMORY.

Hurray! we almost reached the end of the story with just one step away from it.

GRU Intro:

GRU also solves the above 2 problems faced by the vanilla RNN.

If we notice LSTM has total 3 gates which incurs lots of equations.

GRU is simple with just 2 gates and lesser equations when compared to LSTM.

Comparison of LSTM equations and GRU equations:

GRU CELL

GRU has 2 gates (reset and update) gates.

Reset gate tells how much info we should reset from its previous output.

Update gate tells us how much info we should add to generate the current output.

Choice is yours you can pick either LSTM or GRU.

To make it concrete what you want to choose.

  1. LSTM is old which had published in 1997 but it is powerful.
  2. GRU is new which got published in 2014 which is simplified version of LSTM and faster to train since it has lesser equations.

Kudos to you, reached end of the story. I hope made the concept crisp and clear, will engage you with another story until then Bye-Bye.

--

--

Sumanth

Assistant Manager@ Tata Communications Ltd, Machine Learning, Deep Learning, NLP, Computer Vision, Python, Django Enthusiast.