Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fa394ca
feat(metrics): add Perplexity metric to NLP metrics
steaphenai Apr 20, 2026
e1e1e5f
fix(metrics): detach Perplexity accumulators and refine tests
steaphenai Apr 21, 2026
4335359
test(metrics): align Perplexity tests with metric patterns
steaphenai Apr 21, 2026
d5ab433
fix(metrics): address Perplexity review follow-ups
steaphenai Apr 21, 2026
dae37e9
test(metrics): use _reference_perplexity in token-weighted accumulati…
steaphenai Apr 21, 2026
f9ecaa1
Update tests/ignite/metrics/nlp/test_perplexity.py
steaphenai Apr 23, 2026
e5c0cfd
feat(metrics): add ignore_index to Perplexity, expose in docs, remove…
steaphenai Apr 23, 2026
46a3b8d
Merge branch 'master' into feat/perplexity-metric-pr
steaphenai Apr 23, 2026
fa8fb7f
style: fix ruff formatting in test_perplexity.py
steaphenai Apr 23, 2026
49ccdc4
Merge branch 'feat/perplexity-metric-pr' of https://github.com/steaph…
steaphenai Apr 23, 2026
c7a3720
fix(tests): fix token weighted accumulation test with different seq l…
steaphenai Apr 23, 2026
3143650
fix(tests): use _reference_perplexity and matching seq lengths in acc…
steaphenai Apr 23, 2026
b514e12
fix(metrics): remove explicit double dtype from Perplexity accumulato…
steaphenai Apr 24, 2026
832d3c2
Merge branch 'master' into feat/perplexity-metric-pr
steaphenai May 23, 2026
1d6ba0b
Merge branch 'master' into feat/perplexity-metric-pr
aaishwarymishra Jun 5, 2026
4aa534d
Merge branch 'master' into feat/perplexity-metric-pr
TahaZahid05 Jun 10, 2026
494d282
Merge branch 'master' into feat/perplexity-metric-pr
steaphenai Jun 12, 2026
c8202ac
fix(metrics): address Perplexity PR review feedback
steaphenai Jun 12, 2026
d191f11
Update ignite/metrics/nlp/perplexity.py
steaphenai Jun 13, 2026
901e50b
Update tests/ignite/metrics/nlp/test_perplexity.py
steaphenai Jun 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ignite.metrics.mutual_information import MutualInformation
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
from ignite.metrics.nlp.perplexity import Perplexity
from ignite.metrics.precision import Precision
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.metrics.psnr import PSNR
Expand Down Expand Up @@ -93,6 +94,7 @@
"Rouge",
"RougeN",
"RougeL",
"Perplexity",
"regression",
"clustering",
"fairness",
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.perplexity import Perplexity
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN

__all__ = [
"Bleu",
"Perplexity",
Comment thread
steaphenai marked this conversation as resolved.
"Rouge",
"RougeN",
"RougeL",
Expand Down
103 changes: 103 additions & 0 deletions ignite/metrics/nlp/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from collections.abc import Callable

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["Perplexity"]


class Perplexity(Metric):
r"""Calculates the `Perplexity <https://en.wikipedia.org/wiki/Perplexity>`_ of a language model.

.. math::
\text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right)

where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the
conditional probability of token :math:`w_i` given the preceding tokens.

Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood
over all tokens. Lower perplexity indicates a better language model.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)``
containing the unnormalized log-probabilities (logits).
- `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices.

Note:
Perplexity uses token-weighted accumulation rather than batch-average to avoid bias
towards shorter sequences. The total NLL and total token count are accumulated across
all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. testcode::

from ignite.metrics.nlp import Perplexity
import torch

ppl = Perplexity()

# batch_size=2, vocab_size=5, seq_len=3
y_pred = torch.log_softmax(torch.randn(2, 5, 3), dim=1)
Comment thread
TahaZahid05 marked this conversation as resolved.
Outdated
Comment thread
steaphenai marked this conversation as resolved.
Outdated
y = torch.randint(0, 5, (2, 3))

ppl.update((y_pred, y))

print(type(ppl.compute()))

.. testoutput::

<class 'float'>

.. versionadded:: 0.5.2
Comment thread
TahaZahid05 marked this conversation as resolved.
Outdated
Comment thread
steaphenai marked this conversation as resolved.
Outdated
"""

_state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens")

def __init__(
Comment thread
steaphenai marked this conversation as resolved.
self,
output_transform: Callable = lambda x: x,
device: str | torch.device = torch.device("cpu"),
):
super().__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_nll = torch.tensor(0.0, dtype=torch.double, device=self._device)
self._num_tokens = torch.tensor(0, dtype=torch.long, device=self._device)

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
Comment thread
steaphenai marked this conversation as resolved.

if y_pred.ndim < 2:
raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})")

