Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/python3dot10 gym compatibility update #25

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 10 additions & 6 deletions lbforaging.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
'''Basic flow to see if the base install worked over one environment.'''
import argparse
import logging
import random
import time
import gym
import numpy as np
import gymnasium as gym
import lbforaging


Expand All @@ -13,7 +13,7 @@
def _game_loop(env, render):
"""
"""
obs = env.reset()
_, _ = env.reset()
done = False

if render:
Expand All @@ -24,7 +24,7 @@ def _game_loop(env, render):

actions = env.action_space.sample()

nobs, nreward, ndone, _ = env.step(actions)
_, nreward, ndone, _, _ = env.step(actions)
if sum(nreward) > 0:
print(nreward)

Expand All @@ -38,9 +38,11 @@ def _game_loop(env, render):

def main(game_count=1, render=False):
env = gym.make("Foraging-8x8-2p-2f-v2")
obs = env.reset()

for episode in range(game_count):
_, info = env.reset()
assert info == {}

for _ in range(game_count):
_game_loop(env, render)


Expand All @@ -54,3 +56,5 @@ def main(game_count=1, render=False):

args = parser.parse_args()
main(args.times, args.render)

print("Done. NO RUNTIME ERRORS.")
3 changes: 2 additions & 1 deletion lbforaging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from gym.envs.registration import registry, register, make, spec
from gymnasium.envs.registration import registry, register, make, spec
from itertools import product
from lbforaging import foraging

sizes = range(5, 20)
players = range(2, 20)
Expand Down
6 changes: 1 addition & 5 deletions lbforaging/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
# from lbforaging.agents.random_agent import RandomAgent
# from lbforaging.agents.heuristic_agent import H1, H2, H3, H4
# from lbforaging.agents.q_agent import QAgent
# from lbforaging.agents.monte_carlo import MonteCarloAgent
# from lbforaging.agents.hba import HBAAgent
from lbforaging.agents import *
6 changes: 3 additions & 3 deletions lbforaging/agents/hba.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from . import QAgent
from foraging import Env
from lbforaging.agents.q_agent import QAgent
from lbforaging.foraging.environment import ForagingEnv as Env
import random
import numpy as np
from agents import H1, H2, H3, H4
from lbforaging.agents.heuristic_agent import H1, H2, H3, H4
from itertools import product
from collections import defaultdict
from functools import reduce
Expand Down
6 changes: 3 additions & 3 deletions lbforaging/agents/heuristic_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import numpy as np
from foraging import Agent
from foraging.environment import Action
from lbforaging.foraging.agent import Agent
from lbforaging.foraging.environment import Action


class HeuristicAgent(Agent):
Expand All @@ -28,7 +28,7 @@ def _move_towards(self, target, allowed):
raise ValueError("No simple path found")

def step(self, obs):
raise NotImplemented("Heuristic agent is implemented by H1-H4")
raise NotImplementedError("Heuristic agent is implemented by H1-H4")


class H1(HeuristicAgent):
Expand Down
3 changes: 2 additions & 1 deletion lbforaging/agents/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import plotly.graph_objs as go
from networkx.drawing.nx_pydot import graphviz_layout

from foraging import Agent, Env
from lbforaging.foraging.agent import Agent
from lbforaging.foraging.environment import ForagingEnv as Env

MCTS_DEPTH = 15

Expand Down
2 changes: 1 addition & 1 deletion lbforaging/agents/nn_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random

from foraging import Agent
from lbforaging.foraging.agent import Agent


class NNAgent(Agent):
Expand Down
7 changes: 4 additions & 3 deletions lbforaging/agents/q_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as np
import pandas as pd

from agents import H1
from lbforaging import Agent, Env
from lbforaging.environment import Action
from lbforaging.agents.heuristic_agent import H1
from lbforaging.foraging.agent import Agent
from lbforaging.foraging.environment import Action
from lbforaging.foraging.environment import ForagingEnv as Env

_CACHE = None

Expand Down
2 changes: 1 addition & 1 deletion lbforaging/agents/random_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random

from lbforaging import Agent
from lbforaging.foraging.agent import Agent


class RandomAgent(Agent):
Expand Down
3 changes: 1 addition & 2 deletions lbforaging/foraging/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging

import numpy as np

_MAX_INT = 999999
Expand Down Expand Up @@ -30,7 +29,7 @@ def _step(self, obs):
return action

def step(self, obs):
raise NotImplemented("You must implement an agent")
raise NotImplementedError("You must implement an agent")

def _closest_food(self, obs, max_food_level=None, start=None):

Expand Down
69 changes: 51 additions & 18 deletions lbforaging/foraging/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from collections import namedtuple, defaultdict
from enum import Enum
from itertools import product
from gym import Env
import gym
from gym.utils import seeding
from gymnasium import Env
import gymnasium as gym
from gymnasium.utils import seeding
import numpy as np


Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
self.field = np.zeros(field_size, np.int32)

self.penalty = penalty

