changed to simple reference and set up agents for decision making
This commit is contained in:
69
main.py
69
main.py
@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
|
||||
import gymnasium as gym
|
||||
from pettingzoo.mpe import simple_crypto_v3
|
||||
from pettingzoo.mpe import simple_reference_v3
|
||||
|
||||
class A2C(nn.Module):
|
||||
def __init__(
|
||||
@ -73,10 +73,9 @@ class A2C(nn.Module):
|
||||
masks: torch.Tensor,
|
||||
gamma: float,
|
||||
ent_coef: float,
|
||||
device: torch.device
|
||||
) -> tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
advantages = torch.zeros(len(rewards), device=device)
|
||||
advantages = torch.zeros(len(rewards), device=self.device)
|
||||
#compute advantages
|
||||
#mask - 0 if end of episode
|
||||
#gamma - coeffecient for value prediction
|
||||
@ -101,7 +100,7 @@ class A2C(nn.Module):
|
||||
self.actor_optim.step()
|
||||
|
||||
#environment hyperparams
|
||||
n_episodes = 10
|
||||
n_episodes = 1
|
||||
|
||||
#agent hyperparams
|
||||
gamma = 0.999
|
||||
@ -110,7 +109,7 @@ actor_lr = 0.001
|
||||
critic_lr = 0.005
|
||||
|
||||
#environment setup
|
||||
env = simple_crypto_v3.parallel_env(render_mode="human")
|
||||
env = simple_reference_v3.parallel_env(render_mode="human")
|
||||
|
||||
#obs_space
|
||||
#action_space
|
||||
@ -124,12 +123,60 @@ device = torch.device("cpu")
|
||||
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
|
||||
# env, buffer_length=n_episodes
|
||||
#)
|
||||
observations, infos = env.reset()
|
||||
done = False
|
||||
while env.agents:
|
||||
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||
|
||||
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||
|
||||
for _ in range(0, n_episodes):
|
||||
observations, infos = env.reset()
|
||||
agent_0_rewards = []
|
||||
agent_0_probs = []
|
||||
agent_0_pred = []
|
||||
agent_0_ents = []
|
||||
agent_0_mask = []
|
||||
|
||||
agent_1_rewards = []
|
||||
agent_1_probs = []
|
||||
agent_1_pred = []
|
||||
agent_1_ents = []
|
||||
agent_1_mask = []
|
||||
while env.agents:
|
||||
|
||||
actions = {}
|
||||
#eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"])
|
||||
#bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"])
|
||||
#alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"])
|
||||
#actions["eve_0"] = eve_action.item()
|
||||
#actions["bob_0"] = bob_action.item()
|
||||
#actions["alice_0"] = alice_action.item()
|
||||
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(observations["agent_0"])
|
||||
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(observations["agent_1"])
|
||||
actions["agent_0"] = agent_0_action.item()
|
||||
actions["agent_1"] = agent_1_action.item()
|
||||
observations, rewards, terminations, truncations, infos = env.step(actions)
|
||||
agent_0_rewards.append(rewards["agent_0"])
|
||||
agent_0_probs.append(agent_0_log_probs)
|
||||
agent_0_pred.append(agent_0_state_val)
|
||||
agent_0_ents.append(agent_0_ent)
|
||||
agent_0_mask.append( 1 if env.agents else 0)
|
||||
|
||||
agent_1_rewards.append(rewards["agent_1"])
|
||||
agent_1_probs.append(agent_1_log_probs)
|
||||
agent_1_pred.append(agent_1_state_val)
|
||||
agent_1_ents.append(agent_1_ent)
|
||||
agent_1_mask.append( 1 if env.agents else 0)
|
||||
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
|
||||
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
|
||||
#eve.update_params(eve_closs, eve_aloss)
|
||||
agent_0_closs, agent_0_aloss = agent0.get_losses(torch.Tensor(agent_0_rewards), torch.Tensor(agent_0_probs), torch.Tensor(agent_0_pred), torch.Tensor(agent_0_ents), torch.Tensor(agent_0_mask), gamma, ent_coef)
|
||||
print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
|
||||
agent0.update_params(agent_0_closs, agent_0_aloss)
|
||||
agent_1_closs, agent_1_aloss = agent1.get_losses(torch.Tensor(agent_1_rewards), torch.Tensor(agent_1_probs), torch.Tensor(agent_1_pred), torch.Tensor(agent_1_ents), torch.Tensor(agent_1_mask), gamma, ent_coef)
|
||||
print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
|
||||
agent1.update_params(agent_1_closs, agent_1_aloss)
|
||||
|
||||
actions = {agent: env.action_space(agent).sample() for agent in env.agents}
|
||||
observations, rewards, terminations, truncations, infos = env.step(actions)
|
||||
print(observations)
|
||||
|
||||
env.close()
|
Reference in New Issue
Block a user