Implementation of LSTM in Trax framework

Cuong Tran
2 min readJun 10, 2021

This post is to add some information about how LSTM is implemented in Trax deep learning framework and assume readers are familiar with LSTM.

LSTM cell

Above is a diagram of LSTM cell, which may help you understand how the computation is done for each cell, but for a more concrete example like in natural language processing (NLP) problem, given a sentence, how it is processed by the LSTM is not so clear.

To explain the detail implementation of LSTM, let’s assume we are processing a sentence, which is already converted to a list of tokens (list of numbers). In NLP settings, we first need to convert each token into an array (an embedding of the token) by using Embedding layer as following.

# The output
input shape =(2, 5)
output shape (2, 5, 16)

Next the output of Embedding layer will be feed to LSTM layer. From above output, we can see it has 2 lists (each list corresponding to a sentence), each list has 5 tokens (include padding), each token has embedding dimension of 16. The LSTM layer will scan/traverse each sentence in the batch, apply the LSTM cell calculation on each token (and the hidden state), and output an array for each token. It is different from TensorFlow, which only keep output for the final token by default.

Concretely, in the above example, LSTM layer will scan 1st token to 5th token in each sentence in the batch (of 2). The output will have a shape of (2, 5, 16)

We can confirm the dimension of the output of LSTM as following:

For reader that wants to check the source code of Trax, following is the implementation of LSTM:

That’s it. You can change the example above to explore more about LSTM or run the example with a Python debugger to see how it works.

Thanks for reading!

--

--