Source code for gym_gui.services.policy_mapping

# gym_gui/services/policy_mapping.py

"""PolicyMappingService for per-agent policy mapping in multi-agent environments.

This module provides:
- AgentPolicyBinding: Binding between an agent and its policy controller
- PolicyMappingService: Per-agent policy mapping with paradigm awareness

The PolicyMappingService extends ActorService to support:
1. Multiple active policies (one per agent)
2. Paradigm-aware action selection (Sequential vs Simultaneous)
3. Worker-specific routing

For single-agent environments, it delegates to ActorService.
For multi-agent, it maintains agent_id → policy_id mapping.

See Also:
    - :doc:`/documents/architecture/policy_mapping` for policy mapping details
    - :doc:`/documents/architecture/paradigms` for stepping paradigm architecture
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional

from gym_gui.core.enums import SteppingParadigm
from gym_gui.logging_config.helpers import LogConstantMixin
from gym_gui.services.actor import (
    Actor,
    ActorService,
    EpisodeSummary,
    StepSnapshot,
)


[docs] @dataclass class AgentPolicyBinding: """Binding between an agent and its policy controller. Attributes: agent_id: Unique identifier for the agent in the environment. policy_id: References an Actor registered in ActorService. worker_id: Optional worker identifier (e.g., "cleanrl_worker", "llm_worker"). config: Worker-specific configuration options. """ agent_id: str policy_id: str worker_id: Optional[str] = None config: Dict[str, Any] = field(default_factory=dict)
[docs] class PolicyMappingService(LogConstantMixin): """Per-agent policy mapping for multi-agent environments. Extends ActorService to support: 1. Multiple active policies (one per agent) 2. Paradigm-aware action selection 3. Worker-specific routing For single-agent environments, delegates to ActorService. For multi-agent, maintains agent_id → policy_id mapping. Example: >>> actor_service = ActorService() >>> actor_service.register_actor(HumanKeyboardActor(), activate=True) >>> actor_service.register_actor(CleanRLWorkerActor()) >>> >>> mapping = PolicyMappingService(actor_service) >>> mapping.set_paradigm(SteppingParadigm.SEQUENTIAL) >>> mapping.set_agents(["player_0", "player_1"]) >>> mapping.bind_agent_policy("player_0", "human_keyboard") >>> mapping.bind_agent_policy("player_1", "cleanrl_worker") See Also: - :doc:`/documents/architecture/policy_mapping` for policy mapping details """ def __init__(self, actor_service: ActorService) -> None: """Initialize PolicyMappingService. Args: actor_service: The underlying ActorService for policy management. """ self._actor_service = actor_service self._bindings: Dict[str, AgentPolicyBinding] = {} self._paradigm: SteppingParadigm = SteppingParadigm.SINGLE_AGENT self._agent_ids: List[str] = [] self._logger = logging.getLogger("gym_gui.services.policy_mapping") # ------------------------------------------------------------------ # Configuration # ------------------------------------------------------------------
[docs] def set_paradigm(self, paradigm: SteppingParadigm) -> None: """Set the stepping paradigm for this session. Args: paradigm: The stepping paradigm (SINGLE_AGENT, SEQUENTIAL, etc.) """ self._paradigm = paradigm self._logger.debug(f"Paradigm set to {paradigm.name}")
[docs] def set_agents(self, agent_ids: List[str]) -> None: """Configure the list of agents in the environment. Auto-binds agents to the default policy if not already bound. Args: agent_ids: List of agent identifiers from the environment. """ self._agent_ids = list(agent_ids) default_policy = self._actor_service.get_active_actor_id() # Auto-bind to default policy if not already bound for agent_id in agent_ids: if agent_id not in self._bindings and default_policy is not None: self._bindings[agent_id] = AgentPolicyBinding( agent_id=agent_id, policy_id=default_policy, ) self._logger.debug( f"Auto-bound agent '{agent_id}' to policy '{default_policy}'" )
[docs] def bind_agent_policy( self, agent_id: str, policy_id: str, *, worker_id: Optional[str] = None, config: Optional[Dict[str, Any]] = None, ) -> None: """Bind an agent to a specific policy. Args: agent_id: The agent to bind. policy_id: The policy (Actor) to use for this agent. worker_id: Optional worker identifier for remote execution. config: Optional worker-specific configuration. Raises: KeyError: If policy_id is not registered in ActorService. """ available = list(self._actor_service.available_actor_ids()) if policy_id not in available: raise KeyError( f"Unknown policy '{policy_id}'. Available: {available}" ) self._bindings[agent_id] = AgentPolicyBinding( agent_id=agent_id, policy_id=policy_id, worker_id=worker_id, config=config or {}, ) self._logger.debug(f"Bound agent '{agent_id}' to policy '{policy_id}'")
[docs] def unbind_agent(self, agent_id: str) -> None: """Remove binding for an agent. Args: agent_id: The agent to unbind. """ if agent_id in self._bindings: del self._bindings[agent_id] self._logger.debug(f"Unbound agent '{agent_id}'")
[docs] def get_binding(self, agent_id: str) -> Optional[AgentPolicyBinding]: """Get the policy binding for an agent. Args: agent_id: The agent identifier. Returns: The binding if found, otherwise None. """ return self._bindings.get(agent_id)
[docs] def get_all_bindings(self) -> Dict[str, AgentPolicyBinding]: """Get all agent-policy bindings. Returns: Copy of bindings dictionary. """ return dict(self._bindings)
[docs] def available_policy_ids(self) -> Iterable[str]: """Get available policy IDs from ActorService. Returns: Iterable of policy IDs. """ return self._actor_service.available_actor_ids()
# ------------------------------------------------------------------ # Action Selection (Paradigm-Aware) # ------------------------------------------------------------------
[docs] def select_action( self, agent_id: str, snapshot: StepSnapshot, ) -> Optional[int]: """Select action for a specific agent (Sequential/AEC mode). Args: agent_id: The agent needing an action. snapshot: Current step state. Returns: The action to take, or None to abstain. """ binding = self._bindings.get(agent_id) if binding is None: # Fallback to legacy ActorService for unbound agents self._logger.debug( f"No binding for agent '{agent_id}', using legacy ActorService" ) return self._actor_service.select_action(snapshot) # Get the actor for this agent's policy actor = self._get_actor(binding.policy_id) if actor is None: self._logger.warning( f"Policy '{binding.policy_id}' not found for agent '{agent_id}'" ) return None return actor.select_action(snapshot)
[docs] def select_actions( self, observations: Dict[str, Any], snapshots: Dict[str, StepSnapshot], ) -> Dict[str, Optional[int]]: """Select actions for all agents (Simultaneous/POSG mode). Args: observations: Dict mapping agent_id to observation. snapshots: Dict mapping agent_id to StepSnapshot. Returns: Dict mapping agent_id to action (or None). """ actions: Dict[str, Optional[int]] = {} for agent_id, snapshot in snapshots.items(): actions[agent_id] = self.select_action(agent_id, snapshot) return actions
# ------------------------------------------------------------------ # Step Notification # ------------------------------------------------------------------
[docs] def notify_step( self, agent_id: str, snapshot: StepSnapshot, ) -> None: """Notify the appropriate policy of a step result. Args: agent_id: The agent that took the step. snapshot: The step result. """ binding = self._bindings.get(agent_id) if binding is None: self._actor_service.notify_step(snapshot) return actor = self._get_actor(binding.policy_id) if actor is not None: actor.on_step(snapshot)
[docs] def notify_steps( self, snapshots: Dict[str, StepSnapshot], ) -> None: """Notify all agents of their step results (Simultaneous mode). Args: snapshots: Dict mapping agent_id to StepSnapshot. """ for agent_id, snapshot in snapshots.items(): self.notify_step(agent_id, snapshot)
[docs] def notify_episode_end( self, agent_id: str, summary: EpisodeSummary, ) -> None: """Notify the appropriate policy of episode end. Args: agent_id: The agent whose episode ended. summary: Episode summary information. """ binding = self._bindings.get(agent_id) if binding is None: self._actor_service.notify_episode_end(summary) return actor = self._get_actor(binding.policy_id) if actor is not None: actor.on_episode_end(summary)
[docs] def notify_all_episode_end( self, summaries: Dict[str, EpisodeSummary], ) -> None: """Notify all agents of episode end. Args: summaries: Dict mapping agent_id to EpisodeSummary. """ for agent_id, summary in summaries.items(): self.notify_episode_end(agent_id, summary)
# ------------------------------------------------------------------ # Reset # ------------------------------------------------------------------
[docs] def reset(self) -> None: """Reset all bindings for a new session.""" self._bindings.clear() self._agent_ids.clear() self._paradigm = SteppingParadigm.SINGLE_AGENT self._logger.debug("PolicyMappingService reset")
# ------------------------------------------------------------------ # Convenience Properties # ------------------------------------------------------------------
[docs] def is_multi_agent(self) -> bool: """Check if we're in multi-agent mode. Returns: True if more than one agent is configured. """ return len(self._agent_ids) > 1
@property def paradigm(self) -> SteppingParadigm: """Get the current stepping paradigm.""" return self._paradigm @property def agent_ids(self) -> List[str]: """Get list of configured agent IDs.""" return list(self._agent_ids) @property def actor_service(self) -> ActorService: """Get the underlying ActorService.""" return self._actor_service # ------------------------------------------------------------------ # Internal Helpers # ------------------------------------------------------------------ def _get_actor(self, policy_id: str) -> Optional[Actor]: """Get an Actor by ID from the underlying ActorService. Args: policy_id: The policy/actor ID. Returns: The Actor if found, otherwise None. """ # Access internal _actors dict (composition pattern) return self._actor_service._actors.get(policy_id)
__all__ = [ "AgentPolicyBinding", "PolicyMappingService", ]