Source code for gym_gui.core.adapters.base

"""Abstract adapter contract for Gymnasium environments."""

from __future__ import annotations

from abc import ABC
from dataclasses import dataclass, field
import logging
from typing import Any, Callable, Generic, Mapping, Sequence, TypeVar

import gymnasium as gym  # type: ignore[import]

from gym_gui.core.enums import ControlMode, RenderMode, SteppingParadigm
from gym_gui.core.spaces.serializer import describe_space
from gym_gui.core.spaces.vector_metadata import describe_vector_environment
from gym_gui.logging_config.log_constants import (
    LOG_ADAPTER_ENV_CLOSED,
    LOG_ADAPTER_ENV_CREATED,
    LOG_ADAPTER_ENV_RESET,
    LOG_ADAPTER_INIT_ERROR,
    LOG_ADAPTER_PAYLOAD_ERROR,
    LOG_ADAPTER_RENDER_ERROR,
    LOG_ADAPTER_STEP_ERROR,
    LOG_ADAPTER_STEP_SUMMARY,
    LOG_ADAPTER_STATE_INVALID,
    LogConstant,
)
from gym_gui.logging_config.helpers import LogConstantMixin

ObservationT = TypeVar("ObservationT")
ActionT = TypeVar("ActionT")


_LOGGER = logging.getLogger(__name__)


@dataclass(slots=True)
class AdapterContext:
    """Context payload adapters can use for configuration and callbacks."""

    settings: Any
    control_mode: ControlMode
    logger_factory: Callable[[str], logging.Logger] | None = None

    def get_logger(self, name: str) -> logging.Logger:
        if self.logger_factory is not None:
            return self.logger_factory(name)
        return logging.getLogger(name)


