2
votes

Huggigface BERT implementation has a hack to remove the pooler from optimizer.

https://github.com/huggingface/transformers/blob/b832d5bb8a6dfc5965015b828e577677eace601e/examples/run_squad.py#L927

# hack to remove pooler, which is not used
# thus it produce None grad that break apex
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

We are trying to run pretrining on huggingface bert models. The code always diverges later during the training if this pooler hack is not applied. I also see the pooler layer being used during classification.

pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

The pooler layer is a FFN with tanh activation

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

My question is why this pooler hack solves numeric instability?

Problem seen with pooler

diminishing loss scaler

1

1 Answers

-1
votes

There are quite a few resources out there that probably tackle this issue better than me, see for example here, or here.

Specifically, the problem is that you are dealing with vanishing (or exploding) gradients, specifically when using loss functions that flatten in either direction for very small/large inputs, which is the case for both sigmoid and tanh (the only difference here is the range in which their output lies, which is [0, 1] and [-1, 1], respectively.

Additionally, if you have a low-precision decimal, as is the case with APEX, then the gradient vanishing behavior is much more likely to appear already for relatively moderate outputs, as the precision limits the numbers which it is able to differentiate from zero. One way to deal with this is to have functions that have strictly non-zero and easily computable derivatives, such as Leaky ReLU, or simply avoid the activation function altogether (which I'm assuming is what huggingface is doing here).

Note that the problem of exploding gradients is usually not as tragic, as we can apply gradient clipping (limiting it to a fixed maximum size), but nonetheless the principle is the same. For zeroed gradients, on the other hand, there is no such easy fix, since it causes your neurons to "die" (no active learning is happening with zero backflow), which is why I'm assuming that you see the diverging behavior.