413 lines
16 KiB
Python
413 lines
16 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import optim
|
|
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
|
|
import gymnasium as gym
|
|
from pettingzoo.mpe import simple_reference_v3
|
|
import pettingzoo
|
|
|
|
class A2C(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_features: int,
|
|
n_actions: int,
|
|
device: torch.device,
|
|
critic_lr: float,
|
|
actor_lr: float
|
|
) -> None:
|
|
|
|
super().__init__()
|
|
self.device = device
|
|
|
|
critic_layers = [
|
|
nn.Linear(n_features, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128,256),
|
|
nn.ReLU(),
|
|
nn.Linear(256,128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 1),
|
|
]
|
|
|
|
actor_layers = [
|
|
nn.Linear(n_features, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128,256),
|
|
nn.ReLU(),
|
|
nn.Linear(256,128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, n_actions),
|
|
nn.Softmax()
|
|
#nn.Sigmoid()
|
|
]
|
|
|
|
self.critic = nn.Sequential(*critic_layers).to(self.device)
|
|
self.actor = nn.Sequential(*actor_layers).to(self.device)
|
|
|
|
self.critic_optim = optim.Adam(self.critic.parameters(), lr=critic_lr)
|
|
self.actor_optim = optim.Adam(self.actor.parameters(), lr=actor_lr)
|
|
|
|
self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=1)
|
|
self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=1)
|
|
|
|
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
|
|
x = torch.Tensor(x).to(self.device)
|
|
state_values = self.critic(x)
|
|
action_logits_vec = self.actor(x)
|
|
return (state_values, action_logits_vec)
|
|
|
|
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
|
|
|
|
state_values, action_logits = self.forward(x)
|
|
action_pd = torch.distributions.Categorical(
|
|
logits=action_logits
|
|
)
|
|
actions = action_pd.sample()
|
|
#actions = torch.multinomial(action_logits, 1).item()
|
|
#action_log_probs = torch.log(action_logits.squeeze(0)[actions])
|
|
action_log_probs = action_pd.log_prob(actions)
|
|
#entropy = action_logits * action_log_probs
|
|
entropy = action_pd.entropy()
|
|
#print(entropy.item())
|
|
#print(action_logits)
|
|
return actions.item(), action_log_probs, state_values, entropy
|
|
|
|
def get_losses(
|
|
self,
|
|
rewards: torch.Tensor,
|
|
action_log_probs: torch.Tensor,
|
|
value_preds: torch.Tensor,
|
|
entropy: torch.Tensor,
|
|
masks: torch.Tensor,
|
|
gamma: float,
|
|
ent_coef: float,
|
|
) -> tuple[torch.tensor, torch.tensor]:
|
|
|
|
T = len(rewards)
|
|
advantages = torch.zeros(T, device=self.device)
|
|
|
|
# compute the advantages using GAE
|
|
gae = 0.0
|
|
for t in reversed(range(T - 1)):
|
|
td_error = (
|
|
rewards[t] + gamma * masks[t] * value_preds[t+1] - value_preds[t]
|
|
)
|
|
gae = td_error + gamma * 0.95 * masks[t] * gae
|
|
advantages[t] = gae
|
|
|
|
# calculate the loss of the minibatch for actor and critic
|
|
critic_loss = advantages.pow(2).mean()
|
|
|
|
#give a bonus for higher entropy to encourage exploration
|
|
actor_loss = (
|
|
-(advantages.detach() * action_log_probs).mean() - ent_coef * torch.Tensor(entropy).mean()
|
|
)
|
|
|
|
|
|
#advantages = torch.zeros(len(rewards), device=self.device)
|
|
#compute advantages
|
|
#mask - 0 if end of episode
|
|
#gamma - coeffecient for value prediction
|
|
#rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
|
|
#for t in range(len(rewards) - 1):
|
|
#advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
|
#print(advantages[t])
|
|
#rewards[t] + masks[t] * gamma * value_preds[t+1]
|
|
#(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
|
#rewards = np.array(rewards)
|
|
#rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
|
|
#returns = []
|
|
#R = 0
|
|
#for r, mask in zip(reversed(rewards), reversed(masks)):
|
|
# R = r + gamma * R * mask
|
|
# returns.insert(0, R)
|
|
|
|
#returns = torch.FloatTensor(returns)
|
|
#values = torch.stack(value_preds).squeeze(1)
|
|
|
|
#advantage = returns - values
|
|
|
|
|
|
#calculate critic loss - MSE
|
|
#critic_loss = advantages.pow(2).mean()
|
|
#critic_loss = advantages.pow(2).mean()
|
|
#calculate actor loss - give bonus for entropy to encourage exploration
|
|
#actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
|
|
#entropy = -torch.stack(entropy).sum(dim=-1).mean()
|
|
#actor_loss = (-action_log_probs * advantages.detach()).mean() - ent_coef * torch.Tensor(entropy).mean()
|
|
#print(action_log_probs)
|
|
#print(actor_loss)
|
|
return (critic_loss, actor_loss)
|
|
|
|
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
|
|
self.critic_optim.zero_grad()
|
|
critic_loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
|
|
self.critic_optim.step()
|
|
self.critic_scheduler.step()
|
|
|
|
self.actor_optim.zero_grad()
|
|
actor_loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
|
|
self.actor_optim.step()
|
|
self.actor_scheduler.step()
|
|
|
|
def set_eval(self):
|
|
self.critic.eval()
|
|
self.actor.eval()
|
|
|
|
|
|
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,5))
|
|
fig.suptitle(
|
|
f"training plots for the Simple Reference environment"
|
|
)
|
|
|
|
def drawPlots():
|
|
|
|
agent0_average = []
|
|
agent1_average = []
|
|
agent0_average_closs = []
|
|
agent0_average_aloss = []
|
|
agent1_average_closs = []
|
|
agent1_average_aloss = []
|
|
agent0_average_ent = []
|
|
agent1_average_ent = []
|
|
window = 100
|
|
for ind in range(len(agent0_rewards) - window + 1):
|
|
agent0_average.append(np.mean(agent0_rewards[ind:ind+window]))
|
|
for ind in range(len(agent1_rewards) - window + 1):
|
|
agent1_average.append(np.mean(agent1_rewards[ind:ind+window]))
|
|
for ind in range(len(agent0_critic_loss) - window + 1):
|
|
agent0_average_closs.append(np.mean(agent0_critic_loss[ind:ind+window]))
|
|
for ind in range(len(agent0_actor_loss) - window + 1):
|
|
agent0_average_aloss.append(np.mean(agent0_actor_loss[ind:ind+window]))
|
|
for ind in range(len(agent1_critic_loss) - window + 1):
|
|
agent1_average_closs.append(np.mean(agent1_critic_loss[ind:ind+window]))
|
|
for ind in range(len(agent1_actor_loss) - window + 1):
|
|
agent1_average_aloss.append(np.mean(agent1_actor_loss[ind:ind+window]))
|
|
for ind in range(len(agent0_entropy) - window + 1):
|
|
agent0_average_ent.append(np.mean(agent0_entropy[ind:ind+window]))
|
|
for ind in range(len(agent1_entropy) - window + 1):
|
|
agent1_average_ent.append(np.mean(agent1_entropy[ind:ind+window]))
|
|
|
|
axs[0][0].cla()
|
|
axs[0][0].plot(agent0_average, label="Agent 0")
|
|
axs[0][0].plot(agent1_average, label="Agent 1")
|
|
axs[0][0].legend()
|
|
axs[0][0].set_title("Rewards over Tme")
|
|
|
|
axs[1][0].cla()
|
|
axs[1][0].plot(agent0_average_closs, label="Agent 0")
|
|
axs[1][0].plot(agent1_average_closs, label="Agent 1")
|
|
axs[1][0].legend()
|
|
axs[1][0].set_title("Critic Loss over Tme")
|
|
|
|
axs[1][1].cla()
|
|
axs[1][1].plot(agent0_average_aloss, label="Agent 0")
|
|
axs[1][1].plot(agent1_average_aloss, label="Agent 1")
|
|
axs[1][1].legend()
|
|
axs[1][1].set_title("Actor Loss over Tme")
|
|
|
|
axs[0][1].cla()
|
|
axs[0][1].ticklabel_format(style='plain')
|
|
axs[0][1].plot(agent0_average_ent, label="Agent 0")
|
|
axs[0][1].plot(agent1_average_ent, label="Agent 1")
|
|
axs[0][1].legend()
|
|
axs[0][1].set_title("Actor Entropy over Tme")
|
|
|
|
agent0_critic_loss = []
|
|
agent0_actor_loss = []
|
|
agent1_critic_loss = []
|
|
agent1_actor_loss = []
|
|
agent0_rewards = []
|
|
agent1_rewards = []
|
|
agent0_entropy = []
|
|
agent1_entropy = []
|
|
|
|
def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr):
|
|
global agent0_critic_loss
|
|
global agent0_actor_loss
|
|
global agent1_critic_loss
|
|
global agent1_actor_loss
|
|
global agent0_rewards
|
|
global agent1_rewards
|
|
global agent0_entropy
|
|
global agent1_entropy
|
|
agent0_critic_loss = []
|
|
agent0_actor_loss = []
|
|
agent1_critic_loss = []
|
|
agent1_actor_loss = []
|
|
agent0_rewards = []
|
|
agent1_rewards = []
|
|
agent0_entropy = []
|
|
agent1_entropy = []
|
|
env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
|
|
#obs_space
|
|
#action_space
|
|
|
|
device = torch.device("cpu")
|
|
|
|
#init the agent
|
|
#agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
|
|
|
|
#wrapper to record statistics
|
|
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
|
|
# env, buffer_length=n_episodes
|
|
#)
|
|
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
|
|
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
#print(env.action_space("agent_0").n)
|
|
#print(env.observation_space("agent_0"))
|
|
|
|
for episode in tqdm(range(n_episodes)):
|
|
#print("Episode " + str(episode) + "/" + str(n_episodes))
|
|
observations, infos = env.reset()
|
|
agent_0_rewards = []
|
|
agent_0_probs = []
|
|
agent_0_pred = []
|
|
agent_0_ents = []
|
|
agent_0_mask = []
|
|
|
|
agent_1_rewards = []
|
|
agent_1_probs = []
|
|
agent_1_pred = []
|
|
agent_1_ents = []
|
|
agent_1_mask = []
|
|
while env.agents:
|
|
|
|
actions = {}
|
|
#eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"])
|
|
#bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"])
|
|
#alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"])
|
|
#actions["eve_0"] = eve_action.item()
|
|
#actions["bob_0"] = bob_action.item()
|
|
#actions["alice_0"] = alice_action.item()
|
|
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
|
|
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
|
|
actions["agent_0"] = agent_0_action
|
|
actions["agent_1"] = agent_1_action
|
|
observations, rewards, terminations, truncations, infos = env.step(actions)
|
|
#print(rewards)
|
|
agent_0_rewards.append(rewards["agent_0"])
|
|
agent_0_probs.append(agent_0_log_probs)
|
|
agent_0_pred.append(agent_0_state_val)
|
|
agent_0_ents.append(agent_0_ent.item())
|
|
agent_0_mask.append( 1 if env.agents else 0)
|
|
|
|
agent_1_rewards.append(rewards["agent_1"])
|
|
agent_1_probs.append(agent_1_log_probs)
|
|
agent_1_pred.append(agent_1_state_val)
|
|
agent_1_ents.append(agent_1_ent.item())
|
|
agent_1_mask.append( 1 if env.agents else 0)
|
|
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
|
|
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
|
|
#eve.update_params(eve_closs, eve_aloss)
|
|
agent_0_closs, agent_0_aloss = agent0.get_losses(agent_0_rewards, torch.stack(agent_0_probs), agent_0_pred, agent_0_ents, agent_0_mask, gamma, ent_coef)
|
|
#print(agent_0_rewards)
|
|
agent0_critic_loss.append(agent_0_closs.item())
|
|
agent0_actor_loss.append(agent_0_aloss.item())
|
|
#print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
|
|
agent0.update_params(agent_0_closs, agent_0_aloss)
|
|
agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef)
|
|
agent1_critic_loss.append(agent_1_closs.item())
|
|
agent1_actor_loss.append(agent_1_aloss.item())
|
|
#print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
|
|
agent1.update_params(agent_1_closs, agent_1_aloss)
|
|
|
|
agent0_rewards.append(np.array(agent_0_rewards).sum())
|
|
agent1_rewards.append(np.array(agent_1_rewards).sum())
|
|
#print(np.array(agent_0_ents).sum())
|
|
agent0_entropy.append(np.array(agent_0_ents).sum())
|
|
agent1_entropy.append(np.array(agent_1_ents).sum())
|
|
|
|
|
|
|
|
if episode % 500 == 0:
|
|
drawPlots()
|
|
plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png')
|
|
drawPlots()
|
|
plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png')
|
|
env.close()
|
|
|
|
|
|
|
|
#environment hyperparams
|
|
n_episodes = 1000
|
|
train(10000, 0.999, 0.01, 0.0001, 0.0005)
|
|
best = 1
|
|
for gamma in np.arange(0.999, 0.99, -0.1):
|
|
for ent_coef in np.arange(0, 0.1, 0.01):
|
|
for actor_lr in np.arange(0.002,0.1, 0.002):
|
|
for critic_lr in np.arange(0.002,0.1,0.002):
|
|
#train(n_episodes, gamma, ent_coef, actor_lr, critic_lr)
|
|
if best == 1 or agent0_rewards[n_episodes-1] > best:
|
|
best = agent0_rewards[n_episodes-1]
|
|
print("New Best: " + str(best) + "\n\tWith Parameters (gamma=" + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ')')
|
|
|
|
|
|
#agent hyperparams
|
|
#gamma = 0.999
|
|
#ent_coef = 0.01 # coefficient for entropy bonus
|
|
#actor_lr = 0.001
|
|
#critic_lr = 0.005
|
|
|
|
#environment setup
|
|
#env = simple_reference_v3.parallel_env(render_mode="human")
|
|
|
|
|
|
#drawPlots()
|
|
#plt.savefig('data.png')
|
|
#plt.show(block=False)
|
|
|
|
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
|
|
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
|
|
#plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss")
|
|
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
|
|
|
|
|
|
#actor0_weights_path = "weights/actor0_weights.h5"
|
|
#critic0_weights_path = "weights/critic0_weights.h5"
|
|
#actor1_weights_path = "weights/actor1_weights.h5"
|
|
#critic1_weights_path = "weights/critic1_weights.h5"
|
|
|
|
#if not os.path.exists("weights"):
|
|
# os.mkdir("weights")
|
|
|
|
#torch.save(agent0.actor.state_dict(), actor0_weights_path)
|
|
#torch.save(agent0.critic.state_dict(), critic0_weights_path)
|
|
#torch.save(agent1.actor.state_dict(), actor1_weights_path)
|
|
#torch.save(agent1.critic.state_dict(), critic1_weights_path)
|
|
|
|
#if load_weights:
|
|
# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
|
|
#
|
|
# agent.actor.load_state_dict(torch.load(actor_weights_path))
|
|
# agent.critic.load_state_dict(torch.load(critic_weights_path))
|
|
# agent.actor.eval()
|
|
# agent.critic.eval()
|
|
|
|
#agent0.set_eval()
|
|
#agent1.set_eval()
|
|
#env = simple_reference_v3.parallel_env(render_mode="human")
|
|
#while True:
|
|
# observations, infos = env.reset()
|
|
# while env.agents:
|
|
# plt.pause(0.001)
|
|
# actions = {}
|
|
# agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
|
|
# agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
|
|
# actions["agent_0"] = agent_0_action
|
|
# actions["agent_1"] = agent_1_action
|
|
# observations, rewards, terminations, truncations, infos = env.step(actions)
|
|
|
|
#env.close() |