[docs] @dataclass(frozen=True) class WorkerCapabilities: """Declares what stepping paradigms and features a worker/adapter supports. This dataclass is used by the WorkerOrchestrator to: 1. Match environments to compatible workers 2. Configure paradigm-specific stepping behavior 3. Validate worker compatibility before launching runs Attributes: stepping_paradigm: Primary stepping model (SINGLE_AGENT, SEQUENTIAL, etc.) supported_paradigms: All paradigms this worker can handle. env_types: Environment types supported (e.g., ["gymnasium", "pettingzoo"]). action_spaces: Action space types supported (e.g., ["discrete", "continuous"]). observation_spaces: Observation space types supported (e.g., ["box", "dict"]). max_agents: Maximum number of agents (1 for single-agent). supports_self_play: Whether worker can train via self-play. supports_population: Whether worker supports population-based training. supports_record: Whether worker supports recording episodes. supports_fast_reset: Whether worker supports fast environment reset. max_fps: Target frame rate for continuous paradigms (None for turn-based). requires_gpu: Whether GPU is required. gpu_memory_mb: Estimated GPU memory requirement in MB. cpu_cores: Recommended CPU cores. Example: >>> caps = WorkerCapabilities( ... stepping_paradigm=SteppingParadigm.SINGLE_AGENT, ... env_types=["gymnasium"], ... action_spaces=["discrete", "continuous"], ... max_agents=1, ... ) """ stepping_paradigm: SteppingParadigm supported_paradigms: tuple[SteppingParadigm, ...] = () env_types: tuple[str, ...] = ("gymnasium",) action_spaces: tuple[str, ...] = ("discrete",) observation_spaces: tuple[str, ...] = ("box",) max_agents: int = 1 supports_self_play: bool = False supports_population: bool = False supports_record: bool = False supports_fast_reset: bool = False max_fps: float | None = None requires_gpu: bool = False gpu_memory_mb: int | None = None cpu_cores: int = 1 def __post_init__(self) -> None: # Ensure stepping_paradigm is in supported_paradigms if not self.supported_paradigms: object.__setattr__( self, "supported_paradigms", (self.stepping_paradigm,), )
[docs] def supports_paradigm(self, paradigm: SteppingParadigm) -> bool: """Check if this worker supports the given stepping paradigm.""" return paradigm in self.supported_paradigms
[docs] def supports_env_type(self, env_type: str) -> bool: """Check if this worker supports the given environment type.""" return env_type in self.env_types
[docs] def supports_action_space(self, space_type: str) -> bool: """Check if this worker supports the given action space type.""" return space_type in self.action_spaces
[docs] def is_multi_agent(self) -> bool: """Check if this worker supports multi-agent environments.""" return self.max_agents > 1
@dataclass(slots=True) class AgentSnapshot: """State for a single agent participant in the environment.""" name: str role: str | None = None position: tuple[int, int] | None = None orientation: str | None = None info: Mapping[str, Any] = field(default_factory=dict) def as_dict(self) -> dict[str, Any]: return { "name": self.name, "role": self.role, "position": self.position, "orientation": self.orientation, "info": dict(self.info), } @dataclass(slots=True) class StepState: """Machine-readable snapshot of an environment step.""" active_agent: str | None = None agents: Sequence[AgentSnapshot] = field(default_factory=tuple) objectives: Sequence[Mapping[str, Any]] = field(default_factory=tuple) hazards: Sequence[Mapping[str, Any]] = field(default_factory=tuple) inventory: Mapping[str, Any] = field(default_factory=dict) metrics: Mapping[str, Any] = field(default_factory=dict) environment: Mapping[str, Any] = field(default_factory=dict) raw: Mapping[str, Any] = field(default_factory=dict) def as_dict(self) -> dict[str, Any]: """Return a plain dictionary representation for policies and UI.""" return { "active_agent": self.active_agent, "agents": [agent.as_dict() for agent in self.agents], "objectives": [dict(obj) for obj in self.objectives], "hazards": [dict(hazard) for hazard in self.hazards], "inventory": dict(self.inventory), "metrics": dict(self.metrics), "environment": dict(self.environment), "raw": dict(self.raw), } @dataclass(slots=True) class AdapterStep(Generic[ObservationT]): """Standardised step result consumed by orchestrators.""" observation: ObservationT reward: float terminated: bool truncated: bool info: Mapping[str, Any] render_payload: Any | None = None render_hint: Mapping[str, Any] | None = None agent_id: str | None = None frame_ref: str | None = None payload_version: int = 1 state: StepState = field(default_factory=StepState) class AdapterNotReadyError(RuntimeError): """Raised when an adapter is used before `load` has been called.""" class UnsupportedModeError(RuntimeError): """Raised when a requested control mode is incompatible with the adapter."""
[docs] class EnvironmentAdapter(ABC, Generic[ObservationT, ActionT], LogConstantMixin): """Lifecycle contract for all Gymnasium environment adapters. Attributes: id: The Gymnasium environment ID (e.g., "CartPole-v1"). supported_control_modes: Control modes this adapter supports. supported_render_modes: Render modes this adapter supports. default_render_mode: Default render mode for this adapter. stepping_paradigm: The RL stepping paradigm (default: SINGLE_AGENT). """ id: str supported_control_modes: tuple[ControlMode, ...] supported_render_modes: tuple[RenderMode, ...] = () default_render_mode: RenderMode stepping_paradigm: SteppingParadigm = SteppingParadigm.SINGLE_AGENT def __init__(self, context: AdapterContext | None = None) -> None: self._context = context self._logger = _LOGGER self._env: gym.Env[Any, Any] | None = None self._space_signature: Mapping[str, Any] | None = None self._vector_metadata: Mapping[str, Any] | None = None # Episode accounting aligned with xuance wrappers self._episode_step: int = 0 self._episode_return: float = 0.0 # ------------------------------------------------------------------ # Lifecycle hooks # ------------------------------------------------------------------
[docs] def bind(self, context: AdapterContext) -> None: """Bind the adapter to a runtime context after instantiation.""" self._context = context
[docs] def load(self) -> None: """Instantiate underlying Gymnasium environment resources.""" default_mode = self._resolve_default_render_mode() kwargs: dict[str, Any] = {"render_mode": default_mode.value} extra_kwargs = self.gym_kwargs() if extra_kwargs: kwargs.update(extra_kwargs) env = gym.make(self.id, **kwargs) env = self.apply_wrappers(env) self.log_constant( LOG_ADAPTER_ENV_CREATED, extra={ "env_id": self.id, "render_mode": default_mode.value, "gym_kwargs": ",".join(sorted(extra_kwargs.keys())) if extra_kwargs else "-", "wrapped_class": env.__class__.__name__, }, ) self._set_env(env)
[docs] def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> AdapterStep[ObservationT]: env = self._require_env() observation, info = env.reset(seed=seed, options=options) # Reset episode counters and mirror xuance-style info keys self._episode_step = 0 self._episode_return = 0.0 try: info = dict(info) if isinstance(info, Mapping) else {} info["episode_step"] = self._episode_step except Exception as exc: # pragma: no cover - defensive # Surface unexpected info-shape/coercion issues for observability self.log_constant( LOG_ADAPTER_STATE_INVALID, exc_info=exc, extra={ "env_id": self.id, "context": "episode_info_reset", }, ) self.log_constant( LOG_ADAPTER_ENV_RESET, extra={ "env_id": self.id, "seed": seed if seed is not None else "None", "has_options": bool(options), }, ) return self._package_step(observation, 0.0, False, False, info)
[docs] def step(self, action: ActionT) -> AdapterStep[ObservationT]: env = self._require_env() observation, reward, terminated, truncated, info = env.step(action) # Update episode counters and mirror xuance-style info keys try: r = float(reward) if isinstance(reward, (int, float)) else 0.0 except Exception: r = 0.0 self._episode_step += 1 self._episode_return += r try: info = dict(info) if isinstance(info, Mapping) else {} info["episode_step"] = self._episode_step info["episode_score"] = self._episode_return except Exception as exc: # pragma: no cover - defensive # Surface unexpected info-shape/coercion issues for observability self.log_constant( LOG_ADAPTER_STATE_INVALID, exc_info=exc, extra={ "env_id": self.id, "context": "episode_info_step", "episode_step": self._episode_step, "episode_return": self._episode_return, }, ) self.log_constant( LOG_ADAPTER_STEP_SUMMARY, extra={ "env_id": self.id, "action": repr(action), "reward": float(reward) if isinstance(reward, (int, float)) else repr(reward), "terminated": terminated, "truncated": truncated, }, ) return self._package_step(observation, float(reward), terminated, truncated, info)
[docs] def close(self) -> None: if self._env is not None: self.log_constant( LOG_ADAPTER_ENV_CLOSED, extra={"env_id": self.id}, ) self._env.close() self._env = None self._space_signature = None self._vector_metadata = None
# ------------------------------------------------------------------ # Protected helpers # ------------------------------------------------------------------ def _require_env(self) -> gym.Env[Any, Any]: if self._env is None: raise AdapterNotReadyError(f"Adapter '{self.id}' has not been loaded.") return self._env def _set_env(self, env: gym.Env[Any, Any]) -> None: self._env = env self._space_signature = self._build_space_signature(env) self._vector_metadata = describe_vector_environment(env)
[docs] def render(self) -> Any: env = self._require_env() return env.render()
[docs] def gym_kwargs(self) -> dict[str, Any]: """Keyword arguments forwarded to :func:`gymnasium.make`.""" return {}
[docs] def apply_wrappers(self, env: gym.Env[Any, Any]) -> gym.Env[Any, Any]: """Hook for subclasses to apply Gymnasium wrappers before use.""" return env
def _package_step( self, observation: ObservationT, reward: float, terminated: bool, truncated: bool, info: Mapping[str, Any], ) -> AdapterStep[ObservationT]: state = self.build_step_state(observation, info) render_payload: Any | None = None try: render_payload = self.render() except Exception as exc: self.log_constant( LOG_ADAPTER_RENDER_ERROR, exc_info=exc, extra={ "env_id": self.id, "state_snapshot": bool(state.raw), }, ) render_hint = None try: render_hint = self.build_render_hint(observation, info, state) except Exception as exc: self.log_constant( LOG_ADAPTER_PAYLOAD_ERROR, exc_info=exc, extra={ "env_id": self.id, "context": "render_hint", }, ) frame_ref = None try: frame_ref = self.build_frame_reference(render_payload, state) except Exception as exc: self.log_constant( LOG_ADAPTER_PAYLOAD_ERROR, exc_info=exc, extra={ "env_id": self.id, "context": "frame_reference", }, ) return AdapterStep( observation=observation, reward=reward, terminated=terminated, truncated=truncated, info=info, render_payload=render_payload, render_hint=render_hint, agent_id=state.active_agent, frame_ref=frame_ref, payload_version=self.telemetry_payload_version(), state=state, )
[docs] def build_step_state(self, observation: ObservationT, info: Mapping[str, Any]) -> StepState: """Construct the canonical :class:`StepState` for the current step.""" return StepState()
[docs] def build_render_hint( self, observation: ObservationT, info: Mapping[str, Any], state: StepState, ) -> Mapping[str, Any] | None: """Return lightweight render metadata for downstream consumers.""" hint: dict[str, Any] = {} if state.active_agent: hint["active_agent"] = state.active_agent if state.metrics: hint["metrics"] = dict(state.metrics) if state.environment: hint["environment"] = dict(state.environment) if state.inventory: hint["inventory"] = dict(state.inventory) return hint or None
[docs] def build_frame_reference(self, render_payload: Any | None, state: StepState) -> str | None: """Optional hook to derive an external frame reference for media pipelines.""" del render_payload, state return None
[docs] def telemetry_payload_version(self) -> int: """Version marker for downstream telemetry consumers.""" return 1
# ------------------------------------------------------------------ # Mouse delta support (optional; overridden by adapters that need it) # ------------------------------------------------------------------
[docs] def has_mouse_delta_support(self) -> bool: """Return True if the adapter supports FPS-style mouse delta control. Subclasses that provide mouse-look controls (e.g., ViZDoom) should override this to return True. """ return False
[docs] def apply_mouse_delta(self, delta_x: float, delta_y: float) -> None: """Apply mouse movement deltas for FPS-style control. Subclasses override to queue or immediately apply mouse motion. The base implementation is a no-op to keep Pylance satisfied when a generic EnvironmentAdapter is referenced. """ del delta_x, delta_y
# ------------------------------------------------------------------ # Optional utilities # ------------------------------------------------------------------
[docs] def supports_control_mode(self, mode: ControlMode) -> bool: return mode in self.supported_control_modes
[docs] def ensure_control_mode(self, mode: ControlMode) -> None: if not self.supports_control_mode(mode): raise UnsupportedModeError( f"Adapter '{self.id}' does not support control mode '{mode.value}'." )
[docs] def supports_render_mode(self, mode: RenderMode) -> bool: if self.supported_render_modes: return mode in self.supported_render_modes return mode == self._resolve_default_render_mode()
def _resolve_default_render_mode(self) -> RenderMode: """Return the configured default render mode, raising if missing.""" default_mode = getattr(self, "default_render_mode", None) if not isinstance(default_mode, RenderMode): raise AdapterNotReadyError( f"Adapter '{self.id}' must set 'default_render_mode' before loading." ) return default_mode # ------------------------------------------------------------------ # Convenience accessors # ------------------------------------------------------------------ @property def context(self) -> AdapterContext | None: return self._context @property def logger(self) -> logging.Logger: """Return the module-level logger for backward compatibility.""" return _LOGGER @property def settings(self) -> Any | None: return self._context.settings if self._context else None @property def action_space(self) -> gym.Space[Any]: return self._require_env().action_space @property def observation_space(self) -> gym.Space[Any]: return self._require_env().observation_space @property def space_signature(self) -> Mapping[str, Any] | None: if self._space_signature is None and self._env is not None: self._space_signature = self._build_space_signature(self._env) return self._space_signature @property def vector_metadata(self) -> Mapping[str, Any] | None: if self._vector_metadata is None and self._env is not None: self._vector_metadata = describe_vector_environment(self._env) return self._vector_metadata
[docs] def elapsed_steps(self) -> int | None: env = self._env visited: set[int] = set() while env is not None: elapsed = getattr(env, "_elapsed_steps", None) if elapsed is not None: try: return int(elapsed) except (TypeError, ValueError): # pragma: no cover - defensive return None next_env = getattr(env, "unwrapped", None) if next_env is None or id(next_env) in visited: break visited.add(id(env)) env = next_env return None
def _build_space_signature(self, env: gym.Env[Any, Any]) -> Mapping[str, Any] | None: try: observation = describe_space(env.observation_space) action = describe_space(env.action_space) except Exception: # pragma: no cover - best-effort metadata capture return None return {"observation": observation, "action": action}
__all__ = [ "AdapterContext", "AdapterStep", "AdapterNotReadyError", "UnsupportedModeError", "EnvironmentAdapter", "StepState", "AgentSnapshot", ]