made lots of changes
This commit is contained in:
115
main.py
115
main.py
@ -25,19 +25,15 @@ class A2C(nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
critic_layers = [
|
critic_layers = [
|
||||||
nn.Linear(n_features, 8),
|
nn.Linear(n_features, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(8, 8),
|
nn.Linear(128, 1),
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(8, 1)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
actor_layers = [
|
actor_layers = [
|
||||||
nn.Linear(n_features, 8),
|
nn.Linear(n_features, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(8, 8),
|
nn.Linear(128, n_actions),
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(8, n_actions),
|
|
||||||
nn.Softmax()
|
nn.Softmax()
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -47,6 +43,9 @@ class A2C(nn.Module):
|
|||||||
self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr)
|
self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr)
|
||||||
self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr)
|
self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr)
|
||||||
|
|
||||||
|
self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=0.9)
|
||||||
|
self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=0.9)
|
||||||
|
|
||||||
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
|
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
|
||||||
x = torch.Tensor(x).to(self.device)
|
x = torch.Tensor(x).to(self.device)
|
||||||
state_values = self.critic(x)
|
state_values = self.critic(x)
|
||||||
@ -56,12 +55,13 @@ class A2C(nn.Module):
|
|||||||
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
|
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
state_values, action_logits = self.forward(x)
|
state_values, action_logits = self.forward(x)
|
||||||
action_pd = torch.distributions.Categorical(
|
#action_pd = torch.distributions.Categorical(
|
||||||
logits=action_logits
|
# logits=action_logits
|
||||||
)
|
#)
|
||||||
actions = action_pd.sample()
|
#actions = action_pd.sample()
|
||||||
action_log_probs = action_pd.log_prob(actions)
|
actions = torch.multinomial(action_logits, 1).item()
|
||||||
entropy = action_pd.entropy()
|
action_log_probs = torch.log(action_logits.squeeze(0)[actions])
|
||||||
|
entropy = action_logits * action_log_probs#action_pd.entropy()
|
||||||
return actions, action_log_probs, state_values, entropy
|
return actions, action_log_probs, state_values, entropy
|
||||||
|
|
||||||
def get_losses(
|
def get_losses(
|
||||||
@ -79,28 +79,51 @@ class A2C(nn.Module):
|
|||||||
#compute advantages
|
#compute advantages
|
||||||
#mask - 0 if end of episode
|
#mask - 0 if end of episode
|
||||||
#gamma - coeffecient for value prediction
|
#gamma - coeffecient for value prediction
|
||||||
for t in range(len(rewards) - 1):
|
#for t in range(len(rewards) - 1):
|
||||||
advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
# advantages[t] = rewards[t] + masks[t] * gamma * value_preds[t+1]#(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
||||||
|
rewards = np.array(rewards)
|
||||||
|
rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
|
||||||
|
returns = []
|
||||||
|
R = 0
|
||||||
|
for r, mask in zip(reversed(rewards), reversed(masks)):
|
||||||
|
R = r + gamma * R * mask
|
||||||
|
returns.insert(0, R)
|
||||||
|
|
||||||
|
returns = torch.FloatTensor(returns)
|
||||||
|
values = torch.stack(value_preds).squeeze(1)
|
||||||
|
|
||||||
|
advantage = returns - values
|
||||||
|
|
||||||
|
|
||||||
#calculate critic loss - MSE
|
#calculate critic loss - MSE
|
||||||
critic_loss = advantages.pow(2).mean()
|
#critic_loss = advantages.pow(2).mean()
|
||||||
|
critic_loss = advantage.pow(2).mean()
|
||||||
#calculate actor loss - give bonus for entropy to encourage exploration
|
#calculate actor loss - give bonus for entropy to encourage exploration
|
||||||
actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
|
#actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
|
||||||
|
entropy = -torch.stack(entropy).sum(dim=-1).mean()
|
||||||
|
actor_loss = -(action_log_probs * advantage.detach()).mean() - ent_coef * entropy
|
||||||
|
|
||||||
return (critic_loss, actor_loss)
|
return (critic_loss, actor_loss)
|
||||||
|
|
||||||
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
|
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
critic_loss.backward()
|
critic_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
|
||||||
self.critic_optim.step()
|
self.critic_optim.step()
|
||||||
|
self.critic_scheduler.step()
|
||||||
|
|
||||||
self.actor_optim.zero_grad()
|
self.actor_optim.zero_grad()
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
|
||||||
self.actor_optim.step()
|
self.actor_optim.step()
|
||||||
|
self.actor_scheduler.step()
|
||||||
|
|
||||||
|
def set_eval(self):
|
||||||
|
self.critic.eval()
|
||||||
|
self.actor.eval()
|
||||||
|
|
||||||
#environment hyperparams
|
#environment hyperparams
|
||||||
n_episodes = 1
|
n_episodes = 10000
|
||||||
|
|
||||||
#agent hyperparams
|
#agent hyperparams
|
||||||
gamma = 0.999
|
gamma = 0.999
|
||||||
@ -109,7 +132,8 @@ actor_lr = 0.001
|
|||||||
critic_lr = 0.005
|
critic_lr = 0.005
|
||||||
|
|
||||||
#environment setup
|
#environment setup
|
||||||
env = simple_reference_v3.parallel_env(render_mode="human")
|
#env = simple_reference_v3.parallel_env(render_mode="human")
|
||||||
|
env = simple_reference_v3.parallel_env()
|
||||||
|
|
||||||
#obs_space
|
#obs_space
|
||||||
#action_space
|
#action_space
|
||||||
@ -130,7 +154,13 @@ device = torch.device("cpu")
|
|||||||
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
|
||||||
for _ in range(0, n_episodes):
|
agent0_critic_loss = []
|
||||||
|
agent0_actor_loss = []
|
||||||
|
agent1_critic_loss = []
|
||||||
|
agent1_actor_loss = []
|
||||||
|
|
||||||
|
for episode in range(0, n_episodes):
|
||||||
|
print("Episode " + str(episode) + "/" + str(n_episodes))
|
||||||
observations, infos = env.reset()
|
observations, infos = env.reset()
|
||||||
agent_0_rewards = []
|
agent_0_rewards = []
|
||||||
agent_0_probs = []
|
agent_0_probs = []
|
||||||
@ -152,10 +182,10 @@ for _ in range(0, n_episodes):
|
|||||||
#actions["eve_0"] = eve_action.item()
|
#actions["eve_0"] = eve_action.item()
|
||||||
#actions["bob_0"] = bob_action.item()
|
#actions["bob_0"] = bob_action.item()
|
||||||
#actions["alice_0"] = alice_action.item()
|
#actions["alice_0"] = alice_action.item()
|
||||||
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(observations["agent_0"])
|
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
|
||||||
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(observations["agent_1"])
|
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
|
||||||
actions["agent_0"] = agent_0_action.item()
|
actions["agent_0"] = agent_0_action
|
||||||
actions["agent_1"] = agent_1_action.item()
|
actions["agent_1"] = agent_1_action
|
||||||
observations, rewards, terminations, truncations, infos = env.step(actions)
|
observations, rewards, terminations, truncations, infos = env.step(actions)
|
||||||
agent_0_rewards.append(rewards["agent_0"])
|
agent_0_rewards.append(rewards["agent_0"])
|
||||||
agent_0_probs.append(agent_0_log_probs)
|
agent_0_probs.append(agent_0_log_probs)
|
||||||
@ -171,12 +201,37 @@ for _ in range(0, n_episodes):
|
|||||||
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
|
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
|
||||||
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
|
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
|
||||||
#eve.update_params(eve_closs, eve_aloss)
|
#eve.update_params(eve_closs, eve_aloss)
|
||||||
agent_0_closs, agent_0_aloss = agent0.get_losses(torch.Tensor(agent_0_rewards), torch.Tensor(agent_0_probs), torch.Tensor(agent_0_pred), torch.Tensor(agent_0_ents), torch.Tensor(agent_0_mask), gamma, ent_coef)
|
agent_0_closs, agent_0_aloss = agent0.get_losses(agent_0_rewards, torch.stack(agent_0_probs), agent_0_pred, agent_0_ents, agent_0_mask, gamma, ent_coef)
|
||||||
print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
|
#print(agent_0_rewards)
|
||||||
|
agent0_critic_loss.append(agent_0_closs.item())
|
||||||
|
agent0_actor_loss.append(agent_0_aloss.item())
|
||||||
|
#print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
|
||||||
agent0.update_params(agent_0_closs, agent_0_aloss)
|
agent0.update_params(agent_0_closs, agent_0_aloss)
|
||||||
agent_1_closs, agent_1_aloss = agent1.get_losses(torch.Tensor(agent_1_rewards), torch.Tensor(agent_1_probs), torch.Tensor(agent_1_pred), torch.Tensor(agent_1_ents), torch.Tensor(agent_1_mask), gamma, ent_coef)
|
agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef)
|
||||||
print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
|
agent1_critic_loss.append(agent_1_closs.item())
|
||||||
|
agent1_actor_loss.append(agent_1_aloss.item())
|
||||||
|
#print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
|
||||||
agent1.update_params(agent_1_closs, agent_1_aloss)
|
agent1.update_params(agent_1_closs, agent_1_aloss)
|
||||||
|
|
||||||
|
plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
|
||||||
|
plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
|
||||||
|
plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss")
|
||||||
|
plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
|
||||||
|
plt.legend()
|
||||||
|
plt.show(block=False)
|
||||||
|
|
||||||
|
agent0.set_eval()
|
||||||
|
agent1.set_eval()
|
||||||
|
env = simple_reference_v3.parallel_env(render_mode="human")
|
||||||
|
while True:
|
||||||
|
observations, infos = env.reset()
|
||||||
|
while env.agents:
|
||||||
|
plt.pause(0.001)
|
||||||
|
actions = {}
|
||||||
|
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
|
||||||
|
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
|
||||||
|
actions["agent_0"] = agent_0_action
|
||||||
|
actions["agent_1"] = agent_1_action
|
||||||
|
observations, rewards, terminations, truncations, infos = env.step(actions)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
Reference in New Issue
Block a user