"""Paradigm-specific adapter abstractions for multi-paradigm RL orchestration.
This module provides:
- ParadigmAdapter: ABC for paradigm-aware stepping behavior
- Concrete adapters: SingleAgentAdapter, SequentialAdapter, SimultaneousAdapter
The ParadigmAdapter bridges between the GUI/orchestrator and paradigm-specific
environments (Gymnasium, PettingZoo AEC, PettingZoo Parallel).
See Also:
- :doc:`/documents/architecture/paradigms` for stepping paradigm details
- :doc:`/documents/architecture/operators/concept` for operator architecture
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping, Optional, Sequence, TYPE_CHECKING
from gym_gui.core.enums import SteppingParadigm
if TYPE_CHECKING:
from gym_gui.core.adapters.base import AdapterStep
@dataclass(slots=True)
class ParadigmStepResult:
"""Unified step result for all paradigms.
This normalizes results from different paradigms:
- Single-agent: One observation, one reward
- Sequential (AEC): Per-agent observation/reward for current agent
- Simultaneous (POSG): Dict of observations/rewards for all agents
Attributes:
observations: Mapping from agent_id to observation.
For single-agent, uses key "agent_0".
rewards: Mapping from agent_id to reward.
terminations: Mapping from agent_id to terminated flag.
truncations: Mapping from agent_id to truncated flag.
infos: Mapping from agent_id to info dict.
current_agent: The agent that just acted (Sequential mode).
all_done: Whether the episode is complete for all agents.
adapter_steps: Raw AdapterStep results (if available).
"""
observations: Dict[str, Any] = field(default_factory=dict)
rewards: Dict[str, float] = field(default_factory=dict)
terminations: Dict[str, bool] = field(default_factory=dict)
truncations: Dict[str, bool] = field(default_factory=dict)
infos: Dict[str, Dict[str, Any]] = field(default_factory=dict)
current_agent: Optional[str] = None
all_done: bool = False
adapter_steps: Dict[str, "AdapterStep[Any]"] = field(default_factory=dict)
def is_agent_done(self, agent_id: str) -> bool:
"""Check if a specific agent's episode is done."""
return self.terminations.get(agent_id, False) or self.truncations.get(agent_id, False)
[docs]
class ParadigmAdapter(ABC):
"""Abstract base class for paradigm-specific stepping behavior.
ParadigmAdapter bridges between Mosaic's GUI/orchestrator and paradigm-specific
workers. It abstracts:
1. Which agents need actions at any given time
2. How to execute a step (single action vs. joint action dict)
3. How to normalize results across paradigms
Subclasses implement paradigm-specific logic:
- SingleAgentAdapter: Gymnasium-style single agent
- SequentialAdapter: PettingZoo AEC-style turn-based
- SimultaneousAdapter: PettingZoo Parallel / RLlib POSG-style
Example:
>>> adapter = get_paradigm_adapter(env)
>>> while not adapter.is_done():
... agents = adapter.get_agents_to_act()
... actions = {a: policy(a, adapter.get_observation(a)) for a in agents}
... result = adapter.step(actions)
See Also:
- :doc:`/documents/architecture/paradigms` for paradigm details
"""
@property
@abstractmethod
def paradigm(self) -> SteppingParadigm:
"""The stepping paradigm this adapter implements."""
...
@property
@abstractmethod
def agent_ids(self) -> Sequence[str]:
"""All agent identifiers in this environment.
For single-agent environments, returns ["agent_0"].
For multi-agent environments, returns all agent IDs.
"""
...
[docs]
@abstractmethod
def get_agents_to_act(self) -> List[str]:
"""Return agents that need actions NOW.
Returns:
List of agent IDs that require actions in the current step.
- SINGLE_AGENT: ["agent_0"]
- SEQUENTIAL: [current_agent_id] (one agent per step)
- SIMULTANEOUS: [all active agent IDs]
"""
...
[docs]
@abstractmethod
def get_observation(self, agent_id: str) -> Any:
"""Get the current observation for a specific agent.
Args:
agent_id: The agent to get observation for.
Returns:
The observation for the specified agent.
Raises:
KeyError: If agent_id is not valid.
"""
...
[docs]
@abstractmethod
def get_observations(self, agent_ids: Optional[List[str]] = None) -> Dict[str, Any]:
"""Get observations for multiple agents.
Args:
agent_ids: List of agent IDs. If None, returns all active agents.
Returns:
Dict mapping agent_id to observation.
"""
...
[docs]
@abstractmethod
def step(self, actions: Dict[str, Any]) -> ParadigmStepResult:
"""Execute paradigm-appropriate step.
Args:
actions: Dict mapping agent_id to action.
- SINGLE_AGENT: {"agent_0": action}
- SEQUENTIAL: {current_agent: action}
- SIMULTANEOUS: {agent_id: action for all active agents}
Returns:
ParadigmStepResult with normalized observations, rewards, etc.
"""
...
[docs]
@abstractmethod
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
) -> ParadigmStepResult:
"""Reset the environment and return initial observations.
Args:
seed: Optional random seed for reproducibility.
options: Optional reset options dict.
Returns:
ParadigmStepResult with initial observations (rewards=0, done=False).
"""
...
[docs]
@abstractmethod
def is_done(self) -> bool:
"""Check if the episode is complete for all agents."""
...
[docs]
@abstractmethod
def close(self) -> None:
"""Clean up environment resources."""
...
# ------------------------------------------------------------------
# Optional lifecycle hooks (subclasses may override)
# ------------------------------------------------------------------
[docs]
def get_info(self, agent_id: str) -> Dict[str, Any]:
"""Get the info dict for a specific agent.
Default implementation returns empty dict. Subclasses should override.
"""
return {}
[docs]
def get_infos(self, agent_ids: Optional[List[str]] = None) -> Dict[str, Dict[str, Any]]:
"""Get info dicts for multiple agents.
Default implementation calls get_info for each agent.
"""
ids = agent_ids if agent_ids is not None else list(self.agent_ids)
return {agent_id: self.get_info(agent_id) for agent_id in ids}
[docs]
def render(self) -> Any:
"""Render the environment.
Default implementation returns None. Subclasses should override.
"""
return None
# ------------------------------------------------------------------
# Utility methods
# ------------------------------------------------------------------
[docs]
def is_single_agent(self) -> bool:
"""Check if this is a single-agent environment."""
return self.paradigm == SteppingParadigm.SINGLE_AGENT
[docs]
def is_sequential(self) -> bool:
"""Check if this is a sequential (AEC) environment."""
return self.paradigm == SteppingParadigm.SEQUENTIAL
[docs]
def is_simultaneous(self) -> bool:
"""Check if this is a simultaneous (POSG) environment."""
return self.paradigm == SteppingParadigm.SIMULTANEOUS
[docs]
def num_agents(self) -> int:
"""Return the number of agents in the environment."""
return len(self.agent_ids)
# =============================================================================
# Concrete Paradigm Adapters
# =============================================================================
class SingleAgentParadigmAdapter(ParadigmAdapter):
"""Paradigm adapter for single-agent Gymnasium environments.
Wraps a standard Gymnasium environment with the ParadigmAdapter interface.
Uses "agent_0" as the canonical agent ID.
Example:
>>> import gymnasium as gym
>>> env = gym.make("CartPole-v1")
>>> adapter = SingleAgentParadigmAdapter(env)
>>> result = adapter.reset()
>>> while not adapter.is_done():
... action = policy(result.observations["agent_0"])
... result = adapter.step({"agent_0": action})
"""
AGENT_ID = "agent_0"
def __init__(self, env: Any) -> None:
"""Initialize with a Gymnasium environment.
Args:
env: A Gymnasium-compatible environment.
"""
self._env = env
self._current_obs: Any = None
self._current_info: Dict[str, Any] = {}
self._done = False
@property
def paradigm(self) -> SteppingParadigm:
return SteppingParadigm.SINGLE_AGENT
@property
def agent_ids(self) -> Sequence[str]:
return (self.AGENT_ID,)
def get_agents_to_act(self) -> List[str]:
if self._done:
return []
return [self.AGENT_ID]
def get_observation(self, agent_id: str) -> Any:
if agent_id != self.AGENT_ID:
raise KeyError(f"Unknown agent '{agent_id}'. Single-agent env uses '{self.AGENT_ID}'.")
return self._current_obs
def get_observations(self, agent_ids: Optional[List[str]] = None) -> Dict[str, Any]:
if agent_ids is not None and self.AGENT_ID not in agent_ids:
return {}
return {self.AGENT_ID: self._current_obs}
def get_info(self, agent_id: str) -> Dict[str, Any]:
if agent_id != self.AGENT_ID:
raise KeyError(f"Unknown agent '{agent_id}'.")
return self._current_info
def step(self, actions: Dict[str, Any]) -> ParadigmStepResult:
action = actions.get(self.AGENT_ID)
if action is None:
raise ValueError(f"Missing action for '{self.AGENT_ID}'")
obs, reward, terminated, truncated, info = self._env.step(action)
self._current_obs = obs
self._current_info = info if isinstance(info, dict) else {}
self._done = terminated or truncated
return ParadigmStepResult(
observations={self.AGENT_ID: obs},
rewards={self.AGENT_ID: float(reward)},
terminations={self.AGENT_ID: terminated},
truncations={self.AGENT_ID: truncated},
infos={self.AGENT_ID: self._current_info},
current_agent=self.AGENT_ID,
all_done=self._done,
)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
) -> ParadigmStepResult:
reset_kwargs: Dict[str, Any] = {}
if seed is not None:
reset_kwargs["seed"] = seed
if options is not None:
reset_kwargs["options"] = options
obs, info = self._env.reset(**reset_kwargs)
self._current_obs = obs
self._current_info = info if isinstance(info, dict) else {}
self._done = False
return ParadigmStepResult(
observations={self.AGENT_ID: obs},
rewards={self.AGENT_ID: 0.0},
terminations={self.AGENT_ID: False},
truncations={self.AGENT_ID: False},
infos={self.AGENT_ID: self._current_info},
current_agent=self.AGENT_ID,
all_done=False,
)
def is_done(self) -> bool:
return self._done
def close(self) -> None:
self._env.close()
def render(self) -> Any:
return self._env.render()
class SequentialParadigmAdapter(ParadigmAdapter):
"""Paradigm adapter for sequential (AEC) multi-agent environments.
Wraps a PettingZoo AEC environment with the ParadigmAdapter interface.
Agents take turns acting one at a time.
Example:
>>> from pettingzoo.classic import chess_v6
>>> env = chess_v6.env()
>>> adapter = SequentialParadigmAdapter(env)
>>> result = adapter.reset()
>>> while not adapter.is_done():
... agents = adapter.get_agents_to_act() # Returns [current_agent]
... action = policy(agents[0], result.observations[agents[0]])
... result = adapter.step({agents[0]: action})
"""
def __init__(self, env: Any) -> None:
"""Initialize with a PettingZoo AEC environment.
Args:
env: A PettingZoo AEC-compatible environment.
"""
self._env = env
self._all_done = False
@property
def paradigm(self) -> SteppingParadigm:
return SteppingParadigm.SEQUENTIAL
@property
def agent_ids(self) -> Sequence[str]:
return tuple(self._env.possible_agents)
def get_agents_to_act(self) -> List[str]:
if self._all_done or not self._env.agents:
return []
# AEC: only current agent needs to act
agent = self._env.agent_selection
return [agent] if agent else []
def get_observation(self, agent_id: str) -> Any:
# In AEC, use observe() or last() depending on implementation
if hasattr(self._env, "observe"):
return self._env.observe(agent_id)
# Fallback: use last() for current agent
if agent_id == self._env.agent_selection:
obs, _, _, _, _ = self._env.last()
return obs
raise KeyError(f"Cannot get observation for non-current agent '{agent_id}' in AEC mode.")
def get_observations(self, agent_ids: Optional[List[str]] = None) -> Dict[str, Any]:
if agent_ids is None:
agent_ids = self.get_agents_to_act()
return {agent_id: self.get_observation(agent_id) for agent_id in agent_ids}
def get_info(self, agent_id: str) -> Dict[str, Any]:
if agent_id == self._env.agent_selection:
_, _, _, _, info = self._env.last()
return info if isinstance(info, dict) else {}
return {}
def step(self, actions: Dict[str, Any]) -> ParadigmStepResult:
current_agent = self._env.agent_selection
action = actions.get(current_agent)
# Check if agent is terminated/truncated (pass None)
obs, reward, terminated, truncated, info = self._env.last()
if terminated or truncated:
action = None
self._env.step(action)
# Get new state after step
if self._env.agents:
new_agent = self._env.agent_selection
new_obs, new_reward, new_terminated, new_truncated, new_info = self._env.last()
else:
self._all_done = True
new_agent = current_agent
new_obs, new_reward, new_terminated, new_truncated, new_info = obs, reward, True, False, info
return ParadigmStepResult(
observations={new_agent: new_obs} if new_agent else {},
rewards={current_agent: float(reward)},
terminations={current_agent: terminated},
truncations={current_agent: truncated},
infos={current_agent: info if isinstance(info, dict) else {}},
current_agent=new_agent,
all_done=self._all_done or not self._env.agents,
)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
) -> ParadigmStepResult:
reset_kwargs: Dict[str, Any] = {}
if seed is not None:
reset_kwargs["seed"] = seed
if options is not None:
reset_kwargs["options"] = options
self._env.reset(**reset_kwargs)
self._all_done = False
current_agent = self._env.agent_selection
obs, _, _, _, info = self._env.last()
return ParadigmStepResult(
observations={current_agent: obs},
rewards={agent: 0.0 for agent in self._env.agents},
terminations={agent: False for agent in self._env.agents},
truncations={agent: False for agent in self._env.agents},
infos={current_agent: info if isinstance(info, dict) else {}},
current_agent=current_agent,
all_done=False,
)
def is_done(self) -> bool:
return self._all_done or not self._env.agents
def close(self) -> None:
self._env.close()
def render(self) -> Any:
return self._env.render()
class SimultaneousParadigmAdapter(ParadigmAdapter):
"""Paradigm adapter for simultaneous (POSG) multi-agent environments.
Wraps a PettingZoo Parallel or RLlib MultiAgentEnv with the ParadigmAdapter interface.
All agents act simultaneously each step.
Example:
>>> from pettingzoo.butterfly import pistonball_v6
>>> env = pistonball_v6.parallel_env()
>>> adapter = SimultaneousParadigmAdapter(env)
>>> result = adapter.reset()
>>> while not adapter.is_done():
... agents = adapter.get_agents_to_act() # Returns all agents
... actions = {a: policy(a, result.observations[a]) for a in agents}
... result = adapter.step(actions)
"""
def __init__(self, env: Any) -> None:
"""Initialize with a PettingZoo Parallel or RLlib MultiAgentEnv.
Args:
env: A parallel multi-agent environment.
"""
self._env = env
self._current_obs: Dict[str, Any] = {}
self._current_infos: Dict[str, Dict[str, Any]] = {}
self._terminations: Dict[str, bool] = {}
self._truncations: Dict[str, bool] = {}
self._all_done = False
@property
def paradigm(self) -> SteppingParadigm:
return SteppingParadigm.SIMULTANEOUS
@property
def agent_ids(self) -> Sequence[str]:
return tuple(self._env.possible_agents)
def get_agents_to_act(self) -> List[str]:
if self._all_done:
return []
# In parallel mode, all active (not done) agents act
return [
agent
for agent in self._env.agents
if not self._terminations.get(agent, False) and not self._truncations.get(agent, False)
]
def get_observation(self, agent_id: str) -> Any:
if agent_id not in self._current_obs:
raise KeyError(f"No observation for agent '{agent_id}'")
return self._current_obs[agent_id]
def get_observations(self, agent_ids: Optional[List[str]] = None) -> Dict[str, Any]:
if agent_ids is None:
return dict(self._current_obs)
return {agent_id: self._current_obs[agent_id] for agent_id in agent_ids if agent_id in self._current_obs}
def get_info(self, agent_id: str) -> Dict[str, Any]:
return self._current_infos.get(agent_id, {})
def get_infos(self, agent_ids: Optional[List[str]] = None) -> Dict[str, Dict[str, Any]]:
if agent_ids is None:
return dict(self._current_infos)
return {agent_id: self._current_infos.get(agent_id, {}) for agent_id in agent_ids}
def step(self, actions: Dict[str, Any]) -> ParadigmStepResult:
obs, rewards, terminations, truncations, infos = self._env.step(actions)
self._current_obs = obs
self._current_infos = {k: v if isinstance(v, dict) else {} for k, v in infos.items()}
self._terminations = terminations
self._truncations = truncations
# Check if all agents are done
# PettingZoo parallel uses "__all__" key or empty agents list
if "__all__" in terminations:
self._all_done = terminations["__all__"]
else:
self._all_done = not self._env.agents
return ParadigmStepResult(
observations=obs,
rewards={k: float(v) for k, v in rewards.items()},
terminations=terminations,
truncations=truncations,
infos=self._current_infos,
current_agent=None, # No single current agent in simultaneous mode
all_done=self._all_done,
)
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
) -> ParadigmStepResult:
reset_kwargs: Dict[str, Any] = {}
if seed is not None:
reset_kwargs["seed"] = seed
if options is not None:
reset_kwargs["options"] = options
obs, infos = self._env.reset(**reset_kwargs)
self._current_obs = obs
self._current_infos = {k: v if isinstance(v, dict) else {} for k, v in infos.items()}
self._terminations = {agent: False for agent in self._env.agents}
self._truncations = {agent: False for agent in self._env.agents}
self._all_done = False
return ParadigmStepResult(
observations=obs,
rewards={agent: 0.0 for agent in self._env.agents},
terminations=self._terminations,
truncations=self._truncations,
infos=self._current_infos,
current_agent=None,
all_done=False,
)
def is_done(self) -> bool:
return self._all_done
def close(self) -> None:
self._env.close()
def render(self) -> Any:
return self._env.render()
# =============================================================================
# Factory Function
# =============================================================================
def create_paradigm_adapter(
env: Any,
paradigm: Optional[SteppingParadigm] = None,
) -> ParadigmAdapter:
"""Create a ParadigmAdapter for the given environment.
Args:
env: The environment to wrap.
paradigm: Optional explicit paradigm. If None, auto-detected.
Returns:
An appropriate ParadigmAdapter subclass instance.
Raises:
ValueError: If paradigm cannot be determined.
"""
# Auto-detect paradigm if not specified
if paradigm is None:
paradigm = _detect_paradigm(env)
if paradigm == SteppingParadigm.SINGLE_AGENT:
return SingleAgentParadigmAdapter(env)
elif paradigm == SteppingParadigm.SEQUENTIAL:
return SequentialParadigmAdapter(env)
elif paradigm == SteppingParadigm.SIMULTANEOUS:
return SimultaneousParadigmAdapter(env)
else:
raise ValueError(f"Unknown paradigm: {paradigm}")
def _detect_paradigm(env: Any) -> SteppingParadigm:
"""Auto-detect the stepping paradigm from environment type.
Args:
env: The environment to inspect.
Returns:
The detected SteppingParadigm.
"""
# Check for PettingZoo AEC (has agent_iter and last methods)
if hasattr(env, "agent_iter") and hasattr(env, "last"):
return SteppingParadigm.SEQUENTIAL
# Check for PettingZoo Parallel (has possible_agents but step takes dict)
if hasattr(env, "possible_agents") and hasattr(env, "agents"):
# Parallel envs don't have agent_iter
if not hasattr(env, "agent_iter"):
return SteppingParadigm.SIMULTANEOUS
# Default to single-agent (standard Gymnasium)
return SteppingParadigm.SINGLE_AGENT
__all__ = [
"ParadigmAdapter",
"ParadigmStepResult",
"SingleAgentParadigmAdapter",
"SequentialParadigmAdapter",
"SimultaneousParadigmAdapter",
"create_paradigm_adapter",
]