9

I was following a paper on BERT-based lexical substitution (specifically trying to implement equation (2) - if someone has already implemented the whole paper that would also be great). Thus, I wanted to obtain both the last hidden layers (only thing I am unsure is the ordering of the layers in the output: last first or first first?) and the attention from a basic BERT model (bert-base-uncased).

However, I am a bit unsure whether the huggingface/transformers library actually outputs the attention (I was using torch, but am open to using TF instead) for bert-base-uncased?

From what I had read, I was expected to get a tuple of (logits, hidden_states, attentions), but with the example below (runs e.g. in Google Colab), I get of length 2 instead.

Am I misinterpreting what I am getting or going about this the wrong way? I did the obvious test and used output_attention=False instead of output_attention=True (while output_hidden_states=True does indeed seem to add the hidden states, as expected) and nothing change in the output I got. That's clearly a bad sign about my understanding of the library or indicates an issue.

import numpy as np
import torch
!pip install transformers

from transformers import (AutoModelWithLMHead, 
                          AutoTokenizer, 
                          BertConfig)

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
config = BertConfig.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attention=True) # Nothign changes, when I switch to output_attention=False
bert_model = AutoModelWithLMHead.from_config(config)

sequence = "We went to an ice cream cafe and had a chocolate ice cream."
bert_tokenized_sequence = bert_tokenizer.tokenize(sequence)

indexed_tokens = bert_tokenizer.encode(bert_tokenized_sequence, return_tensors='pt')

predictions = bert_model(indexed_tokens)

########## Now let's have a look at what the predictions look like #############
print(len(predictions)) # Length is 2, I expected 3: logits, hidden_layers, attention

print(predictions[0].shape) # torch.Size([1, 16, 30522]) - seems to be logits (shape is 1 x sequence length x vocabulary

print(len(predictions[1])) # Length is 13 - the hidden layers?! There are meant to be 12, right? Is one somehow the attention?

for k in range(len(predictions[1])):
  print(predictions[1][k].shape) # These all seem to be torch.Size([1, 16, 768]), so presumably the hidden layers?

Explanation of what worked in the end inspired by accepted answer

import numpy as np
import torch
!pip install transformers

from transformers import BertModel, BertConfig, BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = BertConfig.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
model = BertModel.from_pretrained('bert-base-uncased', config=config)
sequence = "We went to an ice cream cafe and had a chocolate ice cream."
tokenized_sequence = tokenizer.tokenize(sequence)
indexed_tokens = tokenizer.encode(tokenized_sequence, return_tensors='pt'
enter code here`outputs = model(indexed_tokens)
print( len(outputs) ) # 4 
print( outputs[0].shape ) #1, 16, 768 
print( outputs[1].shape ) # 1, 768
print( len(outputs[2]) ) # 13  = input embedding (index 0) + 12 hidden layers (indices 1 to 12)
print( outputs[2][0].shape ) # for each of these 13: 1,16,768 = input sequence, index of each input id in sequence, size of hidden layer
print( len(outputs[3]) ) # 12 (=attenion for each layer)
print( outputs[3][0].shape ) # 0 index = first layer, 1,12,16,16 = , layer, index of each input id in sequence, index of each input id in sequence
2
  • How did you parse the final tensor of [1, 12, 16, 16]? The documentation says that it represents batch_size, num_heads, sequence_length, sequence_length, but I am not sure how to interpret the last two dimensions. Do you have any ideas? Mar 30, 2020 at 23:28
  • 1
    The attention for each layer bit? So, you get the attention for a certain layer, let's say the first one (index 0) as outputs[3][0], then you may want e.g. the attention that attention head number 3 (index 2) "pays to" item 2 (index 1) when "interpreting" item 15 (index 14). To get that you take outputs[3][0][0,2,1,14], or perhaps outputs[3][0][0,2,14,1] - I forgot which way around this last bit is. I think github.com/jessevig/bertviz visualizes this quite nicely.
    – Björn
    Mar 31, 2020 at 18:34

2 Answers 2

7

I think it's too late to make an answer here, but with the update from the huggingface's transformers, I think we can use this

config = BertConfig.from_pretrained('bert-base-uncased', 
output_hidden_states=True, output_attentions=True)  
bert_model = BertModel.from_pretrained('bert-base-uncased', 
config=config)

with torch.no_grad():
  out = bert_model(input_ids)
  last_hidden_states = out.last_hidden_state
  pooler_output = out.pooler_output
  hidden_states = out.hidden_states
  attentions = out.attentions
1
  • Please, if you have questions, post it as a question with link to this post, do not ask within answers.
    – Ruli
    Dec 16, 2020 at 14:02
3

The reason is that you are using AutoModelWithLMHead which is a wrapper for the actual model. It calls the BERT model (i.e., an instance of BERTModel) and then it uses the embedding matrix as a weight matrix for the word prediction. In between the underlying model indeed returns attentions, but the wrapper does not care and only returns the logits.

You can either get the BERT model directly by calling AutoModel. Note that this model does not return the logits, but the hidden states.

bert_model = AutoModel.from_config(config)

Or you can get it from the BertWithLMHead object by calling:

wrapped_model = bert_model.base_model
2
  • thank you! I admit to being confused about what I got now (after switching to predictions = bert_model.base_model(indexed_tokens) in the code above). I get a tuple with 3 elements (great!) with shapes: [1, 16, 768] (presumably the final hidden layer?), the next element is [1, 768] and the final one a tuple of length 13 with each element being shaped [1, 16, 768]. This may be a stupid Q, but what are these? Are there not 12 hidden layers (n=768) + 144 attention heads (12 for each of 12 layers)? I was expecting 12 x [1, 16, 768] + 144 x [1, 16, 768] or so... I clearly misunderstood something.
    – Björn
    Feb 10, 2020 at 14:02
  • Figured it out myself now. Added at bottom of questions.
    – Björn
    Feb 11, 2020 at 23:33

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.