changed to simple reference and set up agents for decision making

This commit is contained in:
2025-08-31 13:51:49 -06:00
parent 321225cf88
commit 33b2581c48

71
main.py
View File

@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
from tqdm import tqdm
import gymnasium as gym
from pettingzoo.mpe import simple_crypto_v3
from pettingzoo.mpe import simple_reference_v3
class A2C(nn.Module):
def __init__(
@ -73,10 +73,9 @@ class A2C(nn.Module):
masks: torch.Tensor,
gamma: float,
ent_coef: float,
device: torch.device
) -> tuple[torch.tensor, torch.tensor]:
advantages = torch.zeros(len(rewards), device=device)
advantages = torch.zeros(len(rewards), device=self.device)
#compute advantages
#mask - 0 if end of episode
#gamma - coeffecient for value prediction
@ -101,7 +100,7 @@ class A2C(nn.Module):
self.actor_optim.step()
#environment hyperparams
n_episodes = 10
n_episodes = 1
#agent hyperparams
gamma = 0.999
@ -110,7 +109,7 @@ actor_lr = 0.001
critic_lr = 0.005
#environment setup
env = simple_crypto_v3.parallel_env(render_mode="human")
env = simple_reference_v3.parallel_env(render_mode="human")
#obs_space
#action_space
@ -124,12 +123,60 @@ device = torch.device("cpu")
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
# env, buffer_length=n_episodes
#)
observations, infos = env.reset()
done = False
while env.agents:
actions = {agent: env.action_space(agent).sample() for agent in env.agents}
observations, rewards, terminations, truncations, infos = env.step(actions)
print(observations)
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
for _ in range(0, n_episodes):
observations, infos = env.reset()
agent_0_rewards = []
agent_0_probs = []
agent_0_pred = []
agent_0_ents = []
agent_0_mask = []
agent_1_rewards = []
agent_1_probs = []
agent_1_pred = []
agent_1_ents = []
agent_1_mask = []
while env.agents:
actions = {}
#eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"])
#bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"])
#alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"])
#actions["eve_0"] = eve_action.item()
#actions["bob_0"] = bob_action.item()
#actions["alice_0"] = alice_action.item()
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(observations["agent_0"])
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(observations["agent_1"])
actions["agent_0"] = agent_0_action.item()
actions["agent_1"] = agent_1_action.item()
observations, rewards, terminations, truncations, infos = env.step(actions)
agent_0_rewards.append(rewards["agent_0"])
agent_0_probs.append(agent_0_log_probs)
agent_0_pred.append(agent_0_state_val)
agent_0_ents.append(agent_0_ent)
agent_0_mask.append( 1 if env.agents else 0)
agent_1_rewards.append(rewards["agent_1"])
agent_1_probs.append(agent_1_log_probs)
agent_1_pred.append(agent_1_state_val)
agent_1_ents.append(agent_1_ent)
agent_1_mask.append( 1 if env.agents else 0)
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
#eve.update_params(eve_closs, eve_aloss)
agent_0_closs, agent_0_aloss = agent0.get_losses(torch.Tensor(agent_0_rewards), torch.Tensor(agent_0_probs), torch.Tensor(agent_0_pred), torch.Tensor(agent_0_ents), torch.Tensor(agent_0_mask), gamma, ent_coef)
print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
agent0.update_params(agent_0_closs, agent_0_aloss)
agent_1_closs, agent_1_aloss = agent1.get_losses(torch.Tensor(agent_1_rewards), torch.Tensor(agent_1_probs), torch.Tensor(agent_1_pred), torch.Tensor(agent_1_ents), torch.Tensor(agent_1_mask), gamma, ent_coef)
print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
agent1.update_params(agent_1_closs, agent_1_aloss)
env.close()