diff --git a/main.py b/main.py index b288d8f..c13ca66 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ from tqdm import tqdm import gymnasium as gym from pettingzoo.mpe import simple_reference_v3 +import pettingzoo class A2C(nn.Module): def __init__( @@ -25,26 +26,35 @@ class A2C(nn.Module): self.device = device critic_layers = [ - nn.Linear(n_features, 128), + nn.Linear(n_features, 64), + nn.ReLU(), + nn.Linear(64,128), + nn.ReLU(), + nn.Linear(128,128), nn.ReLU(), nn.Linear(128, 1), ] actor_layers = [ - nn.Linear(n_features, 128), + nn.Linear(n_features, 64), + nn.ReLU(), + nn.Linear(64,128), + nn.ReLU(), + nn.Linear(128,128), nn.ReLU(), nn.Linear(128, n_actions), nn.Softmax() + #nn.Sigmoid() ] self.critic = nn.Sequential(*critic_layers).to(self.device) self.actor = nn.Sequential(*actor_layers).to(self.device) - self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr) - self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr) + self.critic_optim = optim.Adam(self.critic.parameters(), lr=critic_lr) + self.actor_optim = optim.Adam(self.actor.parameters(), lr=actor_lr) - self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=0.9) - self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=0.9) + self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=1) + self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=1) def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]: x = torch.Tensor(x).to(self.device) @@ -55,14 +65,18 @@ class A2C(nn.Module): def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]: state_values, action_logits = self.forward(x) - #action_pd = torch.distributions.Categorical( - # logits=action_logits - #) - #actions = action_pd.sample() - actions = torch.multinomial(action_logits, 1).item() - action_log_probs = torch.log(action_logits.squeeze(0)[actions]) - entropy = action_logits * action_log_probs#action_pd.entropy() - return actions, action_log_probs, state_values, entropy + action_pd = torch.distributions.Categorical( + logits=action_logits + ) + actions = action_pd.sample() + #actions = torch.multinomial(action_logits, 1).item() + #action_log_probs = torch.log(action_logits.squeeze(0)[actions]) + action_log_probs = action_pd.log_prob(actions) + #entropy = action_logits * action_log_probs + entropy = action_pd.entropy() + #print(entropy.item()) + #print(action_logits) + return actions.item(), action_log_probs, state_values, entropy def get_losses( self, @@ -75,34 +89,60 @@ class A2C(nn.Module): ent_coef: float, ) -> tuple[torch.tensor, torch.tensor]: - advantages = torch.zeros(len(rewards), device=self.device) + T = len(rewards) + advantages = torch.zeros(T, device=device) + + # compute the advantages using GAE + gae = 0.0 + for t in reversed(range(T - 1)): + td_error = ( + rewards[t] + gamma * masks[t] * value_preds[t+1] - value_preds[t] + ) + gae = td_error + gamma * 0.95 * masks[t] * gae + advantages[t] = gae + + # calculate the loss of the minibatch for actor and critic + critic_loss = advantages.pow(2).mean() + + #give a bonus for higher entropy to encourage exploration + actor_loss = ( + -(advantages.detach() * action_log_probs).mean() - ent_coef * torch.Tensor(entropy).mean() + ) + + + #advantages = torch.zeros(len(rewards), device=self.device) #compute advantages #mask - 0 if end of episode #gamma - coeffecient for value prediction + #rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5) #for t in range(len(rewards) - 1): - # advantages[t] = rewards[t] + masks[t] * gamma * value_preds[t+1]#(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t])) - rewards = np.array(rewards) - rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5) - returns = [] - R = 0 - for r, mask in zip(reversed(rewards), reversed(masks)): - R = r + gamma * R * mask - returns.insert(0, R) + #advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t])) + #print(advantages[t]) + #rewards[t] + masks[t] * gamma * value_preds[t+1] + #(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t])) + #rewards = np.array(rewards) + #rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5) + #returns = [] + #R = 0 + #for r, mask in zip(reversed(rewards), reversed(masks)): + # R = r + gamma * R * mask + # returns.insert(0, R) - returns = torch.FloatTensor(returns) - values = torch.stack(value_preds).squeeze(1) + #returns = torch.FloatTensor(returns) + #values = torch.stack(value_preds).squeeze(1) - advantage = returns - values + #advantage = returns - values #calculate critic loss - MSE #critic_loss = advantages.pow(2).mean() - critic_loss = advantage.pow(2).mean() + #critic_loss = advantages.pow(2).mean() #calculate actor loss - give bonus for entropy to encourage exploration #actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean() - entropy = -torch.stack(entropy).sum(dim=-1).mean() - actor_loss = -(action_log_probs * advantage.detach()).mean() - ent_coef * entropy - + #entropy = -torch.stack(entropy).sum(dim=-1).mean() + #actor_loss = (-action_log_probs * advantages.detach()).mean() - ent_coef * torch.Tensor(entropy).mean() + #print(action_log_probs) + #print(actor_loss) return (critic_loss, actor_loss) def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None: @@ -123,7 +163,7 @@ class A2C(nn.Module): self.actor.eval() #environment hyperparams -n_episodes = 10000 +n_episodes = 1000 #agent hyperparams gamma = 0.999 @@ -133,7 +173,9 @@ critic_lr = 0.005 #environment setup #env = simple_reference_v3.parallel_env(render_mode="human") -env = simple_reference_v3.parallel_env() +env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array") + + #obs_space #action_space @@ -153,11 +195,14 @@ device = torch.device("cpu") agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) - +#print(env.action_space("agent_0").n) +#print(env.observation_space("agent_0")) agent0_critic_loss = [] agent0_actor_loss = [] agent1_critic_loss = [] agent1_actor_loss = [] +agent0_rewards = [] +agent1_rewards = [] for episode in range(0, n_episodes): print("Episode " + str(episode) + "/" + str(n_episodes)) @@ -187,6 +232,7 @@ for episode in range(0, n_episodes): actions["agent_0"] = agent_0_action actions["agent_1"] = agent_1_action observations, rewards, terminations, truncations, infos = env.step(actions) + #print(rewards) agent_0_rewards.append(rewards["agent_0"]) agent_0_probs.append(agent_0_log_probs) agent_0_pred.append(agent_0_state_val) @@ -213,10 +259,15 @@ for episode in range(0, n_episodes): #print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item())) agent1.update_params(agent_1_closs, agent_1_aloss) -plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss") -plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss") -plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss") -plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss") + agent0_rewards.append(np.array(agent_0_rewards).sum()) + agent1_rewards.append(np.array(agent_1_rewards).sum()) + +#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss") +#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss") +#plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss") +#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss") +plt.plot(agent0_rewards, label="Agent 0 Rewards") +plt.plot(agent1_rewards, label="Agent 1 Rewards") plt.legend() plt.show(block=False)