diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index e48a09ac19..e6db7f5251 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -22,6 +22,23 @@ def _purple(str: str) -> str: return f"\033[1;35m{str}\033[0m" +def parse_key_value_args(param_str: str) -> dict[str, Any]: + def to_number(x: str) -> Union[str, int, float]: + try: + return int(x) + except (ValueError, TypeError): + pass + + try: + return float(x) + except (ValueError, TypeError): + pass + return x + + str_dict = dict(kv.split("=") for kv in param_str.split(",") if kv) + return {k: to_number(v) for k, v in str_dict.items()} + + def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run evals through the API") parser.add_argument( @@ -162,8 +179,7 @@ def to_number(x: str) -> Union[int, float, str]: eval_spec.args.update(extra_eval_params) # If the user provided an argument to --completion_args, parse it into a dict here, to be passed to the completion_fn creation **kwargs - completion_args = args.completion_args.split(",") - additional_completion_args = {k: v for k, v in (kv.split("=") for kv in completion_args if kv)} + additional_completion_args = parse_key_value_args(args.completion_args) completion_fns = args.completion_fn.split(",") completion_fn_instances = [ diff --git a/evals/cli/oaieval_test.py b/evals/cli/oaieval_test.py new file mode 100644 index 0000000000..700ba0e959 --- /dev/null +++ b/evals/cli/oaieval_test.py @@ -0,0 +1,18 @@ +import os + +os.environ.setdefault("OPENAI_API_KEY", "dummy") + +from evals.cli.oaieval import parse_key_value_args +from evals.registry import Registry + + +def test_parse_key_value_args_coerces_numbers() -> None: + parsed = parse_key_value_args("temperature=0.5,max_tokens=10,label=test") + + assert parsed == {"temperature": 0.5, "max_tokens": 10, "label": "test"} + + +def test_registry_make_completion_fn_passes_completion_args_to_chat_models() -> None: + completion_fn = Registry().make_completion_fn("gpt-3.5-turbo", temperature=0.5) + + assert completion_fn.extra_options == {"temperature": 0.5} \ No newline at end of file diff --git a/evals/completion_fns/openai.py b/evals/completion_fns/openai.py index 21524bfc1a..d6730cf58b 100644 --- a/evals/completion_fns/openai.py +++ b/evals/completion_fns/openai.py @@ -94,7 +94,10 @@ def __init__( self.api_base = api_base self.api_key = api_key self.n_ctx = n_ctx - self.extra_options = extra_options + self.extra_options = { + **(extra_options or {}), + **{key: value for key, value in kwargs.items() if key != "registry"}, + } def __call__( self, @@ -139,12 +142,16 @@ def __init__( api_key: Optional[str] = None, n_ctx: Optional[int] = None, extra_options: Optional[dict] = {}, + **kwargs, ): self.model = model self.api_base = api_base self.api_key = api_key self.n_ctx = n_ctx - self.extra_options = extra_options + self.extra_options = { + **(extra_options or {}), + **{key: value for key, value in kwargs.items() if key != "registry"}, + } def __call__( self, diff --git a/evals/completion_fns/openai_test.py b/evals/completion_fns/openai_test.py new file mode 100644 index 0000000000..93676e8f19 --- /dev/null +++ b/evals/completion_fns/openai_test.py @@ -0,0 +1,13 @@ +from evals.completion_fns.openai import OpenAIChatCompletionFn, OpenAICompletionFn + + +def test_openai_chat_completion_fn_promotes_extra_kwargs_to_extra_options() -> None: + completion_fn = OpenAIChatCompletionFn(model="gpt-3.5-turbo", temperature=0.5, registry=object()) + + assert completion_fn.extra_options == {"temperature": 0.5} + + +def test_openai_completion_fn_promotes_extra_kwargs_to_extra_options() -> None: + completion_fn = OpenAICompletionFn(model="text-davinci-003", temperature=0.5, registry=object()) + + assert completion_fn.extra_options == {"temperature": 0.5} \ No newline at end of file