/
pettingzoo_chess.py
149 lines (114 loc) · 4.79 KB
/
pettingzoo_chess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import re
from typing import List, Union
from pettingzoo.classic import chess_v6
from pettingzoo.classic.chess.chess_utils import chess, get_move_plane
from chatarena.environments.base import Environment, TimeStep, register_env
from ..message import Message, MessagePool
def action_string_to_alphazero_format(action: str, player_index: int) -> int:
pattern = r"Move \((\d), (\d)\) to \((\d), (\d)\)"
match = re.match(pattern, action)
if not match:
return -1
coords = [int(coord) for coord in match.groups()]
x1, y1, x2, y2 = coords
if player_index == 1:
x1, y1, x2, y2 = 7 - x1, 7 - y1, 7 - x2, 7 - y2
move = chess.Move(from_square=8 * y1 + x1, to_square=8 * y2 + x2, promotion=None)
move_plane = get_move_plane(move)
return x1 * 8 * 73 + y1 * 73 + move_plane
@register_env
class PettingzooChess(Environment):
type_name = "pettingzoo:chess"
def __init__(self, player_names: List[str], **kwargs):
super().__init__(player_names=player_names, **kwargs)
self.env = chess_v6.env(render_mode="ansi")
# The "state" of the environment is maintained by the message pool
self.message_pool = MessagePool()
self._terminal = False
self.reset()
def reset(self):
self.env.reset()
self.current_player = 0
self.turn = 0
self.message_pool.reset()
obs_dict, reward, terminal, truncation, info = self.env.last()
observation = self.get_observation()
self._terminal = terminal
return TimeStep(observation=observation, reward=reward, terminal=terminal)
def get_next_player(self) -> str:
return self.player_names[self.current_player]
def get_observation(self, player_name=None) -> List[Message]:
if player_name is None:
return self.message_pool.get_all_messages()
else:
return self.message_pool.get_visible_messages(
player_name, turn=self.turn + 1
)
def _moderator_speak(self, text: str, visible_to: Union[str, List[str]] = "all"):
"""Moderator say something."""
message = Message(
agent_name="Moderator", content=text, turn=self.turn, visible_to=visible_to
)
self.message_pool.append_message(message)
def is_terminal(self) -> bool:
return self._terminal
def step(self, player_name: str, action: str) -> TimeStep:
assert (
player_name == self.get_next_player()
), f"Wrong player! It is {self.get_next_player()} turn."
self._moderator_speak("\n" + self.env.render())
message = Message(agent_name=player_name, content=action, turn=self.turn)
self.message_pool.append_message(message)
# Convert the action to the AlphaZero format
alphazero_move = action_string_to_alphazero_format(action, self.current_player)
if alphazero_move == -1:
raise ValueError(f"Invalid action: {action}")
obs_dict, reward, terminal, truncation, info = self.env.last()
self.env.step(alphazero_move)
self._terminal = terminal # Update the terminal state
reward = {
self.player_names[self.current_player]: reward,
self.player_names[1 - self.current_player]: 0,
}
self.current_player = 1 - self.current_player
self.turn += 1
return TimeStep(
observation=self.get_observation(), reward=reward, terminal=terminal
)
def check_action(self, action: str, agent_name: str) -> bool:
# This can be implemented depending on how you want to validate actions for a given agent
alphazero_move = action_string_to_alphazero_format(action, self.current_player)
if alphazero_move == -1:
return False
elif self.env.last()[0]["action_mask"][alphazero_move] == 0:
return False
else:
return True
def print(self):
print(self.env.render())
def test_chess_environment():
player_names = ["player1", "player2"]
env = PettingzooChess(player_names)
env.reset()
assert env.get_next_player() == "player1"
env.print()
# Move sequence: 1. e4 e5 2. Nf3 Nc6
moves = [
"Move (4, 1) to (4, 3)",
"Move (4, 6) to (4, 4)",
"Move (6, 0) to (5, 2)",
"Move (1, 7) to (2, 5)",
]
for i, move in enumerate(moves):
assert env.check_action(move, env.get_next_player())
timestep = env.step(env.get_next_player(), move)
print(timestep.reward)
print(timestep.terminal)
env.print()
if __name__ == "__main__":
env = chess_v6.env()
# Test the conversion function with an example action string
action = "Move (0, 1) to (0, 3)"
alphazero_move = action_string_to_alphazero_format(action, 0)
print(alphazero_move)
test_chess_environment()