In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:
The CrossEntropyLoss is initialized with default reduction 'mean',
loss = nn.CrossEntropyLoss()
In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation.
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1)
Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.
In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:
The CrossEntropyLoss is initialized with default reduction 'mean',
loss = nn.CrossEntropyLoss()In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation.
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1)Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.