import os import numpy as np import torch import torch.nn as nn from torch import optim import matplotlib.pyplot as plt from tqdm import tqdm import gymnasium as gym from pettingzoo.mpe import simple_crypto_v3 class A2C(nn.Module): def __init__( self, n_features: int, n_actions: int, device: torch.device, critic_lr: float, actor_lr: float ) -> None: super().__init__() self.device = device critic_layers = [ nn.Linear(n_features, 8), nn.ReLU(), nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 1) ] actor_layers = [ nn.Linear(n_features, 8), nn.ReLU(), nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, n_actions), nn.Softmax() ] self.critic = nn.Sequential(*critic_layers).to(self.device) self.actor = nn.Sequential(*actor_layers).to(self.device) self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr) self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr) def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]: x = torch.Tensor(x).to(self.device) state_values = self.critic(x) action_logits_vec = self.actor(x) return (state_values, action_logits_vec) def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]: state_values, action_logits = self.forward(x) action_pd = torch.distributions.Categorical( logits=action_logits ) actions = action_pd.sample() action_log_probs = action_pd.log_prob(actions) entropy = action_pd.entropy() return actions, action_log_probs, state_values, entropy def get_losses( self, rewards: torch.Tensor, action_log_probs: torch.Tensor, value_preds: torch.Tensor, entropy: torch.Tensor, masks: torch.Tensor, gamma: float, ent_coef: float, device: torch.device ) -> tuple[torch.tensor, torch.tensor]: advantages = torch.zeros(len(rewards), device=device) #compute advantages #mask - 0 if end of episode #gamma - coeffecient for value prediction for t in range(len(rewards) - 1): advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t])) #calculate critic loss - MSE critic_loss = advantages.pow(2).mean() #calculate actor loss - give bonus for entropy to encourage exploration actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean() return (critic_loss, actor_loss) def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None: self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() #environment hyperparams n_episodes = 10 #agent hyperparams gamma = 0.999 ent_coef = 0.01 # coefficient for entropy bonus actor_lr = 0.001 critic_lr = 0.005 #environment setup env = simple_crypto_v3.parallel_env(render_mode="human") #obs_space #action_space device = torch.device("cpu") #init the agent #agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs) #wrapper to record statistics #env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics( # env, buffer_length=n_episodes #) observations, infos = env.reset() done = False while env.agents: actions = {agent: env.action_space(agent).sample() for agent in env.agents} observations, rewards, terminations, truncations, infos = env.step(actions) print(observations) env.close()