"""PettingZoo multi-agent environment adapters for the Gym GUI.
This module provides adapters for PettingZoo environments, supporting both
AEC (Agent Environment Cycle - turn-based) and Parallel (simultaneous) APIs.
PettingZoo environments can operate in:
- Single-Agent Mode: Run one agent while others use random/scripted policies
- Multi-Agent Mode: Multiple agents with different controllers (human, AI, policy)
- Human Control: Human plays turn-based games (Chess, Tic-Tac-Toe, etc.)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union
import numpy as np
from gym_gui.core.adapters.base import (
AdapterContext,
AdapterStep,
AgentSnapshot,
EnvironmentAdapter,
StepState,
)
from gym_gui.core.enums import ControlMode, RenderMode, SteppingParadigm
from gym_gui.core.pettingzoo_enums import (
HUMAN_CONTROLLABLE_ENVS,
PETTINGZOO_CONTROL_MODES,
PETTINGZOO_ENV_METADATA,
PETTINGZOO_RENDER_MODES,
PettingZooAPIType,
PettingZooEnvId,
PettingZooFamily,
get_api_type,
get_display_name,
is_aec_env,
)
from gym_gui.logging_config.log_constants import (
LOG_ADAPTER_ENV_CLOSED,
LOG_ADAPTER_ENV_CREATED,
LOG_ADAPTER_ENV_RESET,
LOG_ADAPTER_RENDER_ERROR,
LOG_ADAPTER_STEP_SUMMARY,
)
_LOGGER = logging.getLogger(__name__)
@dataclass
class PettingZooConfig:
"""Configuration for PettingZoo multi-agent environments.
Attributes:
env_id: The PettingZoo environment identifier (e.g., "chess_v6")
family: The environment family (classic, mpe, sisl, butterfly, atari)
render_mode: Rendering mode ("rgb_array", "human", "ansi")
max_cycles: Maximum number of cycles before truncation
seed: Random seed for reproducibility
env_kwargs: Additional keyword arguments passed to environment constructor
human_player: Name of the agent controlled by human (for hybrid modes)
agent_controllers: Per-agent controller type ("human", "random", "policy")
"""
env_id: PettingZooEnvId
family: PettingZooFamily = PettingZooFamily.CLASSIC
render_mode: str = "rgb_array"
max_cycles: int = 500
seed: Optional[int] = None
env_kwargs: Dict[str, Any] = field(default_factory=dict)
human_player: Optional[str] = None
agent_controllers: Dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate and set family from env_id if not provided."""
if self.env_id in PETTINGZOO_ENV_METADATA:
metadata = PETTINGZOO_ENV_METADATA[self.env_id]
self.family = metadata[0]
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary."""
return {
"env_id": self.env_id.value if isinstance(self.env_id, PettingZooEnvId) else self.env_id,
"family": self.family.value if isinstance(self.family, PettingZooFamily) else self.family,
"render_mode": self.render_mode,
"max_cycles": self.max_cycles,
"seed": self.seed,
"env_kwargs": self.env_kwargs,
"human_player": self.human_player,
"agent_controllers": self.agent_controllers,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PettingZooConfig":
"""Create config from dictionary."""
env_id = data.get("env_id", "")
if isinstance(env_id, str):
try:
env_id = PettingZooEnvId(env_id)
except ValueError:
pass
family = data.get("family", "classic")
if isinstance(family, str):
try:
family = PettingZooFamily(family)
except ValueError:
family = PettingZooFamily.CLASSIC
return cls(
env_id=env_id,
family=family,
render_mode=data.get("render_mode", "rgb_array"),
max_cycles=data.get("max_cycles", 500),
seed=data.get("seed"),
env_kwargs=data.get("env_kwargs", {}),
human_player=data.get("human_player"),
agent_controllers=data.get("agent_controllers", {}),
)
[docs]
class PettingZooAdapter(EnvironmentAdapter[Any, Any]):
"""Unified adapter for PettingZoo AEC and Parallel environments.
This adapter provides a consistent interface for multi-agent environments,
supporting human control for turn-based games and AI control for all others.
Attributes:
id: Environment identifier (e.g., "chess_v6")
supported_control_modes: Tuple of supported control modes
default_render_mode: Default rendering mode (RGB_ARRAY)
"""
default_render_mode = RenderMode.RGB_ARRAY
supported_render_modes = (RenderMode.RGB_ARRAY,)
def __init__(
self,
context: AdapterContext | None = None,
*,
config: PettingZooConfig | None = None,
env_id: PettingZooEnvId | str | None = None,
) -> None:
"""Initialize the PettingZoo adapter.
Args:
context: Adapter context with settings and control mode
config: Full PettingZoo configuration
env_id: Environment ID (alternative to config)
"""
super().__init__(context)
# Resolve env_id
if config is not None:
self._config = config
self._env_id = config.env_id
elif env_id is not None:
if isinstance(env_id, str):
try:
self._env_id = PettingZooEnvId(env_id)
except ValueError:
self._env_id = env_id # type: ignore
else:
self._env_id = env_id
self._config = PettingZooConfig(env_id=self._env_id)
else:
raise ValueError("Either config or env_id must be provided")
# Set id for adapter interface
self.id = self._env_id.value if isinstance(self._env_id, PettingZooEnvId) else str(self._env_id)
# Determine control modes from enum or use defaults
if isinstance(self._env_id, PettingZooEnvId) and self._env_id in PETTINGZOO_CONTROL_MODES:
self.supported_control_modes = PETTINGZOO_CONTROL_MODES[self._env_id]
else:
self.supported_control_modes = (ControlMode.AGENT_ONLY,)
# Multi-agent state
self._pz_env: Any = None # PettingZoo environment instance
self._is_parallel: bool = False
self._agents: List[str] = []
self._current_agent: Optional[str] = None
self._step_count: int = 0
self._episode_rewards: Dict[str, float] = {}
self._terminated_agents: set[str] = set()
self._action_masks: Dict[str, Optional[np.ndarray]] = {}
self._last_observations: Dict[str, Any] = {}
@property
def is_parallel(self) -> bool:
"""Check if using Parallel API (vs AEC)."""
return self._is_parallel
@property
def agents(self) -> List[str]:
"""Get list of active agent names."""
if self._pz_env is None:
return []
return list(self._pz_env.agents) if hasattr(self._pz_env, "agents") else []
@property
def possible_agents(self) -> List[str]:
"""Get list of all possible agent names."""
if self._pz_env is None:
return []
return list(self._pz_env.possible_agents) if hasattr(self._pz_env, "possible_agents") else []
@property
def current_agent(self) -> Optional[str]:
"""Get current agent (for AEC mode)."""
return self._current_agent
@property
def num_agents(self) -> int:
"""Get number of agents."""
return len(self.possible_agents)
@property
def stepping_paradigm(self) -> SteppingParadigm: # type: ignore[override]
"""Return the stepping paradigm based on environment type.
Returns:
SIMULTANEOUS for Parallel API environments,
SEQUENTIAL for AEC environments.
"""
if self._is_parallel:
return SteppingParadigm.SIMULTANEOUS
return SteppingParadigm.SEQUENTIAL
[docs]
def load(self) -> None:
"""Instantiate the PettingZoo environment."""
try:
# Import pettingzoo dynamically
import importlib
# Determine API type
if isinstance(self._env_id, PettingZooEnvId):
self._is_parallel = get_api_type(self._env_id) == PettingZooAPIType.PARALLEL
family = PETTINGZOO_ENV_METADATA[self._env_id][0].value
# Use full env_id with version (e.g., "tictactoe_v3")
env_module_name = self._env_id.value
else:
# Fallback for string env_id
self._is_parallel = False
family = self._config.family.value if isinstance(self._config.family, PettingZooFamily) else self._config.family
env_module_name = str(self._env_id)
# Import the environment module (e.g., pettingzoo.classic.tictactoe_v3)
module_path = f"pettingzoo.{family}.{env_module_name}"
env_module = importlib.import_module(module_path)
# Build kwargs
kwargs: Dict[str, Any] = {"render_mode": self._config.render_mode}
kwargs.update(self._config.env_kwargs)
if self._config.max_cycles and self._is_parallel:
kwargs["max_cycles"] = self._config.max_cycles
# Create environment using appropriate API
if self._is_parallel:
if hasattr(env_module, "parallel_env"):
self._pz_env = env_module.parallel_env(**kwargs)
else:
# Fallback to AEC
self._is_parallel = False
self._pz_env = env_module.env(**kwargs)
else:
self._pz_env = env_module.env(**kwargs)
self.log_constant(
LOG_ADAPTER_ENV_CREATED,
extra={
"env_id": self.id,
"api_type": "parallel" if self._is_parallel else "aec",
"render_mode": self._config.render_mode,
"family": family,
},
)
except Exception as exc:
_LOGGER.error("Failed to load PettingZoo environment %s: %s", self.id, exc)
raise
[docs]
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> AdapterStep[Any]:
"""Reset the environment.
Args:
seed: Optional random seed
options: Additional reset options
Returns:
Initial step result
"""
if self._pz_env is None:
self.load()
self._step_count = 0
self._terminated_agents.clear()
self._action_masks.clear()
actual_seed = seed if seed is not None else self._config.seed
if self._is_parallel:
observations, infos = self._pz_env.reset(seed=actual_seed)
self._agents = list(self._pz_env.agents)
self._current_agent = None
self._episode_rewards = {agent: 0.0 for agent in self._agents}
self._last_observations = dict(observations)
# Package initial observations
obs = observations.get(self._agents[0]) if self._agents else None
self.log_constant(
LOG_ADAPTER_ENV_RESET,
extra={
"env_id": self.id,
"seed": actual_seed,
"num_agents": len(self._agents),
"agents": ",".join(self._agents),
},
)
return self._package_step(
observation=obs,
reward=0.0,
terminated=False,
truncated=False,
info={
"agents": self._agents,
"all_observations": observations,
"all_infos": infos,
},
)
else:
# AEC API
self._pz_env.reset(seed=actual_seed)
self._agents = list(self._pz_env.agents)
self._current_agent = self._pz_env.agent_selection
self._episode_rewards = {agent: 0.0 for agent in self.possible_agents}
observation, reward, termination, truncation, info = self._pz_env.last()
# Extract action mask if available
action_mask = None
if isinstance(info, dict) and "action_mask" in info:
action_mask = info["action_mask"]
self._action_masks[self._current_agent] = action_mask
self._last_observations[self._current_agent] = observation
self.log_constant(
LOG_ADAPTER_ENV_RESET,
extra={
"env_id": self.id,
"seed": actual_seed,
"num_agents": len(self._agents),
"current_agent": self._current_agent,
"has_action_mask": action_mask is not None,
},
)
info_dict: Dict[str, Any] = {
"current_agent": self._current_agent,
"action_mask": action_mask,
"agents": self._agents,
}
if isinstance(info, dict):
info_dict.update(info)
return self._package_step(
observation=observation,
reward=float(reward) if reward else 0.0,
terminated=bool(termination),
truncated=bool(truncation),
info=info_dict,
)
[docs]
def step(self, action: Any) -> AdapterStep[Any]:
"""Execute action(s) in the environment.
For AEC environments, pass a single action for the current agent.
For Parallel environments, pass a dict mapping agent names to actions.
Args:
action: Single action (AEC) or dict of actions (Parallel)
Returns:
Step result
"""
if self._pz_env is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
self._step_count += 1
if self._is_parallel:
return self._step_parallel(action)
else:
return self._step_aec(action)
def _step_parallel(self, actions: Dict[str, Any]) -> AdapterStep[Any]:
"""Step in Parallel mode."""
observations, rewards, terminations, truncations, infos = self._pz_env.step(actions)
# Update episode rewards
for agent, reward in rewards.items():
if agent in self._episode_rewards:
self._episode_rewards[agent] += float(reward)
# Track terminated agents
for agent, terminated in terminations.items():
if terminated:
self._terminated_agents.add(agent)
self._last_observations = dict(observations)
# Check if all agents are done
all_terminated = all(terminations.values()) if terminations else False
all_truncated = all(truncations.values()) if truncations else False
# Sum rewards for primary output
total_reward = sum(rewards.values()) if rewards else 0.0
self.log_constant(
LOG_ADAPTER_STEP_SUMMARY,
extra={
"env_id": self.id,
"step": self._step_count,
"total_reward": total_reward,
"terminated_agents": len(self._terminated_agents),
"active_agents": len(self.agents),
},
)
return self._package_step(
observation=observations,
reward=total_reward,
terminated=all_terminated,
truncated=all_truncated,
info={
"all_rewards": rewards,
"all_terminations": terminations,
"all_truncations": truncations,
"all_infos": infos,
"agents": self.agents,
"episode_rewards": self._episode_rewards.copy(),
},
)
def _step_aec(self, action: Any) -> AdapterStep[Any]:
"""Step in AEC mode."""
previous_agent = self._current_agent
# Execute action for current agent
self._pz_env.step(action)
# Get next agent's state
if self._pz_env.agents:
self._current_agent = self._pz_env.agent_selection
observation, reward, termination, truncation, info = self._pz_env.last()
# Update episode rewards
if previous_agent and previous_agent in self._episode_rewards:
self._episode_rewards[previous_agent] += float(reward)
# Track terminated agents
if termination and self._current_agent:
self._terminated_agents.add(self._current_agent)
# Extract action mask
action_mask = None
if isinstance(info, dict) and "action_mask" in info:
action_mask = info["action_mask"]
self._action_masks[self._current_agent] = action_mask
self._last_observations[self._current_agent] = observation
self.log_constant(
LOG_ADAPTER_STEP_SUMMARY,
extra={
"env_id": self.id,
"step": self._step_count,
"previous_agent": previous_agent,
"current_agent": self._current_agent,
"reward": float(reward),
"terminated": termination,
"has_action_mask": action_mask is not None,
},
)
aec_info: Dict[str, Any] = {
"current_agent": self._current_agent,
"previous_agent": previous_agent,
"action_mask": action_mask,
"agents": self.agents,
"episode_rewards": self._episode_rewards.copy(),
}
if isinstance(info, dict):
aec_info.update(info)
return self._package_step(
observation=observation,
reward=float(reward),
terminated=bool(termination),
truncated=bool(truncation),
info=aec_info,
)
else:
# Episode is done - no more agents
self.log_constant(
LOG_ADAPTER_STEP_SUMMARY,
extra={
"env_id": self.id,
"step": self._step_count,
"status": "episode_complete",
"episode_rewards": self._episode_rewards,
},
)
return self._package_step(
observation=None,
reward=0.0,
terminated=True,
truncated=False,
info={
"current_agent": None,
"agents": [],
"episode_rewards": self._episode_rewards.copy(),
"status": "episode_complete",
},
)
[docs]
def render(self) -> Any:
"""Render the environment.
Returns:
RGB array if render_mode is "rgb_array", else None
"""
if self._pz_env is None:
return None
try:
result = self._pz_env.render()
if isinstance(result, np.ndarray):
return {
"mode": RenderMode.RGB_ARRAY.value,
"rgb": result,
"game_id": self.id,
"current_agent": self._current_agent,
"agents": self.agents,
"step": self._step_count,
}
return result
except Exception as exc:
self.log_constant(
LOG_ADAPTER_RENDER_ERROR,
exc_info=exc,
extra={"env_id": self.id},
)
return None
[docs]
def close(self) -> None:
"""Close the environment."""
if self._pz_env is not None:
self.log_constant(
LOG_ADAPTER_ENV_CLOSED,
extra={"env_id": self.id},
)
self._pz_env.close()
self._pz_env = None
[docs]
def build_step_state(
self,
observation: Any,
info: Mapping[str, Any],
) -> StepState:
"""Construct the canonical StepState for the current step."""
agent_snapshots: List[AgentSnapshot] = []
for agent_name in self.possible_agents:
is_active = agent_name == self._current_agent
is_terminated = agent_name in self._terminated_agents
snapshot = AgentSnapshot(
name=agent_name,
role="active" if is_active else ("terminated" if is_terminated else "waiting"),
info={
"reward": self._episode_rewards.get(agent_name, 0.0),
"has_action_mask": agent_name in self._action_masks,
},
)
agent_snapshots.append(snapshot)
return StepState(
active_agent=self._current_agent,
agents=tuple(agent_snapshots),
metrics={
"step_count": self._step_count,
"active_agents": len(self.agents),
"terminated_agents": len(self._terminated_agents),
},
environment={
"is_parallel": self._is_parallel,
"family": self._config.family.value if isinstance(self._config.family, PettingZooFamily) else self._config.family,
},
raw=dict(info) if isinstance(info, Mapping) else {},
)
# ─────────────────────────────────────────────────────────────────
# Multi-agent specific methods
# ─────────────────────────────────────────────────────────────────
[docs]
def get_action_space(self, agent: Optional[str] = None):
"""Get action space for an agent.
Args:
agent: Agent name (uses current agent if None for AEC)
Returns:
Gymnasium Space
"""
if self._pz_env is None:
raise RuntimeError("Environment not initialized")
if agent is None:
if self._is_parallel:
raise ValueError("Agent name required for Parallel environments")
agent = self._current_agent
if agent is None:
raise RuntimeError("No current agent available")
return self._pz_env.action_space(agent)
[docs]
def get_observation_space(self, agent: Optional[str] = None):
"""Get observation space for an agent.
Args:
agent: Agent name (uses current agent if None for AEC)
Returns:
Gymnasium Space
"""
if self._pz_env is None:
raise RuntimeError("Environment not initialized")
if agent is None:
if self._is_parallel:
raise ValueError("Agent name required for Parallel environments")
agent = self._current_agent
if agent is None:
raise RuntimeError("No current agent available")
return self._pz_env.observation_space(agent)
[docs]
def sample_action(self, agent: Optional[str] = None) -> Any:
"""Sample a random action for an agent.
Args:
agent: Agent name (uses current agent if None for AEC)
Returns:
Sampled action
"""
action_space = self.get_action_space(agent)
# Apply action mask if available
if agent is None:
agent = self._current_agent
if agent and agent in self._action_masks and self._action_masks[agent] is not None:
mask = self._action_masks[agent]
valid_actions = np.where(mask)[0]
if len(valid_actions) > 0:
return np.random.choice(valid_actions)
return action_space.sample()
[docs]
def sample_actions(self) -> Dict[str, Any]:
"""Sample random actions for all active agents.
Returns:
Dict mapping agent names to sampled actions
"""
return {agent: self.sample_action(agent) for agent in self.agents}
[docs]
def get_action_mask(self, agent: Optional[str] = None) -> Optional[np.ndarray]:
"""Get action mask for an agent.
Args:
agent: Agent name (uses current agent if None)
Returns:
Action mask array or None if not available
"""
if agent is None:
agent = self._current_agent
if agent is None:
return None
return self._action_masks.get(agent)
[docs]
def is_done(self) -> bool:
"""Check if episode is complete."""
if self._pz_env is None:
return True
return len(self.agents) == 0
[docs]
def is_human_controllable(self) -> bool:
"""Check if this environment supports human control."""
if isinstance(self._env_id, PettingZooEnvId):
return self._env_id in HUMAN_CONTROLLABLE_ENVS
return False
[docs]
def get_human_agent(self) -> Optional[str]:
"""Get the agent designated for human control."""
if self._config.human_player:
return self._config.human_player
# Default to first agent for human-controllable envs
if self.is_human_controllable() and self.possible_agents:
return self.possible_agents[0]
return None
# ═══════════════════════════════════════════════════════════════════════════
# Concrete adapter classes for specific environments
# ═══════════════════════════════════════════════════════════════════════════
class ChessAdapter(PettingZooAdapter):
"""Adapter for Chess environment."""
id = PettingZooEnvId.CHESS.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.CHESS,
(ControlMode.HUMAN_ONLY, ControlMode.AGENT_ONLY, ControlMode.HYBRID_TURN_BASED),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.CHESS)
super().__init__(context, config=config)
class ConnectFourAdapter(PettingZooAdapter):
"""Adapter for Connect Four environment."""
id = PettingZooEnvId.CONNECT_FOUR.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.CONNECT_FOUR,
(ControlMode.HUMAN_ONLY, ControlMode.AGENT_ONLY, ControlMode.HYBRID_TURN_BASED),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.CONNECT_FOUR)
super().__init__(context, config=config)
class TicTacToeAdapter(PettingZooAdapter):
"""Adapter for Tic-Tac-Toe environment."""
id = PettingZooEnvId.TIC_TAC_TOE.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.TIC_TAC_TOE,
(ControlMode.HUMAN_ONLY, ControlMode.AGENT_ONLY, ControlMode.HYBRID_TURN_BASED),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.TIC_TAC_TOE)
super().__init__(context, config=config)
class GoAdapter(PettingZooAdapter):
"""Adapter for Go environment."""
id = PettingZooEnvId.GO.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.GO,
(ControlMode.HUMAN_ONLY, ControlMode.AGENT_ONLY, ControlMode.HYBRID_TURN_BASED),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.GO)
super().__init__(context, config=config)
class SimpleSpreadAdapter(PettingZooAdapter):
"""Adapter for Simple Spread (MPE) environment."""
id = PettingZooEnvId.SIMPLE_SPREAD.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.SIMPLE_SPREAD,
(ControlMode.AGENT_ONLY, ControlMode.MULTI_AGENT_COOP),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.SIMPLE_SPREAD)
super().__init__(context, config=config)
class SimpleTagAdapter(PettingZooAdapter):
"""Adapter for Simple Tag (MPE) environment."""
id = PettingZooEnvId.SIMPLE_TAG.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.SIMPLE_TAG,
(ControlMode.AGENT_ONLY, ControlMode.MULTI_AGENT_COMPETITIVE),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.SIMPLE_TAG)
super().__init__(context, config=config)
class PistonballAdapter(PettingZooAdapter):
"""Adapter for Pistonball (Butterfly) environment."""
id = PettingZooEnvId.PISTONBALL.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.PISTONBALL,
(ControlMode.AGENT_ONLY, ControlMode.MULTI_AGENT_COOP),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.PISTONBALL)
super().__init__(context, config=config)
class KnightsArchersZombiesAdapter(PettingZooAdapter):
"""Adapter for Knights Archers Zombies (Butterfly) environment.
A cooperative 4-agent game where knights and archers defend against zombies.
Default agents: ['archer_0', 'archer_1', 'knight_0', 'knight_1']
"""
id = PettingZooEnvId.KNIGHTS_ARCHERS_ZOMBIES.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.KNIGHTS_ARCHERS_ZOMBIES,
(ControlMode.AGENT_ONLY, ControlMode.MULTI_AGENT_COOP),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.KNIGHTS_ARCHERS_ZOMBIES)
super().__init__(context, config=config)
class CooperativePongAdapter(PettingZooAdapter):
"""Adapter for Cooperative Pong (Butterfly) environment.
Two paddles work together to keep a ball in play.
"""
id = PettingZooEnvId.COOPERATIVE_PONG.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.COOPERATIVE_PONG,
(ControlMode.AGENT_ONLY, ControlMode.MULTI_AGENT_COOP),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.COOPERATIVE_PONG)
super().__init__(context, config=config)
class MultiwalkerAdapter(PettingZooAdapter):
"""Adapter for Multiwalker (SISL) environment."""
id = PettingZooEnvId.MULTIWALKER.value
supported_control_modes = PETTINGZOO_CONTROL_MODES.get(
PettingZooEnvId.MULTIWALKER,
(ControlMode.AGENT_ONLY, ControlMode.MULTI_AGENT_COOP),
)
def __init__(self, context: AdapterContext | None = None, *, config: PettingZooConfig | None = None) -> None:
if config is None:
config = PettingZooConfig(env_id=PettingZooEnvId.MULTIWALKER)
super().__init__(context, config=config)
# ═══════════════════════════════════════════════════════════════════════════
# Adapter registry for factory pattern
# ═══════════════════════════════════════════════════════════════════════════
# Note: This maps to GameId-compatible keys for integration with existing factory
# PettingZoo environments use their own enum but need string keys for compatibility
PETTINGZOO_ADAPTERS: Dict[str, Type[PettingZooAdapter]] = {
# Classic (AEC - turn-based)
PettingZooEnvId.CHESS.value: ChessAdapter,
PettingZooEnvId.CONNECT_FOUR.value: ConnectFourAdapter,
PettingZooEnvId.TIC_TAC_TOE.value: TicTacToeAdapter,
PettingZooEnvId.GO.value: GoAdapter,
# MPE (Parallel)
PettingZooEnvId.SIMPLE_SPREAD.value: SimpleSpreadAdapter,
PettingZooEnvId.SIMPLE_TAG.value: SimpleTagAdapter,
# Butterfly (Parallel - cooperative visual)
PettingZooEnvId.PISTONBALL.value: PistonballAdapter,
PettingZooEnvId.KNIGHTS_ARCHERS_ZOMBIES.value: KnightsArchersZombiesAdapter,
PettingZooEnvId.COOPERATIVE_PONG.value: CooperativePongAdapter,
# SISL (Parallel - continuous control)
PettingZooEnvId.MULTIWALKER.value: MultiwalkerAdapter,
}
def create_pettingzoo_adapter(
env_id: PettingZooEnvId | str,
context: AdapterContext | None = None,
config: PettingZooConfig | None = None,
) -> PettingZooAdapter:
"""Factory function to create a PettingZoo adapter.
Args:
env_id: Environment identifier
context: Adapter context
config: Optional configuration
Returns:
PettingZoo adapter instance
"""
env_id_str = env_id.value if isinstance(env_id, PettingZooEnvId) else env_id
if env_id_str in PETTINGZOO_ADAPTERS:
adapter_cls = PETTINGZOO_ADAPTERS[env_id_str]
return adapter_cls(context, config=config)
else:
# Generic adapter for unlisted environments
return PettingZooAdapter(context, env_id=env_id, config=config)
__all__ = [
"PettingZooConfig",
"PettingZooAdapter",
# Classic
"ChessAdapter",
"ConnectFourAdapter",
"TicTacToeAdapter",
"GoAdapter",
# MPE
"SimpleSpreadAdapter",
"SimpleTagAdapter",
# Butterfly
"PistonballAdapter",
"KnightsArchersZombiesAdapter",
"CooperativePongAdapter",
# SISL
"MultiwalkerAdapter",
# Registry & Factory
"PETTINGZOO_ADAPTERS",
"create_pettingzoo_adapter",
]