self.max_food = max_food
self._food_spawned = 0.0
self.max_player_level = max_player_level
Expand All @@ -109,13 +109,14 @@ def __init__(
self._grid_observation = grid_observation

self.action_space = gym.spaces.Tuple(tuple([gym.spaces.Discrete(6)] * len(self.players)))
self.observation_space = gym.spaces.Tuple(tuple([self._get_observation_space()] * len(self.players)))
self.observation_space = gym.spaces.Tuple(
tuple([self._get_observation_space()] * len(self.players)))

self.viewer = None

self.n_agents = len(self.players)

def seed(self, seed=None):
def seed(self, seed=0):
self.np_random, seed = seeding.np_random(seed)
return [seed]

Expand Down Expand Up @@ -159,7 +160,15 @@ def _get_observation_space(self):
min_obs = np.stack([agents_min, foods_min, access_min])
max_obs = np.stack([agents_max, foods_max, access_max])

return gym.spaces.Box(np.array(min_obs), np.array(max_obs), dtype=np.float32)
low_obs = np.array(min_obs)
high_obs = np.array(max_obs)
assert len(low_obs) == len(high_obs)
composed_obs_space = gym.spaces.Box(
low=low_obs,
high=high_obs,
shape=[len(low_obs)],
dtype=np.float32)
return composed_obs_space

@classmethod
def from_obs(cls, obs):
Expand All @@ -170,7 +179,15 @@ def from_obs(cls, obs):
player.score = p.score if p.score else 0
players.append(player)

env = cls(players, None, None, None, None)
env = cls(
players=players,
max_player_level=None,
field_size=None,
max_food=None,
sight=None,
max_episode_steps=50,
force_coop=False
)
env.field = np.copy(obs.field)
env.current_step = obs.current_step
env.sight = obs.sight
Expand Down Expand Up @@ -202,6 +219,14 @@ def _gen_valid_moves(self):
for player in self.players
}

def test_gen_valid_moves(self) -> bool:
''' Wrapper around a private method to test if the generated moves are valid. '''
try:
self._gen_valid_moves()
except Exception as _:
return False
return True

def neighborhood(self, row, col, distance=1, ignore_diag=False):
if not ignore_diag:
return self.field[
Expand Down Expand Up @@ -253,8 +278,8 @@ def spawn_food(self, max_food, max_level):

while food_count < max_food and attempts < 1000:
attempts += 1
row = self.np_random.randint(1, self.rows - 1)
col = self.np_random.randint(1, self.cols - 1)
row = self.np_random.integers(1, self.rows - 1)
col = self.np_random.integers(1, self.cols - 1)

# check if it has neighbors:
if (
Expand All @@ -269,7 +294,7 @@ def spawn_food(self, max_food, max_level):
if min_level == max_level
# ! this is excluding food of level `max_level` but is kept for
# ! consistency with prior LBF versions
else self.np_random.randint(min_level, max_level)
else self.np_random.integers(min_level, max_level)
)
food_count += 1
self._food_spawned = self.field.sum()
Expand All @@ -290,12 +315,12 @@ def spawn_players(self, max_player_level):
player.reward = 0

while attempts < 1000:
row = self.np_random.randint(0, self.rows)
col = self.np_random.randint(0, self.cols)
row = self.np_random.integers(0, self.rows)
col = self.np_random.integers(0, self.cols)
if self._is_empty_location(row, col):
player.setup(
(row, col),
self.np_random.randint(1, max_player_level + 1),
self.np_random.integers(1, max_player_level + 1),
self.field_size,
)
break
Expand Down Expand Up @@ -465,9 +490,15 @@ def get_player_reward(observation):
assert self.observation_space[i].contains(obs), \
f"obs space error: obs: {obs}, obs_space: {self.observation_space[i]}"

return nobs, nreward, ndone, ninfo
truncated_term = False
# To turn this into a single agent task, you need to sum the nreward and the ndone
return nobs, nreward, ndone, truncated_term, ninfo

def test_make_gym_obs(self):
''' Test wrapper to test the current observation in a public manner. '''
return self._make_gym_obs()

def reset(self):
def reset(self, *, seed=None, options=None):
self.field = np.zeros(self.field_size, np.int32)
self.spawn_players(self.max_player_level)
player_levels = sorted([player.level for player in self.players])
Expand All @@ -479,8 +510,10 @@ def reset(self):
self._game_over = False
self._gen_valid_moves()

nobs, _, _, _ = self._make_gym_obs()
return nobs
nobs, _, _, _, _ = self._make_gym_obs()
# The new gym spec and gym utils require that
# the new observation and a dictionary with info is returned
return nobs, {}

def step(self, actions):
self.current_step += 1
Expand Down
12 changes: 12 additions & 0 deletions lbforaging/foraging/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@
"""
)

try:
from pyglet import gl
except ImportError as e:
raise ImportError(
"""
Cannot 'from pyglet import gl'
HINT: you can install pyglet directly via 'pip install pyglet'.
But if you really just want to install all Gym dependencies and not have to think about it,
'pip install -e .[all]' or 'pip install gym[all]' will do it.
"""
)

try:
from pyglet.gl import *
except ImportError as e:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.10",
],
install_requires=["numpy", "gym==0.21", "pyglet"],
install_requires=["numpy", "gym==0.26", "pyglet<2"],
extras_require={"test": ["pytest"]},
include_package_data=True,
)