Skip to content

Commit

Permalink
Added an option to run a simple, non-learning heuristic AI
Browse files Browse the repository at this point in the history
  • Loading branch information
samvelyan committed Oct 1, 2019
1 parent 85d496d commit bea53c1
Showing 1 changed file with 66 additions and 3 deletions.
69 changes: 66 additions & 3 deletions smac/env/starcraft2/starcraft2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class StarCraft2Env(MultiAgentEnv):
def __init__(
self,
map_name="8m",
step_mul=None,
step_mul=8,
move_amount=2,
difficulty="7",
game_version=None,
Expand All @@ -91,6 +91,7 @@ def __init__(
replay_prefix="",
window_size_x=1920,
window_size_y=1200,
heuristic_ai=False,
debug=False,
):
"""
Expand All @@ -102,7 +103,7 @@ def __init__(
The name of the SC2 map to play (default is "8m"). The full list
can be found by running bin/map_list.
step_mul : int, optional
How many game steps per agent step (default is None). None
How many game steps per agent step (default is 8). None
indicates to use the default map step_mul.
move_amount : float, optional
How far away units are ordered to move per step (default is 2).
Expand Down Expand Up @@ -179,6 +180,8 @@ def __init__(
The length of StarCraft II window size (default is 1920).
window_size_y: int, optional
The height of StarCraft II window size (default is 1200).
heuristic_ai: bool, optional
Whether or not to use a non-learning heuristic AI (default False).
debug: bool, optional
Log messages about observations, state, actions and rewards for
debugging purposes (default is False).
Expand Down Expand Up @@ -222,6 +225,7 @@ def __init__(
self.game_version = game_version
self.continuing_episode = continuing_episode
self._seed = seed
self.heuristic_ai = heuristic_ai
self.debug = debug
self.window_size = (window_size_x, window_size_y)
self.replay_dir = replay_dir
Expand Down Expand Up @@ -348,6 +352,9 @@ def reset(self):

self.last_action = np.zeros((self.n_agents, self.n_actions))

if self.heuristic_ai:
self.heuristic_targets = [None] * self.n_agents

try:
self._obs = self._controller.observe()
self.init_units()
Expand Down Expand Up @@ -389,7 +396,10 @@ def step(self, actions):
logging.debug("Actions".center(60, "-"))

for a_id, action in enumerate(actions):
agent_action = self.get_agent_action(a_id, action)
if not self.heuristic_ai:
agent_action = self.get_agent_action(a_id, action)
else:
agent_action = self.get_agent_action_heuristic(a_id, action)
if agent_action:
sc_actions.append(agent_action)

Expand Down Expand Up @@ -548,6 +558,59 @@ def get_agent_action(self, a_id, action):
sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
return sc_action

def get_agent_action_heuristic(self, a_id, action):
unit = self.get_unit_by_id(a_id)
tag = unit.tag

target = self.heuristic_targets[a_id]
if unit.unit_type == self.medivac_id:
if (target is None or self.agents[target].health == 0 or
self.agents[target].health == self.agents[target].health_max):
min_dist = math.hypot(self.max_distance_x, self.max_distance_y)
min_id = -1
for al_id, al_unit in self.agents.items():
if al_unit.unit_type == self.medivac_id:
continue
if (al_unit.health != 0 and
al_unit.health != al_unit.health_max):
dist = self.distance(unit.pos.x, unit.pos.y,
al_unit.pos.x, al_unit.pos.y)
if dist < min_dist:
min_dist = dist
min_id = al_id
self.heuristic_targets[a_id] = min_id
if min_id == -1:
self.heuristic_targets[a_id] = None
return None
action_id = actions['heal']
target_tag = self.agents[self.heuristic_targets[a_id]].tag
else:
if target is None or self.enemies[target].health == 0:
min_dist = math.hypot(self.max_distance_x, self.max_distance_y)
min_id = -1
for e_id, e_unit in self.enemies.items():
if (unit.unit_type == self.marauder_id and
e_unit.unit_type == self.medivac_id):
continue
if e_unit.health > 0:
dist = self.distance(unit.pos.x, unit.pos.y,
e_unit.pos.x, e_unit.pos.y)
if dist < min_dist:
min_dist = dist
min_id = e_id
self.heuristic_targets[a_id] = min_id
action_id = actions['attack']
target_tag = self.enemies[self.heuristic_targets[a_id]].tag

cmd = r_pb.ActionRawUnitCommand(
ability_id = action_id,
target_unit_tag = target_tag,
unit_tags = [tag],
queue_command = False)

sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
return sc_action

def reward_battle(self):
"""Reward function when self.reward_spare==False.
Returns accumulative hit/shield point damage dealt to the enemy
Expand Down

0 comments on commit bea53c1

Please sign in to comment.