50
50
import sys
51
51
from typing import Any , Dict , List , Optional , Union
52
52
import uuid
53
+ import re
54
+
53
55
54
56
import aiohttp
55
57
import anthropic
72
74
BOARD_PROMPT ,
73
75
PREFILLED_ASSISTANT_RESPONSE ,
74
76
RULE_PROMPT ,
77
+ ONE_SHOT_PROMPT ,
75
78
)
76
79
from eval .utils import (
77
80
extract_action_from_response ,
@@ -95,7 +98,19 @@ async def call_api(
95
98
while attempt < args .max_retries :
96
99
try :
97
100
# OpenAI API
98
- if isinstance (client , openai .AsyncOpenAI ):
101
+ if args .api == "openrouter" :
102
+ kwargs = {
103
+ "model" : model ,
104
+ "messages" : messages ,
105
+ "temperature" : args .temperature ,
106
+ "top_p" : args .top_p ,
107
+ }
108
+ completion = await client .chat .completions .create (** kwargs )
109
+ if not completion .choices :
110
+ # typically due to rate limiting
111
+ raise ValueError (f"API response missing 'choices'. Error: { completion .error ['message' ]} " )
112
+ output_text = completion .choices [0 ].message .content
113
+ elif isinstance (client , openai .AsyncOpenAI ):
99
114
kwargs = {
100
115
"model" : model ,
101
116
"messages" : messages ,
@@ -106,8 +121,9 @@ async def call_api(
106
121
if "o1-" in model or "o3-" in model :
107
122
kwargs ["max_completion_tokens" ] = args .max_tokens
108
123
kwargs ["temperature" ] = 1.0
124
+ kwargs .pop ('top_p' , None )
109
125
# Special setting for R1 (TogetherAI)
110
- if model == "deepseek-ai/DeepSeek-R1" :
126
+ elif model == "deepseek-ai/DeepSeek-R1" :
111
127
kwargs ["temperature" ] = 0.6
112
128
kwargs ["max_tokens" ] = None
113
129
else :
@@ -132,6 +148,8 @@ async def call_api(
132
148
"type" : "enabled" ,
133
149
"budget_tokens" : args .max_tokens - 1024 ,
134
150
}
151
+ kwargs .pop ('top_k' , None )
152
+ kwargs .pop ('top_p' , None )
135
153
# Handle system prompts for Anthropic APIs
136
154
if messages [0 ]["role" ] == "system" :
137
155
kwargs ["system" ] = messages [0 ]["content" ]
@@ -164,6 +182,7 @@ async def call_api(
164
182
165
183
except Exception as e :
166
184
attempt += 1
185
+ print (f"Attempt { attempt } failed. { e } " )
167
186
if attempt == args .max_retries :
168
187
print (f"Failed after { args .max_retries } attempts for request. { e } " )
169
188
return None
@@ -357,18 +376,121 @@ async def process_one(
357
376
}
358
377
359
378
379
+ async def process_one_single_shot (
380
+ args : argparse .Namespace ,
381
+ client : Union [Any ],
382
+ request : Dict ,
383
+ model : str ,
384
+ tokenizer : Optional [AutoTokenizer ] = None ,
385
+ ) -> Dict :
386
+ # Load data
387
+ rules = request ["rules" ]
388
+ initial_board_ascii = request ["initial_board" ]
389
+ solution_ascii = request ["solution" ]
390
+ rows = request ["rows" ]
391
+ cols = request ["cols" ]
392
+ visual_elements = request ["visual_elements" ]
393
+ if pd .isna (visual_elements ) or visual_elements == "" :
394
+ visual_elements = None
395
+ n_history_turns = request ["n_history_turns" ] # Keep for consistency in output format
396
+
397
+ # Construct setting string (simplified for single-shot)
398
+ settings = ["single-shot" ]
399
+ if request ["num_empty_cells" ] > 0 :
400
+ settings .append (f"{ request ['num_empty_cells' ]} -empty-{ request ['shuffle_seed' ]} -seed" )
401
+ setting = "_" .join (settings )
402
+
403
+ # Pretty print visual elements
404
+ pretty_visual_elements = None
405
+ if visual_elements is not None :
406
+ try :
407
+ visual_elements = json .loads (visual_elements )
408
+ pretty_visual_elements = pretty_print_visual_elements (visual_elements )
409
+ except json .JSONDecodeError :
410
+ print (f"Warning: Could not parse visual_elements for puzzle { request ['puzzle_id' ]} " )
411
+ visual_elements = None # Set to None if parsing fails
412
+
413
+ # Construct the single prompt
414
+ one_shot_prompt = jinja2 .Template (ONE_SHOT_PROMPT ).render (
415
+ rows = rows ,
416
+ cols = cols ,
417
+ rules = rules ,
418
+ pretty_visual_elements = pretty_visual_elements ,
419
+ current_board = initial_board_ascii ,
420
+ )
421
+
422
+ # Single conversation turn
423
+ input_conversation = [{"role" : "user" , "content" : one_shot_prompt }]
424
+
425
+ # Call API
426
+ assistant_response = await call_api (
427
+ args = args ,
428
+ client = client ,
429
+ model = model ,
430
+ tokenizer = tokenizer ,
431
+ messages = input_conversation ,
432
+ )
433
+
434
+ parsed_answer = ""
435
+ final_solved = 0
436
+
437
+ if assistant_response :
438
+ # Update conversation history (just this single turn)
439
+ conversation = input_conversation + [{"role" : "assistant" , "content" : assistant_response }]
440
+
441
+ # Extract answer from <ANSWER> tags
442
+ match = re .search (r"<ANSWER>(.*?)</ANSWER>" , assistant_response , re .DOTALL | re .IGNORECASE )
443
+ if match :
444
+ answer_content = match .group (1 )
445
+ # Extract only digits
446
+ parsed_answer = "" .join (filter (str .isdigit , answer_content ))
447
+ # Check if parsed answer matches solution
448
+ if parsed_answer == solution_ascii :
449
+ final_solved = 1
450
+ print (f"[Pass] Puzzle { request ['puzzle_id' ]} solved correctly." )
451
+ else :
452
+ print (f"[Fail] Puzzle { request ['puzzle_id' ]} incorrect." )
453
+ else :
454
+ print (f"[Fail] Puzzle { request ['puzzle_id' ]} . No <ANSWER> tag found in response." )
455
+ conversation = input_conversation # Record only user prompt if assistant failed
456
+ else :
457
+ print (f"[Fail] Puzzle { request ['puzzle_id' ]} . No response from server." )
458
+ conversation = input_conversation # Record only user prompt if assistant failed
459
+
460
+
461
+ return {
462
+ # From input
463
+ "data_source" : args .dataset ,
464
+ "puzzle_id" : request ["puzzle_id" ],
465
+ "model" : args .model_save_name if args .model_save_name else model ,
466
+ "num_empty_cells" : request ["num_empty_cells" ],
467
+ "shuffle_seed" : request ["shuffle_seed" ],
468
+ "n_response_idx" : request ["n_response_idx" ],
469
+ "n_history_turns" : n_history_turns , # Retained for consistency
470
+ "setting" : setting ,
471
+ "initial_board" : request ["initial_board" ],
472
+ # From output
473
+ "conversation" : json .dumps (conversation ),
474
+ "num_rounds" : 1 , # Single shot = 1 round
475
+ "num_correct_placements" : final_solved , # Treat solved as 1 correct 'placement'
476
+ "final_solved" : final_solved ,
477
+ "final_board" : parsed_answer , # The parsed digit string from the response
478
+ }
479
+
480
+
360
481
async def process_batch (
361
482
args : argparse .Namespace ,
362
483
requests : List [Dict ],
363
484
client : Union [Any ],
364
485
model : str ,
365
486
tokenizer : Optional [AutoTokenizer ] = None ,
366
- batch_size : int = 1
487
+ batch_size : int = 1 ,
488
+ process_func : callable = process_one
367
489
) -> List [Dict ]:
368
490
semaphore = asyncio .Semaphore (batch_size )
369
491
async def process_with_semaphore (request ):
370
492
async with semaphore :
371
- return await process_one (
493
+ return await process_func (
372
494
args = args ,
373
495
client = client ,
374
496
request = request ,
@@ -383,7 +505,8 @@ async def process_with_semaphore(request):
383
505
with tqdm (total = len (tasks ), desc = "Processing requests" ) as pbar :
384
506
for coro in asyncio .as_completed (tasks ):
385
507
result = await coro
386
- outputs .append (result )
508
+ if result :
509
+ outputs .append (result )
387
510
pbar .update (1 )
388
511
389
512
return outputs
@@ -430,7 +553,7 @@ def construct_request(
430
553
431
554
432
555
def main ():
433
- parser = argparse .ArgumentParser (description = "Evaluate LLM on Sudoku puzzles in a multi-round manner ." )
556
+ parser = argparse .ArgumentParser (description = "Evaluate LLM on Sudoku puzzles." )
434
557
435
558
# Filepaths
436
559
parser .add_argument ("--dataset" , type = str , required = True , choices = ["challenge_100" , "nikoli_100" , "ctc" ],
@@ -447,6 +570,8 @@ def main():
447
570
help = "Specific puzzle indices to evaluate. Overrides start/end." )
448
571
449
572
# Eval setting
573
+ parser .add_argument ("--mode" , type = str , default = "multi_round" , choices = ["multi_round" , "single_shot" ],
574
+ help = "Evaluation mode: multi-round interaction or single-shot completion." )
450
575
# The number of evaluations for each puzzle is the product of the following four arguments.
451
576
parser .add_argument ("--num_empty_cells" , type = int , nargs = "+" , default = [0 , 10 , 20 ],
452
577
help = "Number of empty cells in the intial board after hint fill in random cells. "
@@ -460,7 +585,7 @@ def main():
460
585
461
586
# Model
462
587
parser .add_argument ("--api" , type = str , default = "openai" ,
463
- choices = ["openai" , "anthropic" , "anthropic_bedrock" , "deepseek" , "vllm" , "togetherai" ],
588
+ choices = ["openai" , "anthropic" , "anthropic_bedrock" , "deepseek" , "vllm" , "togetherai" , "openrouter" ],
464
589
help = "API to use." )
465
590
parser .add_argument ("--model" , type = str , required = True ,
466
591
help = "Model name or path." )
@@ -494,6 +619,9 @@ def main():
494
619
# Sanity check
495
620
assert args .num_empty_cells != [0 ] or len (args .shuffle_seeds ) == 1 , \
496
621
"shuffle_seed is only used when providing hints (i.e. num_empty_cells > 0)."
622
+ if args .mode == "single_shot" and args .n_history_turns != [5 ]:
623
+ print ("Warning: --n_history_turns is ignored in single_shot mode." )
624
+ args .n_history_turns = [0 ]
497
625
498
626
# Load puzzle
499
627
dataset = datasets .load_dataset ("SakanaAI/Sudoku-Bench" , args .dataset , split = "test" )
@@ -538,6 +666,11 @@ def main():
538
666
client = openai .AsyncOpenAI (
539
667
api_key = os .environ .get ("OPENAI_API_KEY" ),
540
668
)
669
+ elif args .api == "openrouter" :
670
+ client = openai .AsyncOpenAI (
671
+ api_key = os .environ .get ("OPENROUTER_API_KEY" ),
672
+ base_url = "https://openrouter.ai/api/v1" ,
673
+ )
541
674
elif args .api == "anthropic" :
542
675
client = anthropic .AsyncAnthropic (
543
676
api_key = os .environ .get ("ANTHROPIC_API_KEY" ),
@@ -555,9 +688,11 @@ def main():
555
688
)
556
689
elif args .api == "togetherai" :
557
690
client = openai .AsyncOpenAI (
558
- api_key = os .environ .get ("TOGETHERAI_API_KEY " ),
691
+ api_key = os .environ .get ("TOGETHER_API_KEY " ),
559
692
base_url = "https://api.together.xyz/v1" ,
560
693
)
694
+ elif args .api == "googleai" : # Add googleai client setup
695
+ client = genai .Client (api_key = os .environ .get ("GOOGLE_API_KEY" ))
561
696
elif args .api == "vllm" :
562
697
client = AsyncLLMEngine .from_engine_args (
563
698
AsyncEngineArgs (
@@ -575,36 +710,57 @@ def main():
575
710
)
576
711
tokenizer = AutoTokenizer .from_pretrained (args .model )
577
712
578
- # Process batch
713
+ # Select processing function based on mode
714
+ process_func = process_one if args .mode == "multi_round" else process_one_single_shot
715
+ print (f"Running in { args .mode } mode." )
716
+
717
+ # Process batch using the selected function
579
718
all_results = asyncio .run (process_batch (
580
719
args = args ,
581
720
batch_size = args .batch_size ,
582
721
requests = requests ,
583
722
client = client ,
584
723
tokenizer = tokenizer ,
585
- model = args .model
724
+ model = args .model ,
725
+ process_func = process_func # Pass the selected function
586
726
))
587
727
588
728
# Convert results to DataFrame
729
+ if not all_results : # Check if list is empty
730
+ print ("No results generated. Exiting." )
731
+ return # Exit if no results
589
732
res_df = pd .DataFrame (all_results )
590
733
if len (res_df ) == 0 :
591
- print ("No results to save. Possibly no puzzles or an error occurred ." )
734
+ print ("No results to save. DataFrame is empty ." )
592
735
return
593
736
594
737
# Print summary
595
738
# We'll measure average number of correct placements and fraction of puzzles solved.
596
739
group_cols = ["num_empty_cells" , "setting" , "model" ]
597
- summary = (
598
- res_df
599
- .groupby (group_cols )
600
- .agg ({
601
- "num_correct_placements" : "mean" ,
602
- "final_solved" : "mean"
603
- })
604
- .reset_index ()
605
- )
606
- with pd .option_context ("display.max_rows" , None , "display.precision" , 2 ):
607
- print (summary )
740
+ agg_metrics = {"final_solved" : "mean" }
741
+ if args .mode == "multi_round" :
742
+ # Only include multi-round specific metrics if in that mode
743
+ agg_metrics ["num_correct_placements" ] = "mean"
744
+
745
+ # Ensure columns exist before aggregation
746
+ existing_cols = [col for col in group_cols if col in res_df .columns ]
747
+ if not existing_cols :
748
+ print ("Grouping columns not found in results DataFrame. Cannot generate summary." )
749
+ summary_str = "Summary could not be generated."
750
+ else :
751
+ summary = (
752
+ res_df
753
+ .groupby (existing_cols )
754
+ .agg (agg_metrics )
755
+ .reset_index ()
756
+ )
757
+ with pd .option_context ("display.max_rows" , None , "display.precision" , 3 ):
758
+ summary_str = summary .to_string ()
759
+
760
+ print ("\\ n--- Summary ---" )
761
+ print (summary_str )
762
+ print ("---------------" )
763
+
608
764
609
765
# Save results to CSV
610
766
os .makedirs (os .path .dirname (args .output_csv ), exist_ok = True )
0 commit comments