by Roopal Garg
This is a continuation from Part 1 on word embeddings. If you haven’t read it, I would encourage going through it just so that this post gels in well.
By the end of the previous post we gained an understanding on how to embed and represent words. Next stage would be stitching these words to form sentences and be able to represent them.
The simplest way would be to just add up the individual word embeddings and average them out (as in the following diagram) and hey know what?? It works!!
But…only to some extent. You still lose the knowledge from word ordering. The system will not differentiate between:
“The boy kicked the football” vs “The football kicked the boy”
And it’s Deep Learning! That was way too simple to be worth the hype and we didn’t talk of any black magic!!!
Why settle for Expelliarmus when you can learn Avada Kedavra!
Let’s dig deeper.
Sequence of words, sentences or phrases, can be represented in different ways, we will talk about a few common ways here:
1. Recurrent Neural Networks aka RNN
- Vanilla RNNs
- Long Short Term Memory aka LSTMs
- Gated Recurrent Units aka GRUs
2. Convolutional Neural Networks aka CNNs (Yes! they work for text as well!)
Let's look at RNNs first.
Recurrent Neural Networks (RNNs)
Let’s think for a minute about how we as humans process and comprehend sentences. For most languages we go over the text, left to right and after each word have some understanding of what the text is about till we reach the end and cumulatively comprehend the entire meaning. Now as adults, through years of training, our brains are super efficient at this (unless it’s text that requires a frequent dictionary lookup) to the extent that we don’t even feel the lag between comprehending words. The idea behind RNNs is very similar.
I am sure most of you must have seen a variant of the diagram below musing over DL literature numerous times. There are two parts to understanding this diagram:
- Understanding what happens at each step i.e, how the current word is processed
- Understanding how that gels together with preceding words
Understanding what happens at each step
ht is called as the hidden state at time t, think of it as a way by which the RNN represents phrases it has seen upto time t internally. At time t=1, ie when we encounter the first word x0, the initial hidden state h0 can be thought of as a zero vector representing no prior knowledge.
There are three stages at each time step t:
- Input: takes in the word embedding of the current word and creates another representation for it using the input weight Wx
- Hidden: gels the representation from the previous words with the current one using Wh. ht-1 represents the hidden state from all previous words
- Output: converts the internal hidden representation to the actual desired output using Wo. Whether you calculate the output per step or at the end for the last word may depend on what you are trying to achieve, eg: Part of Speech tagging needs an output per word vs say Sentiment Analysis, where you need the overall sentiment of the text.
Note: all the weights(W) are matrices, hidden states(h) is a vector and in practice we have a bias(b) vector at each stage which is added, not shown to keep things simple.
Most of the math seems pretty straight forward. I will elaborate a little on one scary, might be an alien looking like symbol there “𝞂”, called the sigmoid function. Just think of it as a function that spits out a number between 0 and 1 given an input. Not that scary after all huh!
Understanding how that gels together with preceding words
Ok, so now let’s unroll the loop in the diagram above on a phrase and see how it looks:
The idea is the same, to try and mimic the brain. Move from left to right and at each word create an understanding of what we have read until now.
Lets throw in all parameters as well.
RNNs look amazing but there is a problem, at each step we create an understanding of what we read till now, so it’s always a mix of the past and the present. What happens when it’s a really long piece of text and the only vital information you really needed to remember was the first word! Vanilla RNNs suffer from what is referred to as the long term dependency problem.
eg: Coldplay gave a really memorable performance at the Rose Bowl last night. It was a sell out...random….text….something...something....think...of..it...as..a..really...long...text...almost...the...end...of...the...story….almost...there.... ...and….finally...it...ends!
And then someone asks: Who played at the concert?
The answer is the very first word! But as part of creating an understanding from the story, Vanilla RNNs mixed up the embedding for the word “Coldplay” so many times, it might not be feasible to get that one word out as an answer for a Q&A system.
If you think of it, the problem is that there is no way in the system for the information to really flow unaltered. LSTMs (and GRUs) to the rescue!
Long Short Term Memory (LSTMs)
The abstract representation is exactly the same as for RNNs, the difference lies in that orange cell.
Let’s open that pandora’s box!
Now that looks scary!!!!! (I promise I will try my best to simplify things as much as possible)
The first major difference you may notice is that we have a Ct now, this is called the cell state. Similar to the RNN structure, we still have ht which derives its value as a curated version of Ct which we will talk about. The cell state is a very important component of the LSTM as it acts like a conveyor belt through which the information flows within and among cells.
The small boxes with the three “𝞂” you see there are called “Gates”. Simply because “𝞂” gives a number between 0 and 1 which in a way represents degree to which information can flow.
- Information * 0 = 0 No information flows through
- Information * 1 = Information pass all the information
- Information * (0, 1) = some fraction of the information passed through
There are a couple of things in that cell that might look new:
tanh : a function that gives out a number between -1 and 1
: element wise matrix multiplication
: matrix addition (that one was easy!)
There are 3 gates and 4 overall steps. Let’s look at the gates and we will talk about the steps within that one by one.
- Forget Gate (ft): as the name suggests, decides what information do we want to forget. Takes as input ht-1 and the current input xt and applies a “𝞂” resulting in a vector of numbers between 0 and 1.
- Input Gate (it): with the current input, we need to make changes in the cell state to form Ct.
Now this raises 2 questions:
i. which values to update? The input gate decides which values to udpate and by how much using the “𝞂” function
ii. what would be the new values? New candidate cell values Ĉt are generated using tanh
We know what to forget, what to update and the candidate new values Ĉt .
Let’s actually do the update to get the new cell state Ct !!
- Output Gate (ot): decides what part of the cell state Ct should we sent as the new hidden state ht. The “𝞂” takes care of deciding what parts to output and multiplied by Ct gives the new hidden state ht.
The math part may seem complex but try and focus on the intuition rather than thinking too hard on the “math” part (not talking to the extremely curious ones out there, please feel free to explore).
There is an excellent post by Christopher Olah that helped me understand the concepts in the beginning, highly recommend reading it: Understanding LSTM Networks
Gated Recurrent Units (GRUs)
A well known variation of the LSTM is the GRU, which looks something like this:
As you might have noticed there is no cell state Ct
In short, it has 2 gates instead of 3 as in the case of LSTMs. The two gates are:
- Reset Gate (rt): determines how to combine the previous memory with the current input
- Update Gate (zt): decies how much of the previous memory to keep around
GRUs are relatively new and comprise of much fewer parameters than LSTMs which eventually means they are lighter and faster to train. That being said there isn’t a clear winner between the two.
There are many small modifications to these and other variants of RNNs which all come with their pros and cons. What works best is entirely dependent on the data and the usage.
Convolutional Neural Networks (CNNs)
The idea of using CNNs for text is very similar to how n-grams work in practice.
Instead of looking at one word at a time and combining it with the representation of the past words sequentially, we look at pairs of words and combine their representation. One way to do this is hierarchically as follows:
There is excellent paper by Yoon Kim, Convolutional Neural Networks for Sentence Classification which will clear up the concepts in detail and also introduce a concept called Pooling.
Wow! That was intense.
If you understood most of it, then you understand the building blocks on how to use Deep Learning for NLP. From here on it’s smooth sailing, since all you need to do is combine these building blocks like ingredients to a magic spell to form complex architectures.
We now know how to represent words(hope you remember that) and we just read how to represent sentences. Let’s look at how to treat paragraphs now.
If you think of paragraphs, they are nothing but bunch of sentences one after the other.
If you recall, when we stepped up from words to sentences, the first thought was to simply add up all the words in a sentence and average them out.
The sentence to paragraph version of that would look something like this:
Key thing to note here is that for the first word of the first sentence “I”, the initial hidden state is a zero vector h0 but when you look at the first word of the second sentence, the initial hidden state is h5, ie, the last hidden state from the first sentence. The idea is basically to treat the sentences as a long string of words one after the other.
As you might have guessed, this works….but doesn’t work like magic!
Let’s look at an alternate approach.
The second approach is to split and process the sentences separately and then take the final hidden state from each sentence as input to another layer of RNNs (could be Vanilla RNNs, LSTMs or GRUs) ie, we process the sentences hierarchically. Looks something like this:
Doing this also has the advantage that you can process your sentences in parallel before combining them in the second layer of RNNs.
Enough of theory!
Let’s look at a real application: Sentiment Analysis
A paper which we implemented for this task and that works pretty well is titled:
Hierarchical Attention Networks for Document Classification
The idea looks like this:
Note: in the original paper, the RNN layers are bi-directional ie they go from left to right and also from right to left, for simplicity we consider uni-directional RNN cells here.
The paper treats text very similar to how we discussed about processing text hierarchically, but in addition to that they use something called as an Attention Layer.
eg: I like playing soccer
Now if the task at hand is Sentiment Analysis, then it really doesn’t matter what you like, as long as the sentence talks about you “liking” something. The sentiment should be positive since we would pay a higher confus on “I like”, this confus is what the Attention Layer intents to catch.
The model, attempts to capture what’s important within each sentence at the first layer of attention and then tries to judge what’s important among those at the second layer of attention. The last layer which is a softmax layer is a simple way to get a probability distribution across the possible output classes(positive and negative in the case for Sentiment Analysis).
Reading all that would have given you some intuition of how to process text through NNs. There is no “one solution fits all problems” solution here. The network architecture, size of the layers and its performance depends on what the data looks like and how much of it is there. Large networks tend to do better since they can capture more information but they come at the cost of complexity and time for training. At the end the choice would be yours.
There are some pretty neat frameworks out there which you can use to get started:
- Apache MXNet
- Torch / PyTorch
- Keras (this would be the choice if you want to stay at an abstract level and are not ok getting your hands dirty)
Some good reading resources:
- Stanford Course: Deep Learning for NLP : highly recommend this!
- Twitter Feeds: excellent resource to stay updated with new research papers, blog posts, etc.