Skip to content

Commit 021d0b4

Browse files
committed
single shot experiments
1 parent 24e4c2d commit 021d0b4

File tree

3 files changed

+226
-22
lines changed

3 files changed

+226
-22
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
reports
22
data/benchmark_results
3+
*.csv
4+
*.html
5+
*.png
36

47
# Byte-compiled / optimized / DLL files
58
__pycache__/

src/eval/prompts.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,48 @@
6969
## Current Board ##
7070
{{current_board}}
7171
""".strip()
72+
73+
74+
ONE_SHOT_PROMPT = """You are a professional Sudoku puzzle solver. Please solve the following Sudoku variant.
75+
76+
## Format Explanation ##
77+
Coordinates:
78+
- We will use r{x}c{y} coordinates. For example, r1c1 is the top-left cell at row 1 column 1, r1c2 is the cell to the right at row 1 column 2, r2c1 is the cell below at row 2 column 1, and so on.
79+
{%- if pretty_visual_elements %}
80+
81+
Visual Elements:
82+
- Any visual elements will be described in text using rxcy coordinates.
83+
- Please note the visual elements will be described as-is. If a thermo or arrow appears on the board, the location of the circle or bulb will be listed, and the line or arrow will be listed as a separate object. But you can infer they are part of the same object by their coordinates.
84+
- If a visual element is described as "between" two cells, it means the visual element appears on the edge between the two cells.
85+
- In some puzzles there may be visual elements outside of the grid and these will be described using the same coordinate system. For example an arrow in r0c1 pointing to the lower right means there is an arrow above r1c1 that points in the direction of the diagonal: r1c2, r2c3, etc.
86+
- All visual elements are provided and provides sufficient information to solve the puzzle.
87+
{%- endif %}
88+
89+
## Tips ##
90+
In solving the puzzle it often helps to understand that there exists a unique solution.
91+
It therefore helps to focus on what values must be forced given the puzzle constraints, and given the fact that the solution is unique.
92+
You should try to commit a single value to a cell.
93+
94+
## Size ##
95+
{{rows}} x {{cols}}
96+
97+
## Rules ##
98+
{{rules}}
99+
{%- if pretty_visual_elements %}
100+
101+
## Visual Elements ##
102+
{{pretty_visual_elements}}
103+
{%- endif %}
104+
105+
## Current Board ##
106+
{{current_board}}
107+
108+
## Answer Format ##
109+
Please provide your answer at the end of your response. Put your answer within tags <ANSWER></ANSWER>. Your answer will be a sequence of {{rows}}x{{cols}} = {{ rows * cols }} digits.
110+
111+
For example, if the solution is 1234, your answer will be:
112+
<ANSWER>
113+
1234
114+
</ANSWER>
115+
116+
""".strip()

src/eval/run.py

Lines changed: 178 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
import sys
5151
from typing import Any, Dict, List, Optional, Union
5252
import uuid
53+
import re
54+
5355

5456
import aiohttp
5557
import anthropic
@@ -72,6 +74,7 @@
7274
BOARD_PROMPT,
7375
PREFILLED_ASSISTANT_RESPONSE,
7476
RULE_PROMPT,
77+
ONE_SHOT_PROMPT,
7578
)
7679
from eval.utils import (
7780
extract_action_from_response,
@@ -95,7 +98,19 @@ async def call_api(
9598
while attempt < args.max_retries:
9699
try:
97100
# OpenAI API
98-
if isinstance(client, openai.AsyncOpenAI):
101+
if args.api == "openrouter":
102+
kwargs = {
103+
"model": model,
104+
"messages": messages,
105+
"temperature": args.temperature,
106+
"top_p": args.top_p,
107+
}
108+
completion = await client.chat.completions.create(**kwargs)
109+
if not completion.choices:
110+
# typically due to rate limiting
111+
raise ValueError(f"API response missing 'choices'. Error: {completion.error['message']}")
112+
output_text = completion.choices[0].message.content
113+
elif isinstance(client, openai.AsyncOpenAI):
99114
kwargs = {
100115
"model": model,
101116
"messages": messages,
@@ -106,8 +121,9 @@ async def call_api(
106121
if "o1-" in model or "o3-" in model:
107122
kwargs["max_completion_tokens"] = args.max_tokens
108123
kwargs["temperature"] = 1.0
124+
kwargs.pop('top_p', None)
109125
# Special setting for R1 (TogetherAI)
110-
if model == "deepseek-ai/DeepSeek-R1":
126+
elif model == "deepseek-ai/DeepSeek-R1":
111127
kwargs["temperature"] = 0.6
112128
kwargs["max_tokens"] = None
113129
else:
@@ -132,6 +148,8 @@ async def call_api(
132148
"type": "enabled",
133149
"budget_tokens": args.max_tokens - 1024,
134150
}
151+
kwargs.pop('top_k', None)
152+
kwargs.pop('top_p', None)
135153
# Handle system prompts for Anthropic APIs
136154
if messages[0]["role"] == "system":
137155
kwargs["system"] = messages[0]["content"]
@@ -164,6 +182,7 @@ async def call_api(
164182

165183
except Exception as e:
166184
attempt += 1
185+
print(f"Attempt {attempt} failed. {e}")
167186
if attempt == args.max_retries:
168187
print(f"Failed after {args.max_retries} attempts for request. {e}")
169188
return None
@@ -357,18 +376,121 @@ async def process_one(
357376
}
358377

359378

379+
async def process_one_single_shot(
380+
args: argparse.Namespace,
381+
client: Union[Any],
382+
request: Dict,
383+
model: str,
384+
tokenizer: Optional[AutoTokenizer] = None,
385+
) -> Dict:
386+
# Load data
387+
rules = request["rules"]
388+
initial_board_ascii = request["initial_board"]
389+
solution_ascii = request["solution"]
390+
rows = request["rows"]
391+
cols = request["cols"]
392+
visual_elements = request["visual_elements"]
393+
if pd.isna(visual_elements) or visual_elements == "":
394+
visual_elements = None
395+
n_history_turns = request["n_history_turns"] # Keep for consistency in output format
396+
397+
# Construct setting string (simplified for single-shot)
398+
settings = ["single-shot"]
399+
if request["num_empty_cells"] > 0:
400+
settings.append(f"{request['num_empty_cells']}-empty-{request['shuffle_seed']}-seed")
401+
setting = "_".join(settings)
402+
403+
# Pretty print visual elements
404+
pretty_visual_elements = None
405+
if visual_elements is not None:
406+
try:
407+
visual_elements = json.loads(visual_elements)
408+
pretty_visual_elements = pretty_print_visual_elements(visual_elements)
409+
except json.JSONDecodeError:
410+
print(f"Warning: Could not parse visual_elements for puzzle {request['puzzle_id']}")
411+
visual_elements = None # Set to None if parsing fails
412+
413+
# Construct the single prompt
414+
one_shot_prompt = jinja2.Template(ONE_SHOT_PROMPT).render(
415+
rows=rows,
416+
cols=cols,
417+
rules=rules,
418+
pretty_visual_elements=pretty_visual_elements,
419+
current_board=initial_board_ascii,
420+
)
421+
422+
# Single conversation turn
423+
input_conversation = [{"role": "user", "content": one_shot_prompt}]
424+
425+
# Call API
426+
assistant_response = await call_api(
427+
args=args,
428+
client=client,
429+
model=model,
430+
tokenizer=tokenizer,
431+
messages=input_conversation,
432+
)
433+
434+
parsed_answer = ""
435+
final_solved = 0
436+
437+
if assistant_response:
438+
# Update conversation history (just this single turn)
439+
conversation = input_conversation + [{"role": "assistant", "content": assistant_response}]
440+
441+
# Extract answer from <ANSWER> tags
442+
match = re.search(r"<ANSWER>(.*?)</ANSWER>", assistant_response, re.DOTALL | re.IGNORECASE)
443+
if match:
444+
answer_content = match.group(1)
445+
# Extract only digits
446+
parsed_answer = "".join(filter(str.isdigit, answer_content))
447+
# Check if parsed answer matches solution
448+
if parsed_answer == solution_ascii:
449+
final_solved = 1
450+
print(f"[Pass] Puzzle {request['puzzle_id']} solved correctly.")
451+
else:
452+
print(f"[Fail] Puzzle {request['puzzle_id']} incorrect.")
453+
else:
454+
print(f"[Fail] Puzzle {request['puzzle_id']}. No <ANSWER> tag found in response.")
455+
conversation = input_conversation # Record only user prompt if assistant failed
456+
else:
457+
print(f"[Fail] Puzzle {request['puzzle_id']}. No response from server.")
458+
conversation = input_conversation # Record only user prompt if assistant failed
459+
460+
461+
return {
462+
# From input
463+
"data_source": args.dataset,
464+
"puzzle_id": request["puzzle_id"],
465+
"model": args.model_save_name if args.model_save_name else model,
466+
"num_empty_cells": request["num_empty_cells"],
467+
"shuffle_seed": request["shuffle_seed"],
468+
"n_response_idx": request["n_response_idx"],
469+
"n_history_turns": n_history_turns, # Retained for consistency
470+
"setting": setting,
471+
"initial_board": request["initial_board"],
472+
# From output
473+
"conversation": json.dumps(conversation),
474+
"num_rounds": 1, # Single shot = 1 round
475+
"num_correct_placements": final_solved, # Treat solved as 1 correct 'placement'
476+
"final_solved": final_solved,
477+
"final_board": parsed_answer, # The parsed digit string from the response
478+
}
479+
480+
360481
async def process_batch(
361482
args: argparse.Namespace,
362483
requests: List[Dict],
363484
client: Union[Any],
364485
model: str,
365486
tokenizer: Optional[AutoTokenizer] = None,
366-
batch_size: int = 1
487+
batch_size: int = 1,
488+
process_func: callable = process_one
367489
) -> List[Dict]:
368490
semaphore = asyncio.Semaphore(batch_size)
369491
async def process_with_semaphore(request):
370492
async with semaphore:
371-
return await process_one(
493+
return await process_func(
372494
args=args,
373495
client=client,
374496
request=request,
@@ -383,7 +505,8 @@ async def process_with_semaphore(request):
383505
with tqdm(total=len(tasks), desc="Processing requests") as pbar:
384506
for coro in asyncio.as_completed(tasks):
385507
result = await coro
386-
outputs.append(result)
508+
if result:
509+
outputs.append(result)
387510
pbar.update(1)
388511

389512
return outputs
@@ -430,7 +553,7 @@ def construct_request(
430553

431554

432555
def main():
433-
parser = argparse.ArgumentParser(description="Evaluate LLM on Sudoku puzzles in a multi-round manner.")
556+
parser = argparse.ArgumentParser(description="Evaluate LLM on Sudoku puzzles.")
434557

435558
# Filepaths
436559
parser.add_argument("--dataset", type=str, required=True, choices=["challenge_100", "nikoli_100", "ctc"],
@@ -447,6 +570,8 @@ def main():
447570
help="Specific puzzle indices to evaluate. Overrides start/end.")
448571

449572
# Eval setting
573+
parser.add_argument("--mode", type=str, default="multi_round", choices=["multi_round", "single_shot"],
574+
help="Evaluation mode: multi-round interaction or single-shot completion.")
450575
# The number of evaluations for each puzzle is the product of the following four arguments.
451576
parser.add_argument("--num_empty_cells", type=int, nargs="+", default=[0, 10, 20],
452577
help="Number of empty cells in the intial board after hint fill in random cells. "
@@ -460,7 +585,7 @@ def main():
460585

461586
# Model
462587
parser.add_argument("--api", type=str, default="openai",
463-
choices=["openai", "anthropic", "anthropic_bedrock", "deepseek", "vllm", "togetherai"],
588+
choices=["openai", "anthropic", "anthropic_bedrock", "deepseek", "vllm", "togetherai", "openrouter"],
464589
help="API to use.")
465590
parser.add_argument("--model", type=str, required=True,
466591
help="Model name or path.")
@@ -494,6 +619,9 @@ def main():
494619
# Sanity check
495620
assert args.num_empty_cells != [0] or len(args.shuffle_seeds) == 1, \
496621
"shuffle_seed is only used when providing hints (i.e. num_empty_cells > 0)."
622+
if args.mode == "single_shot" and args.n_history_turns != [5]:
623+
print("Warning: --n_history_turns is ignored in single_shot mode.")
624+
args.n_history_turns = [0]
497625

498626
# Load puzzle
499627
dataset = datasets.load_dataset("SakanaAI/Sudoku-Bench", args.dataset, split="test")
@@ -538,6 +666,11 @@ def main():
538666
client = openai.AsyncOpenAI(
539667
api_key=os.environ.get("OPENAI_API_KEY"),
540668
)
669+
elif args.api == "openrouter":
670+
client = openai.AsyncOpenAI(
671+
api_key=os.environ.get("OPENROUTER_API_KEY"),
672+
base_url="https://openrouter.ai/api/v1",
673+
)
541674
elif args.api == "anthropic":
542675
client = anthropic.AsyncAnthropic(
543676
api_key=os.environ.get("ANTHROPIC_API_KEY"),
@@ -555,9 +688,11 @@ def main():
555688
)
556689
elif args.api == "togetherai":
557690
client = openai.AsyncOpenAI(
558-
api_key=os.environ.get("TOGETHERAI_API_KEY"),
691+
api_key=os.environ.get("TOGETHER_API_KEY"),
559692
base_url="https://api.together.xyz/v1",
560693
)
694+
elif args.api == "googleai": # Add googleai client setup
695+
client = genai.Client(api_key=os.environ.get("GOOGLE_API_KEY"))
561696
elif args.api == "vllm":
562697
client = AsyncLLMEngine.from_engine_args(
563698
AsyncEngineArgs(
@@ -575,36 +710,57 @@ def main():
575710
)
576711
tokenizer = AutoTokenizer.from_pretrained(args.model)
577712

578-
# Process batch
713+
# Select processing function based on mode
714+
process_func = process_one if args.mode == "multi_round" else process_one_single_shot
715+
print(f"Running in {args.mode} mode.")
716+
717+
# Process batch using the selected function
579718
all_results = asyncio.run(process_batch(
580719
args=args,
581720
batch_size=args.batch_size,
582721
requests=requests,
583722
client=client,
584723
tokenizer=tokenizer,
585-
model=args.model
724+
model=args.model,
725+
process_func=process_func # Pass the selected function
586726
))
587727

588728
# Convert results to DataFrame
729+
if not all_results: # Check if list is empty
730+
print("No results generated. Exiting.")
731+
return # Exit if no results
589732
res_df = pd.DataFrame(all_results)
590733
if len(res_df) == 0:
591-
print("No results to save. Possibly no puzzles or an error occurred.")
734+
print("No results to save. DataFrame is empty.")
592735
return
593736

594737
# Print summary
595738
# We'll measure average number of correct placements and fraction of puzzles solved.
596739
group_cols = ["num_empty_cells", "setting", "model"]
597-
summary = (
598-
res_df
599-
.groupby(group_cols)
600-
.agg({
601-
"num_correct_placements": "mean",
602-
"final_solved": "mean"
603-
})
604-
.reset_index()
605-
)
606-
with pd.option_context("display.max_rows", None, "display.precision", 2):
607-
print(summary)
740+
agg_metrics = {"final_solved": "mean"}
741+
if args.mode == "multi_round":
742+
# Only include multi-round specific metrics if in that mode
743+
agg_metrics["num_correct_placements"] = "mean"
744+
745+
# Ensure columns exist before aggregation
746+
existing_cols = [col for col in group_cols if col in res_df.columns]
747+
if not existing_cols:
748+
print("Grouping columns not found in results DataFrame. Cannot generate summary.")
749+
summary_str = "Summary could not be generated."
750+
else:
751+
summary = (
752+
res_df
753+
.groupby(existing_cols)
754+
.agg(agg_metrics)
755+
.reset_index()
756+
)
757+
with pd.option_context("display.max_rows", None, "display.precision", 3):
758+
summary_str = summary.to_string()
759+
760+
print("\\n--- Summary ---")
761+
print(summary_str)
762+
print("---------------")
763+
608764

609765
# Save results to CSV
610766
os.makedirs(os.path.dirname(args.output_csv), exist_ok=True)

0 commit comments

Comments
 (0)