10

(I'm following this pytorch tutorial about BERT word embeddings, and in the tutorial the author is access the intermediate layers of the BERT model.)

What I want is to access the last, lets say, 4 last layers of a single input token of the BERT model in TensorFlow2 using HuggingFace's Transformers library. Because each layer outputs a vector of length 768, so the last 4 layers will have a shape of 4*768=3072 (for each token).

How can I implement this in TF/keras/TF2, to get the intermediate layers of pretrained model for an input token? (later I will try to get the tokens for each token in a sentence, but for now one token is enough).

I'm using the HuggingFace's BERT model:

!pip install transformers
from transformers import (TFBertModel, BertTokenizer)

bert_model = TFBertModel.from_pretrained("bert-base-uncased")  # Automatically loads the config
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
sentence_marked = "hello"
tokenized_text = bert_tokenizer.tokenize(sentence_marked)
indexed_tokens = bert_tokenizer.convert_tokens_to_ids(tokenized_text)

print (indexed_tokens)
>> prints [7592]

The output is a token ([7592]), which should be the input of the for the BERT model.

3
  • 1
    I would suggest changing your tags as you have asked a question about PyTorch and tagged tensorflow. It's misleading and wouldn't help you either.
    – chrisgiffy
    Apr 27, 2020 at 18:33
  • By "get the intermediate layers of pretrained model" I assume you are referring to hidden states of intermediate layers, right? And note that BERT produces contextual token representations and therefore it does not make sense to use the representation of a token based on an an input sequence which only contains that token. Further, it uses wordpieces to tokenize an input, therefore one word may be represent as two or more wordpiece tokens, hence two or more representation vectors for that word (which needs to be combined back to get one single vector for that word).
    – today
    Apr 28, 2020 at 16:54
  • @today yeah, I know that BERT has to get the context of the sentence inorder to get the best embeddings. But my question is on how to get the middle layers's outputs - each of the 12 BERT's layers is outputs an array of 764 values for each token, and my question is how to access those values
    – Yagel
    Apr 28, 2020 at 23:42

1 Answer 1

12

The third element of the BERT model's output is a tuple which consists of output of embedding layer as well as the intermediate layers hidden states. From documentation:

hidden_states (tuple(tf.Tensor), optional, returned when config.output_hidden_states=True): tuple of tf.Tensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

Hidden-states of the model at the output of each layer plus the initial embedding outputs.

For the bert-base-uncased model, the config.output_hidden_states is by default True. Therefore, to access hidden states of the 12 intermediate layers, you can do the following:

outputs = bert_model(input_ids, attention_mask)
hidden_states = outputs[2][1:]

There are 12 elements in hidden_states tuple corresponding to all the layers from beginning to the last, and each of them is an array of shape (batch_size, sequence_length, hidden_size). So, for example, to access the hidden state of third layer for the fifth token of all the samples in the batch, you can do: hidden_states[2][:,4].


Note that if the model you are loading does not return the hidden states by default, then you can load the config using BertConfig class and pass output_hidden_state=True argument, like this:

config = BertConfig.from_pretrained("name_or_path_of_model",
                                    output_hidden_states=True)

bert_model = TFBertModel.from_pretrained("name_or_path_of_model",
                                         config=config)
0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.