Skip to content

Commit

Permalink
refactor: remove repetitive imports in run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jerry871002 committed Apr 30, 2024
1 parent 6e4c82b commit f259f29
Showing 1 changed file with 32 additions and 48 deletions.
80 changes: 32 additions & 48 deletions src/run.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,40 @@
import argparse
import importlib


def run(args: argparse.Namespace) -> None:
if args.scenario == 'grid':
from grid_world.run import (
run_bpr_okr,
run_bpr_plus,
run_bsi,
run_bsi_pt,
run_deep_bpr_plus,
run_tom,
)
elif args.scenario == 'nav':
from navigation_game.run import (
run_bpr_okr,
run_bpr_plus,
run_bsi,
run_bsi_pt,
run_deep_bpr_plus,
run_tom,
)
elif args.scenario == 'soccer':
from soccer_game.run import (
run_bpr_okr,
run_bpr_plus,
run_bsi,
run_bsi_pt,
run_deep_bpr_plus,
run_tom,
)
elif args.scenario == 'baseball':
from baseball_game.run import (
run_bpr_okr,
run_bpr_plus,
run_bsi,
run_bsi_pt,
run_deep_bpr_plus,
run_tom,
)
scenario_modules = {
'grid': 'grid_world.run',
'nav': 'navigation_game.run',
'soccer': 'soccer_game.run',
'baseball': 'baseball_game.run',
}

if args.agent == 'bpr+':
run_bpr_plus(args)
elif args.agent == 'deep-bpr+':
run_deep_bpr_plus(args)
elif args.agent == 'tom':
run_tom(args)
elif args.agent == 'bpr-okr':
run_bpr_okr(args)
elif args.agent == 'bsi':
run_bsi(args)
elif args.agent == 'bsi-pt':
run_bsi_pt(args)
scenario = args.scenario
if scenario in scenario_modules:
run_module = importlib.import_module(scenario_modules[scenario])
run_bpr_okr = run_module.run_bpr_okr
run_bpr_plus = run_module.run_bpr_plus
run_bsi = run_module.run_bsi
run_bsi_pt = run_module.run_bsi_pt
run_deep_bpr_plus = run_module.run_deep_bpr_plus
run_tom = run_module.run_tom
else:
raise ValueError(f"Unsupported scenario: {scenario}")

agent_functions = {
'bpr+': run_bpr_plus,
'deep-bpr+': run_deep_bpr_plus,
'tom': run_tom,
'bpr-okr': run_bpr_okr,
'bsi': run_bsi,
'bsi-pt': run_bsi_pt,
}

if agent in agent_functions:
agent_functions[agent](args)
else:
raise ValueError(f"Unsupported agent type: {agent}")


def positive_int(value: str) -> int:
Expand Down

0 comments on commit f259f29

Please sign in to comment.