if y.ndim < 1:
raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})")
Comment thread
steaphenai marked this conversation as resolved.

Comment thread
TahaZahid05 marked this conversation as resolved.
nll = F.cross_entropy(y_pred, y, reduction="sum")
self._sum_of_nll += nll.to(self._device, dtype=torch.double)
self._num_tokens += y.numel()

@sync_all_reduce("_sum_of_nll", "_num_tokens")
def compute(self) -> float:
if self._num_tokens == 0:
raise NotComputableError("Perplexity must have at least one example before it can be computed.")

return torch.exp(self._sum_of_nll / self._num_tokens).item()
99 changes: 99 additions & 0 deletions tests/ignite/metrics/nlp/test_perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.nlp import Perplexity


def test_zero_sample():
ppl = Perplexity()
ppl.reset()
with pytest.raises(NotComputableError):
ppl.compute()


def test_compute_matches_manual():
torch.manual_seed(42)
ppl = Perplexity()
ppl.reset()

y_pred = torch.randn(4, 10, 5)
y = torch.randint(0, 10, (4, 5))

ppl.update((y_pred, y))

nll_manual = F.cross_entropy(y_pred, y, reduction="sum").item()
ppl_manual = torch.exp(torch.tensor(nll_manual / y.numel())).item()

assert abs(ppl.compute() - ppl_manual) < 1e-4


def test_token_weighted_accumulation():
"""Token-weighted accumulation must differ from naive batch average."""
torch.manual_seed(0)
ppl = Perplexity()
ppl.reset()

# Two batches with different sequence lengths
b1_pred = torch.randn(2, 5, 4)
b1_y = torch.randint(0, 5, (2, 4))
b2_pred = torch.randn(3, 5, 10)
b2_y = torch.randint(0, 5, (3, 10))

ppl.update((b1_pred, b1_y))
ppl.update((b2_pred, b2_y))

nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item()
Comment thread
steaphenai marked this conversation as resolved.
Outdated
nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item()
total_tokens = b1_y.numel() + b2_y.numel()
ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item()

assert abs(ppl.compute() - ppl_ref) < 1e-4


def test_returns_float():
Comment thread
steaphenai marked this conversation as resolved.
Outdated
torch.manual_seed(1)
ppl = Perplexity()
ppl.reset()

y_pred = torch.randn(2, 5, 3)
y = torch.randint(0, 5, (2, 3))
ppl.update((y_pred, y))

result = ppl.compute()
assert isinstance(result, float)


def test_invalid_y_pred_shape():
ppl = Perplexity()
ppl.reset()

with pytest.raises(ValueError, match="y_pred must be at least 2-dimensional"):
ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0])))


def test_reset_clears_state():
torch.manual_seed(2)
ppl = Perplexity()

y_pred = torch.randn(2, 5, 3)
y = torch.randint(0, 5, (2, 3))
ppl.update((y_pred, y))

ppl.reset()
with pytest.raises(NotComputableError):
ppl.compute()


def test_single_token():
ppl = Perplexity()
ppl.reset()

y_pred = torch.randn(1, 5, 1)
y = torch.randint(0, 5, (1, 1))
ppl.update((y_pred, y))

result = ppl.compute()
assert result > 0
assert isinstance(result, float)