T5为啥没法用fp16训练

最近基于T5做一些分类任务的训练,发现使用混合精度训练,无论哪个型号的mT5,必出现输出为nan的问题。

调试的时候发现,模型在Encoder阶段还算正常,在Decoder阶段,模型计算的过程中很容易出现inf,然后进而导致在后续的计算过程中出现nan。

目前发现在T5的self Attention的计算过程中,会出现-inf的情况,主要是Attention的分数在进行mask的过程中出现的,在transformers的实现中,Attention的mask来自get_extended_attention_mask,这里返回的mask的数据类型是fp32,这就导致fp16格式的scores会出现-inf的问题:

T5 self-Attention的部分代码

if position_bias is None:
    if not self.has_relative_attention_bias:
        position_bias = torch.zeros(
            (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
        )
        if self.gradient_checkpointing and self.training:
            position_bias.requires_grad = True
    else:
        position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)

    # if key and values are already calculated
    # we want only the last query position bias
    if past_key_value is not None:
        position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

    if mask is not None:
        position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)

if self.pruned_heads:
    mask = torch.ones(position_bias.shape[1])
    mask[list(self.pruned_heads)] = 0
    position_bias_masked = position_bias[:, mask.bool()]
else:
    position_bias_masked = position_bias

# 这行是重点
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
    scores
)  # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
    attn_weights, p=self.dropout, training=self.training
)  # (batch_size, n_heads, seq_length, key_length)

但是即使修改了get_extended_attention_mask返回的mask为fp16,或者mask的数值改成-32752.0(fp16可以表示的最小数字的一半),在计算过程中依旧会出现-inf的问题,目测是在经过一些线性层的时候,出现的这个问题,毕竟fp16的可表示范围小了很多。

所以-inf的问题难以解决,然后在经过MT5LayerFF的时候,其中会经过一个激活函数,是transformer中自定义的一个NewGELUActivation激活函数,其中会出现inf * 0的情况,导致了出现nan的问题。

其实看T5的代码,里面有想解决这个inf问题的一些代码,比如在MT5Block中有下面这段代码:

self_attention_outputs = self.layer[0](
    hidden_states,
    attention_mask=attention_mask,
    position_bias=position_bias,
    layer_head_mask=layer_head_mask,
    past_key_value=self_attn_past_key_value,
    use_cache=use_cache,
    output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
    clamp_value = torch.where(
        torch.isinf(hidden_states).any(),
        torch.finfo(hidden_states.dtype).max - 1000,
        torch.finfo(hidden_states.dtype).max,
    )
    hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

这段代码前半部分是经过Decoder的第一层multi-head-Attention,然后输出的结果为hidden_states,正常来说如果hidden_states为fp16,那么会经过下面的剪切代码,防止出现inf的情况,然后再经过MT5LayerFF层就不会那么容易出现nan的问题。但是可惜的是,第一层multi-head-Attention的实现里面MT5LayerSelfAttention,最后输出是残差结构,而残差是加法操作,在混合精度计算里面,加法操作用fp32进行计算,这就导致Decoder的第一层multi-head-Attention的输出的结果为fp32,这样就没有经过下面的针对fp16的裁剪。目前看这个针对fp16的裁剪可能在推理的时候有点用。

同时根据网上搜集的资料:
http://github.com/huggingface/transformers/issues/23918
http://www.philschmid.de/fine-tune-flan-t5-deepspeed
http://github.com/huggingface/transformers/issues/4586
http://github.com/huggingface/transformers/issues/10830
http://github.com/huggingface/transformers/pull/10956

发现这个可能并不是T5的问题,而是使用bf16训练的模型,然后使用fp16进行训练的通病,因为bf16比fp16表示的整数范围更大,和fp32一致,这就导致bf16转到fp16的时候出现问题,导致模型的数据出现异常。目前针对这个问题的比较好的解决办法就是,如果显卡支持bf16,可以使用bf16进行计算,否则使用fp32进行计算

本站无任何商业行为
个人在线分享 » T5 fp16训练nan问题
E-->