Embedding layer in Trax

Cuong Tran
2 min readSep 7, 2021

When taking a course in Natural Language Processing in Coursera, I had wondered how the Embedding layer works under the hood. I spent some time look into the source code, write some simple sample code to check my understanding and finally I think I have a clear image of how it works.

Today I saw some learners have the same question as I had before so I decided to write something in a hope to help them understanding without making the effort as I did. Of course if someone wants to know how it is implemented, it worth taking time to read the source code.

In a nutshell, the Embedding layer converts each token (a non-negative integer) into an array of size d_feature. It does this by keeping a weights array inside, which has size of (vocab_size, d_feature). This is implemented in the following code.

The input parameter input_signature is ignored here. The shape of the weights matrix is defined by shape_w which is a tuple of (self._vocab_size, self._d_feature) It is used to initialize the weights and then the weights is assigned back to self.weights.

How does the embedding is generated? This logic is inside the def forward method.

All the heavy lifting is done by jnp.take which is a method from lax.numpy The implementation of take is quite complex, but for the purpose of the forward method, it will gather slices specified by x from self.weights. Let’s look into an example to confirm the idea. In this example we have a vocabulary of size vocab_size=5 and a batch of tokens of size (2, 4) (two batches, each has size of 4 tokens). First take a look of Trax Embedding, then we will try to see how the source code work inside.

What we did here is manually initialized an Embedding layer, then give it an input (batch of tokens) to get the output. We can confirm the shape of the output. Next we will try to do the same thing but this time we use the source code of Embedding layer, specifically the take method mentioned above.

Now we initialized the weights ourselves using random.normal , then using take to get the embedding result.

That’s it, feel free to look into the weights matrix and compare it with embed1 to confirm it is really token from weights matrix.

--

--