11

I have two questions about how to use Tensorflow implementation of the Transformers for text classifications.

  • First, it seems people mostly used only the encoder layer to do the text classification task. However, encoder layer generates one prediction for each input word. Based on my understanding of transformers, the input to the encoder each time is one word from the input sentence. Then, the attention weights and the output is calculated using the current input word. And we can repeat this process for all of the words in the input sentence. As a result we'll end up with pairs of (attention weights, outputs) for each word in the input sentence. Is that correct? Then how would you use this pairs to perform a text classification?
  • Second, based on the Tensorflow implementation of transformer here, they embed the whole input sentence to one vector and feed a batch of these vectors to the Transformer. However, I expected the input to be a batch of words instead of sentences based on what I've learned from The Illustrated Transformer

Thank you!

2 Answers 2

12

There are two approaches, you can take:

  1. Just average the states you get from the encoder;
  2. Prepend a special token [CLS] (or whatever you like to call it) and use the hidden state for the special token as input to your classifier.

The second approach is used by BERT. When pre-training, the hidden state corresponding to this special token is used for predicting whether two sentences are consecutive. In the downstream tasks, it is also used for sentence classification. However, my experience is that sometimes, averaging the hidden states give a better result.

Instead of training a Transformer model from scratch, it is probably more convenient to use (and eventually finetune) a pre-trained model (BERT, XLNet, DistilBERT, ...) from the transformers package. It has pre-trained models ready to use in PyTorch and TensorFlow 2.0.

2
  • It seems odd to use a single token from the encoded sequence, discarding the rest. Why not just add an additional attention layer to attend across the encoded sequence and output a single value?
    – Robz
    Dec 30, 2020 at 6:45
  • 1
    What you suggest basically happens in the last layer, so in some sense, the last layer might be wasting computation for the remaining hidden states. The intuition behind the [CLS] token is that it collects information on the fly from all hidden layers, not just the last one a this is passed via residual connections. Also, the information collected for the classification also influences the self attentions in the layers. The model should be thus able to learn more complex patterns more easily compared to what you suggest.
    – Jindřich
    Dec 30, 2020 at 9:54
5
  1. The Transformers are designed to take the whole input sentence at once. The main motive for designing a transformer was to enable parallel processing of the words in the sentences. This parallel processing is not possible in LSTMs or RNNs or GRUs as they take words of the input sentence as input one by one. So in the encoder part of the transformers, the very first layer contains the number of units equal to the number of words in a sentence and then each unit converts that word into an embedding vector corresponding to that word. Further, the rest of the processes are carried out. For more details, you can go through the article: http://jalammar.github.io/illustrated-transformer/ How to use this transformer for text classification - Since in text classification our output is a single number not a sequence of numbers or vectors so we can remove the decoder part and just use the encoder part. The output of the encoder is a set of vectors, the same in number as the number of words in the input sentence. Further, we can feed these sets of output vectors into a CNN, or we can add an LSTM or RNN model and perform classification.
  2. The input is the whole sentence or batch of sentences not word by word. Surely you would have misunderstood it.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.