import os import numpy as np import torch import torch.nn as nn from torch import optim import matplotlib.pyplot as plt from tqdm import tqdm import gymnasium as gym from pettingzoo.mpe import simple_crypto_v3 class A2C(nn.Module): def __init__( self, n_features: int, n_actions: int, device: torch.device, critic_lr: float, actor_lr: float ) -> None: super().__init__() self.device = device critic_layers = [ nn.Linear(n_features, 8), nn.ReLU(), nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 1) ] actor_layers = [ nn.Linear(n_features, 8), nn.ReLU(), nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, n_actions), nn.Softmax() ] self.critic = nn.Sequential(*critic_layers).to(self.device) self.actor = nn.Sequential(*actor_layers).to(self.device) self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr) self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr) def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]: x = torch.Tensor(x).to(self.device) state_values = self.critic(x) action_logits_vec = self.actor(x) return (state_values, action_logits_vec) def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]: state_values, action_logits = self.forward(x) action_pd = torch.distributions.Categorical( logits=action_logits ) actions = action_pd.sample() action_log_probs = action_pd.log_prob(actions) entropy = action_pd.entropy() return actions, action_log_probs, state_values, entropy def get_losses(self): pass def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None: self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() #environment hyperparams n_episodes = 10 #agent hyperparams ent_coef = 0.01 # coefficient for entropy bonus actor_lr = 0.001 critic_lr = 0.005 #environment setup env = simple_crypto_v3.parallel_env(render_mode="human") #obs_space #action_space device = torch.device("cpu") #init the agent #agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs) #wrapper to record statistics #env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics( # env, buffer_length=n_episodes #) observations, infos = env.reset() done = False while env.agents: actions = {agent: env.action_space(agent).sample() for agent in env.agents} observations, rewards, terminations, truncations, infos = env.step(actions) print(observations) env.close()