2
votes

I'm using the HuggingFace Transformers BERT model, and I want to compute a summary vector (a.k.a. embedding) over the tokens in a sentence, using either the mean or max function. The complication is that some tokens are [PAD], so I want to ignore the vectors for those tokens when computing the average or max.

Here's an example. I initially instantiate a BertTokenizer and a BertModel:

import torch
import transformers
from transformers import AutoTokenizer, AutoModel

transformer_name = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True)

model = AutoModel.from_pretrained(transformer_name)

I then input some sentences into the tokenizer and get out input_ids and attention_mask. Notably, an attention_mask value of 0 means that the token was a [PAD] that I can ignore.

sentences = ['Deep learning is difficult yet very rewarding.',
             'Deep learning is not easy.',
             'But is rewarding if done right.']
tokenizer_result = tokenizer(sentences, max_length=32, padding=True, return_attention_mask=True, return_tensors='pt')

input_ids = tokenizer_result.input_ids
attention_mask = tokenizer_result.attention_mask

print(input_ids.shape) # torch.Size([3, 11])

print(input_ids)
# tensor([[  101,  2784,  4083,  2003,  3697,  2664,  2200, 10377,  2075,  1012,  102],
#         [  101,  2784,  4083,  2003,  2025,  3733,  1012,   102,     0,     0,    0],
#         [  101,  2021,  2003, 10377,  2075,  2065,  2589,  2157,  1012,   102,   0]])

print(attention_mask.shape) # torch.Size([3, 11])

print(attention_mask)
# tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
#         [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
#         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])

Now, I call the BERT model to get the 768-D token embeddings (the top-layer hidden states).

model_result = model(input_ids, attention_mask=attention_mask, return_dict=True)

token_embeddings = model_result.last_hidden_state
print(token_embeddings.shape) # torch.Size([3, 11, 768])

So at this point, I have:

  1. token embeddings in a [3, 11, 768] matrix: 3 sentences, 11 tokens, 768-D vector for each token.
  2. attention mask in a [3, 11] matrix: 3 sentences, 11 tokens. A 1 value indicates non-[PAD].

How do I compute the mean / max over the vectors for the valid, non-[PAD] tokens?

I tried using the attention mask as a mask and then called torch.max(), but I don't get the right dimensions:

masked_token_embeddings = token_embeddings[attention_mask==1]
print(masked_token_embeddings.shape) # torch.Size([29, 768] <-- WRONG. SHOULD BE [3, 11, 768]

pooled = torch.max(masked_token_embeddings, 1)
print(pooled.values.shape) # torch.Size([29]) <-- WRONG. SHOULD BE [3, 768]

What I really want is a tensor of shape [3, 768]. That is, a 768-D vector for each of the 3 sentences.

1

1 Answers

3
votes

For max, you can multiply with attention_mask:

pooled = torch.max((token_embeddings * attention_mask.unsqueeze(-1)), axis=1)

For mean, you can sum along the axis and divide by attention_mask along that axis:

mean_pooled = token_embeddings.sum(axis=1) / attention_mask.sum(axis=-1).unsqueeze(-1)