diff --git a/evals/registry.py b/evals/registry.py index 2d1c0fee1d..5926e2f9fd 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -85,11 +85,21 @@ def is_chat_model(model_name: str) -> bool: return False CHAT_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"} + CHAT_MODEL_PREFIXES = { + "gpt-3.5-turbo-", + "gpt-4-", + "gpt-4o", + "gpt-4.1", + "gpt-4.5", + "o1", + "o3", + "o4", + } if model_name in CHAT_MODEL_NAMES: return True - for model_prefix in {"gpt-3.5-turbo-", "gpt-4-"}: + for model_prefix in CHAT_MODEL_PREFIXES: if model_name.startswith(model_prefix): return True diff --git a/evals/registry_test.py b/evals/registry_test.py index ef05316220..1e7bd341fc 100644 --- a/evals/registry_test.py +++ b/evals/registry_test.py @@ -1,4 +1,5 @@ -from evals.registry import is_chat_model, n_ctx_from_model_name +from evals import OpenAIChatCompletionFn +from evals.registry import Registry, is_chat_model, n_ctx_from_model_name def test_n_ctx_from_model_name(): @@ -27,6 +28,22 @@ def test_is_chat_model(): assert is_chat_model("gpt-4-0613") assert is_chat_model("gpt-4-32k") assert is_chat_model("gpt-4-32k-0613") + assert is_chat_model("gpt-4o") + assert is_chat_model("gpt-4o-mini") + assert is_chat_model("gpt-4.1") + assert is_chat_model("gpt-4.1-mini") + assert is_chat_model("o1") + assert is_chat_model("o1-mini") + assert is_chat_model("o3-mini") + assert is_chat_model("o4-mini") assert not is_chat_model("text-davinci-003") assert not is_chat_model("gpt4-base") assert not is_chat_model("code-davinci-002") + + +def test_make_completion_fn_uses_chat_completion_for_modern_chat_models(): + registry = Registry() + + completion_fn = registry.make_completion_fn("gpt-4o-mini") + + assert isinstance(completion_fn, OpenAIChatCompletionFn)