Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions evals/cli/oaieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import shlex
import sys
from numbers import Number
from typing import Any, Mapping, Optional, Union, cast

import evals
Expand All @@ -18,6 +19,26 @@
logger = logging.getLogger(__name__)


def _flatten_usage_metrics(usage: Any, prefix: str = "") -> dict[str, Number]:
if usage is None:
return {}

if isinstance(usage, Number):
return {prefix: usage} if prefix else {}

if hasattr(usage, "model_dump"):
return _flatten_usage_metrics(usage.model_dump(exclude_none=True), prefix)

if isinstance(usage, Mapping):
flattened: dict[str, Number] = {}
for key, value in usage.items():
nested_prefix = f"{prefix}_{key}" if prefix else str(key)
flattened.update(_flatten_usage_metrics(value, nested_prefix))
return flattened

return {}


def _purple(str: str) -> str:
return f"\033[1;35m{str}\033[0m"

Expand Down Expand Up @@ -274,13 +295,14 @@ def add_token_usage_to_result(result: dict[str, Any], recorder: RecorderBase) ->
sampling_events = recorder.get_events("sampling")
for event in sampling_events:
if "usage" in event.data:
usage_events.append(dict(event.data["usage"]))
usage_events.append(_flatten_usage_metrics(event.data["usage"]))
logger.info(f"Found {len(usage_events)}/{len(sampling_events)} sampling events with usage data")
if usage_events:
# Sum up the usage of all samples (assumes the usage is the same for all samples)
# Sum up token usage across all sampling events, including nested usage breakdowns.
usage_keys = set().union(*(usage_event.keys() for usage_event in usage_events))
total_usage = {
key: sum(u[key] if u[key] is not None else 0 for u in usage_events)
for key in usage_events[0]
key: sum(usage_event.get(key, 0) for usage_event in usage_events)
for key in sorted(usage_keys)
}
total_usage_str = "\n".join(f"{key}: {value:,}" for key, value in total_usage.items())
logger.info(f"Token usage from {len(usage_events)} sampling events:\n{total_usage_str}")
Expand Down
54 changes: 54 additions & 0 deletions evals/cli/oaieval_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails

from evals.base import RunSpec
from evals.cli.oaieval import add_token_usage_to_result
from evals.record import DummyRecorder


def test_add_token_usage_to_result_flattens_nested_usage_details() -> None:
spec = RunSpec(
completion_fns=[""],
eval_name="",
base_eval="",
split="",
run_config={},
created_by="",
run_id="",
created_at="",
)
recorder = DummyRecorder(spec)

with recorder.as_default_recorder("sample-1"):
recorder.record_sampling(
prompt="prompt-1",
sampled="answer-1",
usage=CompletionUsage(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2),
prompt_tokens_details=PromptTokensDetails(cached_tokens=4),
),
)

with recorder.as_default_recorder("sample-2"):
recorder.record_sampling(
prompt="prompt-2",
sampled="answer-2",
usage=CompletionUsage(
prompt_tokens=7,
completion_tokens=6,
total_tokens=13,
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3),
prompt_tokens_details=PromptTokensDetails(cached_tokens=1),
),
)

result: dict[str, int] = {}
add_token_usage_to_result(result, recorder)

assert result["usage_prompt_tokens"] == 17
assert result["usage_completion_tokens"] == 11
assert result["usage_total_tokens"] == 28
assert result["usage_completion_tokens_details_reasoning_tokens"] == 5
assert result["usage_prompt_tokens_details_cached_tokens"] == 5
Loading