diff --git a/main.py b/main.py index 6bd63fc..56792b6 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ import matplotlib.pyplot as plt from tqdm import tqdm import gymnasium as gym -from pettingzoo.mpe import simple_crypto_v3 +from pettingzoo.mpe import simple_reference_v3 class A2C(nn.Module): def __init__( @@ -73,10 +73,9 @@ class A2C(nn.Module): masks: torch.Tensor, gamma: float, ent_coef: float, - device: torch.device ) -> tuple[torch.tensor, torch.tensor]: - advantages = torch.zeros(len(rewards), device=device) + advantages = torch.zeros(len(rewards), device=self.device) #compute advantages #mask - 0 if end of episode #gamma - coeffecient for value prediction @@ -101,7 +100,7 @@ class A2C(nn.Module): self.actor_optim.step() #environment hyperparams -n_episodes = 10 +n_episodes = 1 #agent hyperparams gamma = 0.999 @@ -110,7 +109,7 @@ actor_lr = 0.001 critic_lr = 0.005 #environment setup -env = simple_crypto_v3.parallel_env(render_mode="human") +env = simple_reference_v3.parallel_env(render_mode="human") #obs_space #action_space @@ -124,12 +123,60 @@ device = torch.device("cpu") #env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics( # env, buffer_length=n_episodes #) -observations, infos = env.reset() -done = False -while env.agents: - - actions = {agent: env.action_space(agent).sample() for agent in env.agents} - observations, rewards, terminations, truncations, infos = env.step(actions) - print(observations) +#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) +#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) +#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) + +agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) +agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) + +for _ in range(0, n_episodes): + observations, infos = env.reset() + agent_0_rewards = [] + agent_0_probs = [] + agent_0_pred = [] + agent_0_ents = [] + agent_0_mask = [] + + agent_1_rewards = [] + agent_1_probs = [] + agent_1_pred = [] + agent_1_ents = [] + agent_1_mask = [] + while env.agents: + + actions = {} + #eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"]) + #bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"]) + #alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"]) + #actions["eve_0"] = eve_action.item() + #actions["bob_0"] = bob_action.item() + #actions["alice_0"] = alice_action.item() + agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(observations["agent_0"]) + agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(observations["agent_1"]) + actions["agent_0"] = agent_0_action.item() + actions["agent_1"] = agent_1_action.item() + observations, rewards, terminations, truncations, infos = env.step(actions) + agent_0_rewards.append(rewards["agent_0"]) + agent_0_probs.append(agent_0_log_probs) + agent_0_pred.append(agent_0_state_val) + agent_0_ents.append(agent_0_ent) + agent_0_mask.append( 1 if env.agents else 0) + + agent_1_rewards.append(rewards["agent_1"]) + agent_1_probs.append(agent_1_log_probs) + agent_1_pred.append(agent_1_state_val) + agent_1_ents.append(agent_1_ent) + agent_1_mask.append( 1 if env.agents else 0) + #eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef) + #print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item())) + #eve.update_params(eve_closs, eve_aloss) + agent_0_closs, agent_0_aloss = agent0.get_losses(torch.Tensor(agent_0_rewards), torch.Tensor(agent_0_probs), torch.Tensor(agent_0_pred), torch.Tensor(agent_0_ents), torch.Tensor(agent_0_mask), gamma, ent_coef) + print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item())) + agent0.update_params(agent_0_closs, agent_0_aloss) + agent_1_closs, agent_1_aloss = agent1.get_losses(torch.Tensor(agent_1_rewards), torch.Tensor(agent_1_probs), torch.Tensor(agent_1_pred), torch.Tensor(agent_1_ents), torch.Tensor(agent_1_mask), gamma, ent_coef) + print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item())) + agent1.update_params(agent_1_closs, agent_1_aloss) + env.close() \ No newline at end of file