<aside> 📘 Series:
Beginner’s Guide on Recurrent Neural Networks with PyTorch
A Brief Introduction to Recurrent Neural Networks
Illustrated Guide to Transformers- Step by Step Explanation
How to code The Transformer in PyTorch
</aside>
Recurrent Neural Networks(RNNs) have been the answer to most problems dealing with sequential data and Natural Language Processing(NLP) problems for many years, and its variants such as the LSTM are still widely used in numerous state-of-the-art models to this date. In this post, I’ll be covering the basic concepts around RNNs and implementing a plain vanilla RNN model with PyTorch to generate text.
Although the content is introductory, the post assumes that you at least have a basic understanding of normal feed-forward neural nets.
What exactly are RNNs? First, let’s compare the architecture and flow of RNNs vs traditional feed-forward neural networks.
Overview of the feed-forward neural network and RNN structures
The main difference is in how the input data is taken in by the model.
Traditional feed-forward neural networks take in a fixed amount of input data all at the same time and produce a fixed amount of output each time. On the other hand, RNNs do not consume all the input data at once. Instead, they take them in one at a time and in a sequence. At each step, the RNN does a series of calculations before producing an output. The output, known as the hidden state, is then combined with the next input in the sequence to produce another output. This process continues until the model is programmed to finish or the input sequence ends.
Still confused? Don't anguish yet. Being able to visualize the flow of an RNN really helped me understand when I started on this topic.
Simple process flow of an RNN cell
As we can see, the calculations at each time step consider the context of the previous time steps in the form of the hidden state. Being able to use this contextual information from previous inputs is the key essence to RNNs’ success in sequential problems.
While it may seem that a different RNN cell is being used at each time step in the graphics, the underlying principle of Recurrent Neural Networks is that the RNN cell is actually the exact same one and reused throughout.
You might be wondering, which portion of the RNN do I extract my output from? This really depends on what your use case is. For example, if you’re using the RNN for a classification task, you’ll only need one final output after passing in all the input - a vector representing the class probability scores. In another case, if you’re doing text generation based on the previous character/word, you’ll need an output at every single time step.
This image was taken from Andrej Karpathy’s blog post
This is where RNNs are really flexible and can adapt to your needs. As seen in the image above, your input and output size can come in different forms, yet they can still be fed into and extracted from the RNN model.
Many inputs to one output
For the case where you’ll only need a single output from the whole process, getting that output can be fairly straightforward as you can easily take the output produced by the last RNN cell in the sequence. As this final output has already undergone calculations through all the previous cells, the context of all the previous inputs has been captured. This means that the final result is indeed dependent on all the previous computations and inputs.