changed to simple reference and set up agents for decision making
This commit is contained in:
69
main.py
69
main.py
@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from pettingzoo.mpe import simple_crypto_v3
|
from pettingzoo.mpe import simple_reference_v3
|
||||||
|
|
||||||
class A2C(nn.Module):
|
class A2C(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -73,10 +73,9 @@ class A2C(nn.Module):
|
|||||||
masks: torch.Tensor,
|
masks: torch.Tensor,
|
||||||
gamma: float,
|
gamma: float,
|
||||||
ent_coef: float,
|
ent_coef: float,
|
||||||
device: torch.device
|
|
||||||
) -> tuple[torch.tensor, torch.tensor]:
|
) -> tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
advantages = torch.zeros(len(rewards), device=device)
|
advantages = torch.zeros(len(rewards), device=self.device)
|
||||||
#compute advantages
|
#compute advantages
|
||||||
#mask - 0 if end of episode
|
#mask - 0 if end of episode
|
||||||
#gamma - coeffecient for value prediction
|
#gamma - coeffecient for value prediction
|
||||||
@ -101,7 +100,7 @@ class A2C(nn.Module):
|
|||||||
self.actor_optim.step()
|
self.actor_optim.step()
|
||||||
|
|
||||||
#environment hyperparams
|
#environment hyperparams
|
||||||
n_episodes = 10
|
n_episodes = 1
|
||||||
|
|
||||||
#agent hyperparams
|
#agent hyperparams
|
||||||
gamma = 0.999
|
gamma = 0.999
|
||||||
@ -110,7 +109,7 @@ actor_lr = 0.001
|
|||||||
critic_lr = 0.005
|
critic_lr = 0.005
|
||||||
|
|
||||||
#environment setup
|
#environment setup
|
||||||
env = simple_crypto_v3.parallel_env(render_mode="human")
|
env = simple_reference_v3.parallel_env(render_mode="human")
|
||||||
|
|
||||||
#obs_space
|
#obs_space
|
||||||
#action_space
|
#action_space
|
||||||
@ -124,12 +123,60 @@ device = torch.device("cpu")
|
|||||||
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
|
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
|
||||||
# env, buffer_length=n_episodes
|
# env, buffer_length=n_episodes
|
||||||
#)
|
#)
|
||||||
observations, infos = env.reset()
|
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
done = False
|
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
while env.agents:
|
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
|
||||||
|
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
|
||||||
|
for _ in range(0, n_episodes):
|
||||||
|
observations, infos = env.reset()
|
||||||
|
agent_0_rewards = []
|
||||||
|
agent_0_probs = []
|
||||||
|
agent_0_pred = []
|
||||||
|
agent_0_ents = []
|
||||||
|
agent_0_mask = []
|
||||||
|
|
||||||
|
agent_1_rewards = []
|
||||||
|
agent_1_probs = []
|
||||||
|
agent_1_pred = []
|
||||||
|
agent_1_ents = []
|
||||||
|
agent_1_mask = []
|
||||||
|
while env.agents:
|
||||||
|
|
||||||
|
actions = {}
|
||||||
|
#eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"])
|
||||||
|
#bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"])
|
||||||
|
#alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"])
|
||||||
|
#actions["eve_0"] = eve_action.item()
|
||||||
|
#actions["bob_0"] = bob_action.item()
|
||||||
|
#actions["alice_0"] = alice_action.item()
|
||||||
|
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(observations["agent_0"])
|
||||||
|
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(observations["agent_1"])
|
||||||
|
actions["agent_0"] = agent_0_action.item()
|
||||||
|
actions["agent_1"] = agent_1_action.item()
|
||||||
|
observations, rewards, terminations, truncations, infos = env.step(actions)
|
||||||
|
agent_0_rewards.append(rewards["agent_0"])
|
||||||
|
agent_0_probs.append(agent_0_log_probs)
|
||||||
|
agent_0_pred.append(agent_0_state_val)
|
||||||
|
agent_0_ents.append(agent_0_ent)
|
||||||
|
agent_0_mask.append( 1 if env.agents else 0)
|
||||||
|
|
||||||
|
agent_1_rewards.append(rewards["agent_1"])
|
||||||
|
agent_1_probs.append(agent_1_log_probs)
|
||||||
|
agent_1_pred.append(agent_1_state_val)
|
||||||
|
agent_1_ents.append(agent_1_ent)
|
||||||
|
agent_1_mask.append( 1 if env.agents else 0)
|
||||||
|
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
|
||||||
|
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
|
||||||
|
#eve.update_params(eve_closs, eve_aloss)
|
||||||
|
agent_0_closs, agent_0_aloss = agent0.get_losses(torch.Tensor(agent_0_rewards), torch.Tensor(agent_0_probs), torch.Tensor(agent_0_pred), torch.Tensor(agent_0_ents), torch.Tensor(agent_0_mask), gamma, ent_coef)
|
||||||
|
print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
|
||||||
|
agent0.update_params(agent_0_closs, agent_0_aloss)
|
||||||
|
agent_1_closs, agent_1_aloss = agent1.get_losses(torch.Tensor(agent_1_rewards), torch.Tensor(agent_1_probs), torch.Tensor(agent_1_pred), torch.Tensor(agent_1_ents), torch.Tensor(agent_1_mask), gamma, ent_coef)
|
||||||
|
print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
|
||||||
|
agent1.update_params(agent_1_closs, agent_1_aloss)
|
||||||
|
|
||||||
actions = {agent: env.action_space(agent).sample() for agent in env.agents}
|
|
||||||
observations, rewards, terminations, truncations, infos = env.step(actions)
|
|
||||||
print(observations)
|
|
||||||
|
|
||||||
env.close()
|
env.close()
|
Reference in New Issue
Block a user