I'm using the HuggingFace Transformers BERT model, and I want to compute a summary vector (a.k.a. embedding) over the tokens in a sentence, using either the mean
or max
function. The complication is that some tokens are [PAD]
, so I want to ignore the vectors for those tokens when computing the average or max.
Here's an example. I initially instantiate a BertTokenizer
and a BertModel
:
import torch
import transformers
from transformers import AutoTokenizer, AutoModel
transformer_name = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True)
model = AutoModel.from_pretrained(transformer_name)
I then input some sentences into the tokenizer and get out input_ids
and attention_mask
. Notably, an attention_mask
value of 0 means that the token was a [PAD]
that I can ignore.
sentences = ['Deep learning is difficult yet very rewarding.',
'Deep learning is not easy.',
'But is rewarding if done right.']
tokenizer_result = tokenizer(sentences, max_length=32, padding=True, return_attention_mask=True, return_tensors='pt')
input_ids = tokenizer_result.input_ids
attention_mask = tokenizer_result.attention_mask
print(input_ids.shape) # torch.Size([3, 11])
print(input_ids)
# tensor([[ 101, 2784, 4083, 2003, 3697, 2664, 2200, 10377, 2075, 1012, 102],
# [ 101, 2784, 4083, 2003, 2025, 3733, 1012, 102, 0, 0, 0],
# [ 101, 2021, 2003, 10377, 2075, 2065, 2589, 2157, 1012, 102, 0]])
print(attention_mask.shape) # torch.Size([3, 11])
print(attention_mask)
# tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
# [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
# [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])
Now, I call the BERT model to get the 768-D token embeddings (the top-layer hidden states).
model_result = model(input_ids, attention_mask=attention_mask, return_dict=True)
token_embeddings = model_result.last_hidden_state
print(token_embeddings.shape) # torch.Size([3, 11, 768])
So at this point, I have:
- token embeddings in a [3, 11, 768] matrix: 3 sentences, 11 tokens, 768-D vector for each token.
- attention mask in a [3, 11] matrix: 3 sentences, 11 tokens. A 1 value indicates non-
[PAD]
.
How do I compute the mean
/ max
over the vectors for the valid, non-[PAD]
tokens?
I tried using the attention mask as a mask and then called torch.max()
, but I don't get the right dimensions:
masked_token_embeddings = token_embeddings[attention_mask==1]
print(masked_token_embeddings.shape) # torch.Size([29, 768] <-- WRONG. SHOULD BE [3, 11, 768]
pooled = torch.max(masked_token_embeddings, 1)
print(pooled.values.shape) # torch.Size([29]) <-- WRONG. SHOULD BE [3, 768]
What I really want is a tensor of shape [3, 768]. That is, a 768-D vector for each of the 3 sentences.