Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions evals/cli/oaieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ def _purple(str: str) -> str:
return f"\033[1;35m{str}\033[0m"


def parse_key_value_args(param_str: str) -> dict[str, Any]:
def to_number(x: str) -> Union[str, int, float]:
try:
return int(x)
except (ValueError, TypeError):
pass

try:
return float(x)
except (ValueError, TypeError):
pass
return x

str_dict = dict(kv.split("=") for kv in param_str.split(",") if kv)
return {k: to_number(v) for k, v in str_dict.items()}


def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run evals through the API")
parser.add_argument(
Expand Down Expand Up @@ -162,8 +179,7 @@ def to_number(x: str) -> Union[int, float, str]:
eval_spec.args.update(extra_eval_params)

# If the user provided an argument to --completion_args, parse it into a dict here, to be passed to the completion_fn creation **kwargs
completion_args = args.completion_args.split(",")
additional_completion_args = {k: v for k, v in (kv.split("=") for kv in completion_args if kv)}
additional_completion_args = parse_key_value_args(args.completion_args)

completion_fns = args.completion_fn.split(",")
completion_fn_instances = [
Expand Down
18 changes: 18 additions & 0 deletions evals/cli/oaieval_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

os.environ.setdefault("OPENAI_API_KEY", "dummy")

from evals.cli.oaieval import parse_key_value_args
from evals.registry import Registry


def test_parse_key_value_args_coerces_numbers() -> None:
parsed = parse_key_value_args("temperature=0.5,max_tokens=10,label=test")

assert parsed == {"temperature": 0.5, "max_tokens": 10, "label": "test"}


def test_registry_make_completion_fn_passes_completion_args_to_chat_models() -> None:
completion_fn = Registry().make_completion_fn("gpt-3.5-turbo", temperature=0.5)

assert completion_fn.extra_options == {"temperature": 0.5}
11 changes: 9 additions & 2 deletions evals/completion_fns/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def __init__(
self.api_base = api_base
self.api_key = api_key
self.n_ctx = n_ctx
self.extra_options = extra_options
self.extra_options = {
**(extra_options or {}),
**{key: value for key, value in kwargs.items() if key != "registry"},
}

def __call__(
self,
Expand Down Expand Up @@ -139,12 +142,16 @@ def __init__(
api_key: Optional[str] = None,
n_ctx: Optional[int] = None,
extra_options: Optional[dict] = {},
**kwargs,
):
self.model = model
self.api_base = api_base
self.api_key = api_key
self.n_ctx = n_ctx
self.extra_options = extra_options
self.extra_options = {
**(extra_options or {}),
**{key: value for key, value in kwargs.items() if key != "registry"},
}

def __call__(
self,
Expand Down
13 changes: 13 additions & 0 deletions evals/completion_fns/openai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from evals.completion_fns.openai import OpenAIChatCompletionFn, OpenAICompletionFn


def test_openai_chat_completion_fn_promotes_extra_kwargs_to_extra_options() -> None:
completion_fn = OpenAIChatCompletionFn(model="gpt-3.5-turbo", temperature=0.5, registry=object())

assert completion_fn.extra_options == {"temperature": 0.5}


def test_openai_completion_fn_promotes_extra_kwargs_to_extra_options() -> None:
completion_fn = OpenAICompletionFn(model="text-davinci-003", temperature=0.5, registry=object())

assert completion_fn.extra_options == {"temperature": 0.5}
Loading