0
votes

encode_plus in huggingface's transformers library allows truncation of the input sequence. Two parameters are relevant: truncation and max_length. I'm passing a paired input sequence to encode_plus and need to truncate the input sequence simply in a "cut off" manner, i.e., if the whole sequence consisting of both inputs text and text_pair is longer than max_length it should just be truncated correspondingly from the right.

It seems that neither of the truncation strategies allows to do this, instead longest_first removes tokens from the longest sequence (which could be either text or text_pair, but not just simply from the right or end of the sequence, e.g., if text is longer that text_pair, it seems this would remove tokens from text first), only_first and only_second remove tokens from only the first or second (hence, also not simply from the end), and do_not_truncate does not truncate at all. Or did I misunderstood this and actually longest_first might be what I'm looking for?

1

1 Answers

6
votes

No longest_first is not the same as cut from the right. When you set the truncation strategy to longest_first, the tokenizer will compare the length of both text and text_pair everytime a token needs to be removed and remove a token from the longest. The could for example mean that it will cut at first 3 tokens from text_pair and will cut the rest of the tokens which need be cut alternately from text and text_pair. An example:

from transformers import BertTokenizerFast

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

seq1 = 'This is a long uninteresting text'
seq2 = 'What could be a second sequence to the uninteresting text'

print(len(tokenizer.tokenize(seq1)))
print(len(tokenizer.tokenize(seq2)))

print(tokenizer(seq1, seq2))

print(tokenizer(seq1, seq2, truncation= True, max_length = 15))
print(tokenizer.decode(tokenizer(seq1, seq2, truncation= True, max_length = 15)['input_ids']))

Output:

9
13
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 1037, 2117, 5537, 2000, 1996, 4895, 18447, 18702, 3436, 3793, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 102, 2054, 2071, 2022, 1037, 2117, 5537, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[CLS] this is a long unint [SEP] what could be a second sequence [SEP]

As far as I can tell from your question you are actually looking for only_second because it cuts from the right (which is text_pair):

print(tokenizer(seq1, seq2, truncation= 'only_second', max_length = 15))

Output:

{'input_ids': [101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

It throw an exception when you try your text input is longer as the specified max_length. That is correct in my opinion because in this case it is not any longer a sequnece pair input.

Just in case only_second doesn't meet your requirements, you can simply create your own truncation strategy. As an example only_second by hand:


tok_seq1 = tokenizer.tokenize(seq1)
tok_seq2 = tokenizer.tokenize(seq2)

maxLengthSeq2 =  myMax_len - len(tok_seq1) - 3 #number of special tokens for bert sequence pair
if len(tok_seq2) >  maxLengthSeq2:
    tok_seq2 = tok_seq2[:maxLengthSeq2]

input_ids = [tokenizer.cls_token_id] 
input_ids += tokenizer.convert_tokens_to_ids(tok_seq1)
input_ids += [tokenizer.sep_token_id]

token_type_ids = [0]*len(input_ids)

input_ids += tokenizer.convert_tokens_to_ids(tok_seq2)
input_ids += [tokenizer.sep_token_id]
token_type_ids += [1]*(len(tok_seq2)+1) 


attention_mask = [1]*len(input_ids)
print(input_ids)
print(token_type_ids)
print(attention_mask)

Output:

[101, 2023, 2003, 1037, 2146, 4895, 18447, 18702, 3436, 3793, 102, 2054, 2071, 2022, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]