<aside> 📘 Series:

Beginner’s Guide on Recurrent Neural Networks with PyTorch

A Brief Introduction to Recurrent Neural Networks

Attention Mechanism

Attention? Attention!

Illustrated Guide to Transformers- Step by Step Explanation

How to code The Transformer in PyTorch

Tokenizers: How machines read

</aside>


Recurrent Neural Networks(RNNs) have been the answer to most problems dealing with sequential data and Natural Language Processing(NLP) problems for many years, and its variants such as the LSTM are still widely used in numerous state-of-the-art models to this date. In this post, I’ll be covering the basic concepts around RNNs and implementing a plain vanilla RNN model with PyTorch to generate text.

Although the content is introductory, the post assumes that you at least have a basic understanding of normal feed-forward neural nets.

Basic Concepts

What exactly are RNNs? First, let’s compare the architecture and flow of RNNs vs traditional feed-forward neural networks.

Overview of the feed-forward neural network and RNN structures

Overview of the feed-forward neural network and RNN structures

The main difference is in how the input data is taken in by the model.

Traditional feed-forward neural networks take in a fixed amount of input data all at the same time and produce a fixed amount of output each time. On the other hand, RNNs do not consume all the input data at once. Instead, they take them in one at a time and in a sequence. At each step, the RNN does a series of calculations before producing an output. The output, known as the hidden state, is then combined with the next input in the sequence to produce another output. This process continues until the model is programmed to finish or the input sequence ends.

Still confused? Don't anguish yet. Being able to visualize the flow of an RNN really helped me understand when I started on this topic.

Simple process flow of an RNN cell

Simple process flow of an RNN cell

As we can see, the calculations at each time step consider the context of the previous time steps in the form of the hidden state. Being able to use this contextual information from previous inputs is the key essence to RNNs’ success in sequential problems.

While it may seem that a different RNN cell is being used at each time step in the graphics, the underlying principle of Recurrent Neural Networks is that the RNN cell is actually the exact same one and reused throughout.

Processing RNN Outputs?

You might be wondering, which portion of the RNN do I extract my output from? This really depends on what your use case is. For example, if you’re using the RNN for a classification task, you’ll only need one final output after passing in all the input - a vector representing the class probability scores. In another case, if you’re doing text generation based on the previous character/word, you’ll need an output at every single time step.

This image was taken from Andrej Karpathy’s blog post

This image was taken from Andrej Karpathy’s blog post

This is where RNNs are really flexible and can adapt to your needs. As seen in the image above, your input and output size can come in different forms, yet they can still be fed into and extracted from the RNN model.

Many inputs to one output

Many inputs to one output

For the case where you’ll only need a single output from the whole process, getting that output can be fairly straightforward as you can easily take the output produced by the last RNN cell in the sequence. As this final output has already undergone calculations through all the previous cells, the context of all the previous inputs has been captured. This means that the final result is indeed dependent on all the previous computations and inputs.