diff --git a/evals/solvers/human_cli_solver.py b/evals/solvers/human_cli_solver.py index c3ac19d62d..ef2ae1b959 100644 --- a/evals/solvers/human_cli_solver.py +++ b/evals/solvers/human_cli_solver.py @@ -1,6 +1,7 @@ +from copy import deepcopy from typing import Any -from evals.record import record_sampling +from evals.record import record_event, record_sampling from evals.solvers.solver import Solver, SolverResult from evals.task_state import Message, TaskState @@ -15,6 +16,7 @@ class HumanCliSolver(Solver): def __init__( self, input_prompt: str = "assistant (you): ", + explain: bool = False, postprocessors: list[str] = [], registry: Any = None, ): @@ -22,18 +24,54 @@ def __init__( Args: input_prompt: Prompt to be printed before the user input. If None, no prompt is printed. + explain: If True, print structured context and output details for + human-in-the-loop debugging. """ super().__init__(postprocessors=postprocessors) self.input_prompt = input_prompt + self.explain = explain + + def _get_messages(self, task_state: TaskState) -> list[Message]: + return [Message("system", task_state.task_description)] + task_state.messages + + def _get_prompt(self, task_state: TaskState) -> str: + msgs = self._get_messages(task_state) + return "\n".join([f"{msg.role}: {msg.content}" for msg in msgs]) + f"\n{self.input_prompt}" + + def _print_explain_context(self, task_state: TaskState, prompt: str) -> None: + print("================ TASK CONTEXT (system) ================") + print(task_state.task_description) + print() + print("================ MESSAGE HISTORY ================") + if task_state.messages: + for idx, msg in enumerate(task_state.messages): + print(f"[{idx}] {msg.role}: {msg.content}") + else: + print("(no prior messages)") + print() + print("================ FINAL PROMPT STRING ================") + print(prompt) + print() + print("================ AWAITING HUMAN INPUT ================") + + def _print_explain_outcome(self, prompt: str, raw_answer: str, final_answer: str) -> None: + print() + print("================ SAMPLING RECORD ================") + print("model: human") + print(f"prompt_chars: {len(prompt)}") + print(f"answer_chars: {len(raw_answer)}") + print() + print("================ OUTPUT ================") + print(f"raw_answer: {raw_answer}") + print(f"final_answer: {final_answer}") def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: - msgs = [Message("system", task_state.task_description)] - msgs += task_state.messages + prompt = self._get_prompt(task_state) + if self.explain: + self._print_explain_context(task_state, prompt) - prompt = ( - "\n".join([f"{msg.role}: {msg.content}" for msg in msgs]) + f"\n{self.input_prompt}" - ) - answer = input(prompt) + cli_prompt = self.input_prompt if self.explain else prompt + answer = input(cli_prompt) record_sampling( prompt=prompt, @@ -41,7 +79,29 @@ def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: model="human", ) - return SolverResult(answer) + return SolverResult(answer, prompt=prompt) + + def __call__(self, task_state: TaskState, **kwargs) -> SolverResult: + res = self._solve(deepcopy(task_state), **kwargs) + raw_output = res.output + + if hasattr(self, "postprocessors"): + for postprocessor in self.postprocessors: + prev_output = res.output + res = postprocessor(res) + record_event( + "postprocessor", + { + "name": postprocessor.__class__.__name__, + "input": prev_output, + "output": res.output, + }, + ) + + if self.explain: + self._print_explain_outcome(res.metadata["prompt"], raw_output, res.output) + + return res @property def name(self) -> str: diff --git a/evals/solvers/human_cli_solver_test.py b/evals/solvers/human_cli_solver_test.py new file mode 100644 index 0000000000..bb96dc2f1f --- /dev/null +++ b/evals/solvers/human_cli_solver_test.py @@ -0,0 +1,101 @@ +import builtins + +import pytest + +from evals.base import RunSpec +from evals.record import DummyRecorder +from evals.solvers.human_cli_solver import HumanCliSolver +from evals.task_state import Message, TaskState + + +@pytest.fixture +def dummy_recorder(): + yield DummyRecorder( + RunSpec( + completion_fns=[""], + eval_name="", + base_eval="", + split="", + run_config={}, + created_by="", + run_id="", + created_at="", + ) + ) + + +def test_human_cli_solver_default_mode_keeps_existing_behavior(dummy_recorder, monkeypatch, capsys): + prompts = [] + + def fake_input(prompt: str) -> str: + prompts.append(prompt) + return "raw answer" + + monkeypatch.setattr(builtins, "input", fake_input) + + solver = HumanCliSolver() + task_state = TaskState("Follow the instructions.", [Message("user", "Hello")]) + + with dummy_recorder.as_default_recorder("x"): + result = solver(task_state) + + assert result.output == "raw answer" + assert prompts == ["system: Follow the instructions.\nuser: Hello\nassistant (you): "] + assert capsys.readouterr().out == "" + + sampling_event = dummy_recorder.get_events("sampling")[0] + assert sampling_event.data["prompt"] == prompts[0] + assert sampling_event.data["sampled"] == "raw answer" + assert sampling_event.data["model"] == "human" + + +def test_human_cli_solver_explain_mode_shows_context_and_postprocessed_output( + dummy_recorder, monkeypatch, capsys +): + prompts = [] + + def fake_input(prompt: str) -> str: + prompts.append(prompt) + return " done. " + + monkeypatch.setattr(builtins, "input", fake_input) + + solver = HumanCliSolver( + explain=True, + postprocessors=[ + "evals.solvers.postprocessors.postprocessors:Strip", + "evals.solvers.postprocessors.postprocessors:RemovePeriod", + ], + ) + task_state = TaskState( + "Answer as a human baseline.", + [Message("user", "First question"), Message("assistant", "First answer")], + ) + + with dummy_recorder.as_default_recorder("x"): + result = solver(task_state) + + assert result.output == "done" + assert prompts == ["assistant (you): "] + + output = capsys.readouterr().out + assert "================ TASK CONTEXT (system) ================" in output + assert "Answer as a human baseline." in output + assert "================ MESSAGE HISTORY ================" in output + assert "[0] user: First question" in output + assert "[1] assistant: First answer" in output + assert "================ FINAL PROMPT STRING ================" in output + assert "system: Answer as a human baseline." in output + assert "assistant (you): " in output + assert "================ SAMPLING RECORD ================" in output + assert "model: human" in output + assert "prompt_chars:" in output + assert "answer_chars: 9" in output + assert "================ OUTPUT ================" in output + assert "raw_answer: done. " in output + assert "final_answer: done" in output + + sampling_event = dummy_recorder.get_events("sampling")[0] + assert sampling_event.data["sampled"] == " done. " + postprocessor_events = dummy_recorder.get_events("postprocessor") + assert len(postprocessor_events) == 2