-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Add explain mode to HumanCliSolver #1652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from copy import deepcopy | ||
| from typing import Any | ||
|
|
||
| from evals.record import record_sampling | ||
| from evals.record import record_event, record_sampling | ||
| from evals.solvers.solver import Solver, SolverResult | ||
| from evals.task_state import Message, TaskState | ||
|
|
||
|
|
@@ -15,33 +16,92 @@ class HumanCliSolver(Solver): | |
| def __init__( | ||
| self, | ||
| input_prompt: str = "assistant (you): ", | ||
| explain: bool = False, | ||
| postprocessors: list[str] = [], | ||
| registry: Any = None, | ||
| ): | ||
| """ | ||
| Args: | ||
| input_prompt: Prompt to be printed before the user input. | ||
| If None, no prompt is printed. | ||
| explain: If True, print structured context and output details for | ||
|
Comment on lines
18
to
+27
|
||
| human-in-the-loop debugging. | ||
| """ | ||
| super().__init__(postprocessors=postprocessors) | ||
| self.input_prompt = input_prompt | ||
| self.explain = explain | ||
|
|
||
| def _get_messages(self, task_state: TaskState) -> list[Message]: | ||
| return [Message("system", task_state.task_description)] + task_state.messages | ||
|
|
||
| def _get_prompt(self, task_state: TaskState) -> str: | ||
| msgs = self._get_messages(task_state) | ||
| return "\n".join([f"{msg.role}: {msg.content}" for msg in msgs]) + f"\n{self.input_prompt}" | ||
|
|
||
| def _print_explain_context(self, task_state: TaskState, prompt: str) -> None: | ||
| print("================ TASK CONTEXT (system) ================") | ||
| print(task_state.task_description) | ||
| print() | ||
| print("================ MESSAGE HISTORY ================") | ||
| if task_state.messages: | ||
| for idx, msg in enumerate(task_state.messages): | ||
| print(f"[{idx}] {msg.role}: {msg.content}") | ||
| else: | ||
| print("(no prior messages)") | ||
| print() | ||
| print("================ FINAL PROMPT STRING ================") | ||
| print(prompt) | ||
| print() | ||
| print("================ AWAITING HUMAN INPUT ================") | ||
|
|
||
| def _print_explain_outcome(self, prompt: str, raw_answer: str, final_answer: str) -> None: | ||
| print() | ||
| print("================ SAMPLING RECORD ================") | ||
| print("model: human") | ||
| print(f"prompt_chars: {len(prompt)}") | ||
| print(f"answer_chars: {len(raw_answer)}") | ||
| print() | ||
| print("================ OUTPUT ================") | ||
| print(f"raw_answer: {raw_answer}") | ||
| print(f"final_answer: {final_answer}") | ||
|
|
||
| def _solve(self, task_state: TaskState, **kwargs) -> SolverResult: | ||
| msgs = [Message("system", task_state.task_description)] | ||
| msgs += task_state.messages | ||
| prompt = self._get_prompt(task_state) | ||
| if self.explain: | ||
| self._print_explain_context(task_state, prompt) | ||
|
|
||
| prompt = ( | ||
| "\n".join([f"{msg.role}: {msg.content}" for msg in msgs]) + f"\n{self.input_prompt}" | ||
| ) | ||
| answer = input(prompt) | ||
| cli_prompt = self.input_prompt if self.explain else prompt | ||
| answer = input(cli_prompt) | ||
|
|
||
| record_sampling( | ||
| prompt=prompt, | ||
| sampled=answer, | ||
| model="human", | ||
| ) | ||
|
|
||
| return SolverResult(answer) | ||
| return SolverResult(answer, prompt=prompt) | ||
|
|
||
| def __call__(self, task_state: TaskState, **kwargs) -> SolverResult: | ||
| res = self._solve(deepcopy(task_state), **kwargs) | ||
| raw_output = res.output | ||
|
|
||
| if hasattr(self, "postprocessors"): | ||
| for postprocessor in self.postprocessors: | ||
| prev_output = res.output | ||
| res = postprocessor(res) | ||
| record_event( | ||
| "postprocessor", | ||
| { | ||
| "name": postprocessor.__class__.__name__, | ||
| "input": prev_output, | ||
| "output": res.output, | ||
| }, | ||
| ) | ||
|
Comment on lines
+84
to
+99
|
||
|
|
||
| if self.explain: | ||
| self._print_explain_outcome(res.metadata["prompt"], raw_output, res.output) | ||
|
|
||
| return res | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| import builtins | ||
|
|
||
| import pytest | ||
|
|
||
| from evals.base import RunSpec | ||
| from evals.record import DummyRecorder | ||
| from evals.solvers.human_cli_solver import HumanCliSolver | ||
| from evals.task_state import Message, TaskState | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def dummy_recorder(): | ||
| yield DummyRecorder( | ||
| RunSpec( | ||
| completion_fns=[""], | ||
| eval_name="", | ||
| base_eval="", | ||
| split="", | ||
| run_config={}, | ||
| created_by="", | ||
| run_id="", | ||
| created_at="", | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| def test_human_cli_solver_default_mode_keeps_existing_behavior(dummy_recorder, monkeypatch, capsys): | ||
| prompts = [] | ||
|
|
||
| def fake_input(prompt: str) -> str: | ||
| prompts.append(prompt) | ||
| return "raw answer" | ||
|
|
||
| monkeypatch.setattr(builtins, "input", fake_input) | ||
|
|
||
| solver = HumanCliSolver() | ||
| task_state = TaskState("Follow the instructions.", [Message("user", "Hello")]) | ||
|
|
||
| with dummy_recorder.as_default_recorder("x"): | ||
| result = solver(task_state) | ||
|
|
||
| assert result.output == "raw answer" | ||
| assert prompts == ["system: Follow the instructions.\nuser: Hello\nassistant (you): "] | ||
| assert capsys.readouterr().out == "" | ||
|
|
||
| sampling_event = dummy_recorder.get_events("sampling")[0] | ||
| assert sampling_event.data["prompt"] == prompts[0] | ||
| assert sampling_event.data["sampled"] == "raw answer" | ||
| assert sampling_event.data["model"] == "human" | ||
|
|
||
|
|
||
| def test_human_cli_solver_explain_mode_shows_context_and_postprocessed_output( | ||
| dummy_recorder, monkeypatch, capsys | ||
| ): | ||
| prompts = [] | ||
|
|
||
| def fake_input(prompt: str) -> str: | ||
| prompts.append(prompt) | ||
| return " done. " | ||
|
|
||
| monkeypatch.setattr(builtins, "input", fake_input) | ||
|
|
||
| solver = HumanCliSolver( | ||
| explain=True, | ||
| postprocessors=[ | ||
| "evals.solvers.postprocessors.postprocessors:Strip", | ||
| "evals.solvers.postprocessors.postprocessors:RemovePeriod", | ||
| ], | ||
| ) | ||
| task_state = TaskState( | ||
| "Answer as a human baseline.", | ||
| [Message("user", "First question"), Message("assistant", "First answer")], | ||
| ) | ||
|
|
||
| with dummy_recorder.as_default_recorder("x"): | ||
| result = solver(task_state) | ||
|
|
||
| assert result.output == "done" | ||
| assert prompts == ["assistant (you): "] | ||
|
|
||
| output = capsys.readouterr().out | ||
| assert "================ TASK CONTEXT (system) ================" in output | ||
| assert "Answer as a human baseline." in output | ||
| assert "================ MESSAGE HISTORY ================" in output | ||
| assert "[0] user: First question" in output | ||
| assert "[1] assistant: First answer" in output | ||
| assert "================ FINAL PROMPT STRING ================" in output | ||
| assert "system: Answer as a human baseline." in output | ||
| assert "assistant (you): " in output | ||
| assert "================ SAMPLING RECORD ================" in output | ||
| assert "model: human" in output | ||
| assert "prompt_chars:" in output | ||
| assert "answer_chars: 9" in output | ||
| assert "================ OUTPUT ================" in output | ||
| assert "raw_answer: done. " in output | ||
| assert "final_answer: done" in output | ||
|
|
||
| sampling_event = dummy_recorder.get_events("sampling")[0] | ||
| assert sampling_event.data["sampled"] == " done. " | ||
| postprocessor_events = dummy_recorder.get_events("postprocessor") | ||
| assert len(postprocessor_events) == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
postprocessorsuses a mutable list default ([]). This can lead to accidental state sharing if the list is ever mutated (now or in future refactors). Preferpostprocessors: list[str] | None = Noneand then passpostprocessors or []tosuper().__init__.