Skip to content

The mlm loss computation in the function _get_batch_loss_bert seems wrong in d2l pytorch code  #2582

@lyconghk

Description

@lyconghk

In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:

The CrossEntropyLoss is initialized with default reduction 'mean',
loss = nn.CrossEntropyLoss()
In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation.
mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1)
Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions