import os import numpy as np import torch import torch.nn as nn from torch import optim import matplotlib.pyplot as plt from tqdm import tqdm import gymnasium as gym from pettingzoo.mpe import simple_reference_v3 import pettingzoo class A2C(nn.Module): def __init__( self, n_features: int, n_actions: int, device: torch.device, critic_lr: float, actor_lr: float ) -> None: super().__init__() self.device = device critic_layers = [ nn.Linear(n_features, 128), nn.ReLU(), nn.Linear(128,256), nn.ReLU(), nn.Linear(256,128), nn.ReLU(), nn.Linear(128, 1), ] actor_layers = [ nn.Linear(n_features, 128), nn.ReLU(), nn.Linear(128,256), nn.ReLU(), nn.Linear(256,128), nn.ReLU(), nn.Linear(128, n_actions), nn.Softmax() #nn.Sigmoid() ] self.critic = nn.Sequential(*critic_layers).to(self.device) self.actor = nn.Sequential(*actor_layers).to(self.device) self.critic_optim = optim.Adam(self.critic.parameters(), lr=critic_lr) self.actor_optim = optim.Adam(self.actor.parameters(), lr=actor_lr) self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=1) self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=1) def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]: x = torch.Tensor(x).to(self.device) state_values = self.critic(x) action_logits_vec = self.actor(x) return (state_values, action_logits_vec) def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]: state_values, action_logits = self.forward(x) action_pd = torch.distributions.Categorical( logits=action_logits ) actions = action_pd.sample() #actions = torch.multinomial(action_logits, 1).item() #action_log_probs = torch.log(action_logits.squeeze(0)[actions]) action_log_probs = action_pd.log_prob(actions) #entropy = action_logits * action_log_probs entropy = action_pd.entropy() #print(entropy.item()) #print(action_logits) return actions.item(), action_log_probs, state_values, entropy def get_losses( self, rewards: torch.Tensor, action_log_probs: torch.Tensor, value_preds: torch.Tensor, entropy: torch.Tensor, masks: torch.Tensor, gamma: float, ent_coef: float, ) -> tuple[torch.tensor, torch.tensor]: T = len(rewards) advantages = torch.zeros(T, device=device) # compute the advantages using GAE gae = 0.0 for t in reversed(range(T - 1)): td_error = ( rewards[t] + gamma * masks[t] * value_preds[t+1] - value_preds[t] ) gae = td_error + gamma * 0.95 * masks[t] * gae advantages[t] = gae # calculate the loss of the minibatch for actor and critic critic_loss = advantages.pow(2).mean() #give a bonus for higher entropy to encourage exploration actor_loss = ( -(advantages.detach() * action_log_probs).mean() - ent_coef * torch.Tensor(entropy).mean() ) #advantages = torch.zeros(len(rewards), device=self.device) #compute advantages #mask - 0 if end of episode #gamma - coeffecient for value prediction #rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5) #for t in range(len(rewards) - 1): #advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t])) #print(advantages[t]) #rewards[t] + masks[t] * gamma * value_preds[t+1] #(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t])) #rewards = np.array(rewards) #rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5) #returns = [] #R = 0 #for r, mask in zip(reversed(rewards), reversed(masks)): # R = r + gamma * R * mask # returns.insert(0, R) #returns = torch.FloatTensor(returns) #values = torch.stack(value_preds).squeeze(1) #advantage = returns - values #calculate critic loss - MSE #critic_loss = advantages.pow(2).mean() #critic_loss = advantages.pow(2).mean() #calculate actor loss - give bonus for entropy to encourage exploration #actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean() #entropy = -torch.stack(entropy).sum(dim=-1).mean() #actor_loss = (-action_log_probs * advantages.detach()).mean() - ent_coef * torch.Tensor(entropy).mean() #print(action_log_probs) #print(actor_loss) return (critic_loss, actor_loss) def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None: self.critic_optim.zero_grad() critic_loss.backward() torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5) self.critic_optim.step() self.critic_scheduler.step() self.actor_optim.zero_grad() actor_loss.backward() torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5) self.actor_optim.step() self.actor_scheduler.step() def set_eval(self): self.critic.eval() self.actor.eval() #environment hyperparams n_episodes = 100 #agent hyperparams gamma = 0.999 ent_coef = 0.01 # coefficient for entropy bonus actor_lr = 0.001 critic_lr = 0.005 #environment setup #env = simple_reference_v3.parallel_env(render_mode="human") env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array") #obs_space #action_space device = torch.device("cpu") #init the agent #agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs) #wrapper to record statistics #env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics( # env, buffer_length=n_episodes #) #eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) #bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) #alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) #print(env.action_space("agent_0").n) #print(env.observation_space("agent_0")) agent0_critic_loss = [] agent0_actor_loss = [] agent1_critic_loss = [] agent1_actor_loss = [] agent0_rewards = [] agent1_rewards = [] for episode in tqdm(range(n_episodes)): #print("Episode " + str(episode) + "/" + str(n_episodes)) observations, infos = env.reset() agent_0_rewards = [] agent_0_probs = [] agent_0_pred = [] agent_0_ents = [] agent_0_mask = [] agent_1_rewards = [] agent_1_probs = [] agent_1_pred = [] agent_1_ents = [] agent_1_mask = [] while env.agents: actions = {} #eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"]) #bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"]) #alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"]) #actions["eve_0"] = eve_action.item() #actions["bob_0"] = bob_action.item() #actions["alice_0"] = alice_action.item() agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0)) agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0)) actions["agent_0"] = agent_0_action actions["agent_1"] = agent_1_action observations, rewards, terminations, truncations, infos = env.step(actions) #print(rewards) agent_0_rewards.append(rewards["agent_0"]) agent_0_probs.append(agent_0_log_probs) agent_0_pred.append(agent_0_state_val) agent_0_ents.append(agent_0_ent) agent_0_mask.append( 1 if env.agents else 0) agent_1_rewards.append(rewards["agent_1"]) agent_1_probs.append(agent_1_log_probs) agent_1_pred.append(agent_1_state_val) agent_1_ents.append(agent_1_ent) agent_1_mask.append( 1 if env.agents else 0) #eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef) #print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item())) #eve.update_params(eve_closs, eve_aloss) agent_0_closs, agent_0_aloss = agent0.get_losses(agent_0_rewards, torch.stack(agent_0_probs), agent_0_pred, agent_0_ents, agent_0_mask, gamma, ent_coef) #print(agent_0_rewards) agent0_critic_loss.append(agent_0_closs.item()) agent0_actor_loss.append(agent_0_aloss.item()) #print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item())) agent0.update_params(agent_0_closs, agent_0_aloss) agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef) agent1_critic_loss.append(agent_1_closs.item()) agent1_actor_loss.append(agent_1_aloss.item()) #print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item())) agent1.update_params(agent_1_closs, agent_1_aloss) agent0_rewards.append(np.array(agent_0_rewards).sum()) agent1_rewards.append(np.array(agent_1_rewards).sum()) if episode % 500 == 0: #rolling_length = 20 #fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,5)) #fig.suptitle( # f"training plots for the Simple Reference environment" #) plt.plot(agent0_rewards, label="Agent 0 Rewards") plt.plot(agent1_rewards, label="Agent 1 Rewards") plt.legend() plt.savefig('data.png') plt.clf() plt.plot(agent0_rewards, label="Agent 0 Rewards") plt.plot(agent1_rewards, label="Agent 1 Rewards") plt.legend() plt.savefig('data.png') plt.show(block=False) #plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss") #plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss") #plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss") #plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss") actor0_weights_path = "weights/actor0_weights.h5" critic0_weights_path = "weights/critic0_weights.h5" actor1_weights_path = "weights/actor1_weights.h5" critic1_weights_path = "weights/critic1_weights.h5" if not os.path.exists("weights"): os.mkdir("weights") torch.save(agent0.actor.state_dict(), actor0_weights_path) torch.save(agent0.critic.state_dict(), critic0_weights_path) torch.save(agent1.actor.state_dict(), actor1_weights_path) torch.save(agent1.critic.state_dict(), critic1_weights_path) #if load_weights: # agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr) # # agent.actor.load_state_dict(torch.load(actor_weights_path)) # agent.critic.load_state_dict(torch.load(critic_weights_path)) # agent.actor.eval() # agent.critic.eval() agent0.set_eval() agent1.set_eval() env = simple_reference_v3.parallel_env(render_mode="human") while True: observations, infos = env.reset() while env.agents: plt.pause(0.001) actions = {} agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0)) agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0)) actions["agent_0"] = agent_0_action actions["agent_1"] = agent_1_action observations, rewards, terminations, truncations, infos = env.step(actions) env.close()