Files
simple_crypto/main.py

444 lines
18 KiB
Python

import os
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import gymnasium as gym
from pettingzoo.mpe import simple_reference_v3,simple_v3
import pettingzoo
class A2C(nn.Module):
def __init__(
self,
n_features: int,
n_actions: int,
device: torch.device,
critic_lr: float,
actor_lr: float
) -> None:
super().__init__()
self.device = device
critic_layers = [
nn.Linear(n_features, 128),
nn.ReLU(),
nn.Linear(128,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128, 1),
]
actor_layers = [
nn.Linear(n_features, 128),
nn.ReLU(),
nn.Linear(128,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128, n_actions),
nn.Softmax()
#nn.Sigmoid()
]
self.critic = nn.Sequential(*critic_layers).to(self.device)
self.actor = nn.Sequential(*actor_layers).to(self.device)
self.critic_optim = optim.Adam(self.critic.parameters(), lr=critic_lr)
self.actor_optim = optim.Adam(self.actor.parameters(), lr=actor_lr)
#self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=0.9)
#self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=0.9)
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
x = torch.Tensor(x).to(self.device)
state_values = self.critic(x)
action_logits_vec = self.actor(x)
return (state_values, action_logits_vec)
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
state_values, action_logits = self.forward(x)
action_pd = torch.distributions.Categorical(
logits=action_logits
)
actions = action_pd.sample()
#actions = torch.multinomial(action_logits, 1).item()
#action_log_probs = torch.log(action_logits.squeeze(0)[actions])
action_log_probs = action_pd.log_prob(actions)
#entropy = action_logits * action_log_probs
entropy = action_pd.entropy()
#print(entropy.item())
#print(action_logits)
return actions.item(), action_log_probs, state_values, entropy
def get_losses(
self,
rewards: torch.Tensor,
action_log_probs: torch.Tensor,
value_preds: torch.Tensor,
entropy: torch.Tensor,
masks: torch.Tensor,
gamma: float,
ent_coef: float,
) -> tuple[torch.tensor, torch.tensor]:
#T = len(rewards)
#rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
#advantages = torch.zeros(T, device=self.device)
# compute the advantages using GAE
#gae = 0.0
#for t in reversed(range(T - 1)):
# td_error = (
# rewards[t] + gamma * masks[t] * value_preds[t+1] - value_preds[t]
# )
# gae = td_error + gamma * 0.95 * masks[t] * gae
# advantages[t] = gae
#advantages = (advantages - advantages.mean()) / advantages.std()
# calculate the loss of the minibatch for actor and critic
#critic_loss = advantages.pow(2).mean()
#give a bonus for higher entropy to encourage exploration
#actor_loss = (
# -(advantages.detach() * action_log_probs).mean() - ent_coef * torch.Tensor(entropy).mean()
#)
#advantages = torch.zeros(len(rewards), device=self.device)
#compute advantages
#mask - 0 if end of episode
#gamma - coeffecient for value prediction
#for t in range(len(rewards) - 1):
#advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
#print(advantages[t])
#rewards[t] + masks[t] * gamma * value_preds[t+1]
#(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
#rewards = np.array(rewards)
#rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
returns = []
R = 0
for r, mask in zip(reversed(rewards), reversed(masks)):
R = r + gamma * R * mask
returns.insert(0, R)
returns = torch.FloatTensor(returns)
values = torch.stack(value_preds).squeeze(1)
advantage = returns - values
#calculate critic loss - MSE
critic_loss = advantage.pow(2).mean()
#critic_loss = advantages.pow(2).mean()
#calculate actor loss - give bonus for entropy to encourage exploration
#actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
#entropy = -torch.stack(entropy).sum(dim=-1).mean()
actor_loss = (-action_log_probs * advantage.detach()).mean() - ent_coef * torch.Tensor(entropy).mean()
#print(action_log_probs)
#print(actor_loss)
return (critic_loss, actor_loss)
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
self.critic_optim.zero_grad()
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.critic_optim.step()
#self.critic_scheduler.step()
self.actor_optim.zero_grad()
actor_loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
self.actor_optim.step()
#self.actor_scheduler.step()
def set_eval(self):
self.critic.eval()
self.actor.eval()
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,5))
fig.suptitle(
f"training plots for the Simple Reference environment"
)
def drawPlots():
agent0_average = []
agent1_average = []
agent0_average_closs = []
agent0_average_aloss = []
agent1_average_closs = []
agent1_average_aloss = []
agent0_average_ent = []
agent1_average_ent = []
window = 100
for ind in range(len(agent0_rewards) - window + 1):
agent0_average.append(np.mean(agent0_rewards[ind:ind+window]))
for ind in range(len(agent1_rewards) - window + 1):
agent1_average.append(np.mean(agent1_rewards[ind:ind+window]))
for ind in range(len(agent0_critic_loss) - window + 1):
agent0_average_closs.append(np.mean(agent0_critic_loss[ind:ind+window]))
for ind in range(len(agent0_actor_loss) - window + 1):
agent0_average_aloss.append(np.mean(agent0_actor_loss[ind:ind+window]))
for ind in range(len(agent1_critic_loss) - window + 1):
agent1_average_closs.append(np.mean(agent1_critic_loss[ind:ind+window]))
for ind in range(len(agent1_actor_loss) - window + 1):
agent1_average_aloss.append(np.mean(agent1_actor_loss[ind:ind+window]))
for ind in range(len(agent0_entropy) - window + 1):
agent0_average_ent.append(np.mean(agent0_entropy[ind:ind+window]))
for ind in range(len(agent1_entropy) - window + 1):
agent1_average_ent.append(np.mean(agent1_entropy[ind:ind+window]))
axs[0][0].cla()
axs[0][0].plot(agent0_average, label="Agent 0")
axs[0][0].plot(agent1_average, label="Agent 1")
axs[0][0].legend()
axs[0][0].set_title("Rewards over Tme")
axs[1][0].cla()
axs[1][0].plot(agent0_average_closs, label="Agent 0")
axs[1][0].plot(agent1_average_closs, label="Agent 1")
axs[1][0].legend()
axs[1][0].set_title("Critic Loss over Tme")
axs[1][1].cla()
axs[1][1].plot(agent0_average_aloss, label="Agent 0")
axs[1][1].plot(agent1_average_aloss, label="Agent 1")
axs[1][1].legend()
axs[1][1].set_title("Actor Loss over Tme")
axs[0][1].cla()
axs[0][1].ticklabel_format(style='plain')
axs[0][1].plot(agent0_average_ent, label="Agent 0")
axs[0][1].plot(agent1_average_ent, label="Agent 1")
axs[0][1].legend()
axs[0][1].set_title("Actor Entropy over Tme")
agent0_critic_loss = []
agent0_actor_loss = []
agent1_critic_loss = []
agent1_actor_loss = []
agent0_rewards = []
agent1_rewards = []
agent0_entropy = []
agent1_entropy = []
def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr):
global agent0_critic_loss
global agent0_actor_loss
global agent1_critic_loss
global agent1_actor_loss
global agent0_rewards
global agent1_rewards
global agent0_entropy
global agent1_entropy
agent0_critic_loss = []
agent0_actor_loss = []
agent1_critic_loss = []
agent1_actor_loss = []
agent0_rewards = []
agent1_rewards = []
agent0_entropy = []
agent1_entropy = []
#env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
env = simple_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
#obs_space
#action_space
device = torch.device("cpu")
#init the agent
#agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
#wrapper to record statistics
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
# env, buffer_length=n_episodes
#)
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#print(env.action_space("agent_0").n)
#print(env.observation_space("agent_0"))
for episode in tqdm(range(n_episodes)):
#print("Episode " + str(episode) + "/" + str(n_episodes))
observations, infos = env.reset()
agent_0_rewards = []
agent_0_probs = []
agent_0_pred = []
agent_0_ents = []
agent_0_mask = []
agent_1_rewards = []
agent_1_probs = []
agent_1_pred = []
agent_1_ents = []
agent_1_mask = []
while env.agents:
actions = {}
#eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"])
#bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"])
#alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"])
#actions["eve_0"] = eve_action.item()
#actions["bob_0"] = bob_action.item()
#actions["alice_0"] = alice_action.item()
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
#agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
actions["agent_0"] = agent_0_action
#actions["agent_1"] = agent_1_action
observations, rewards, terminations, truncations, infos = env.step(actions)
#print(rewards)
agent_0_rewards.append(rewards["agent_0"])
agent_0_probs.append(agent_0_log_probs)
agent_0_pred.append(agent_0_state_val)
agent_0_ents.append(agent_0_ent.item())
agent_0_mask.append( 1 if env.agents else 0)
#agent_1_rewards.append(rewards["agent_1"])
#agent_1_probs.append(agent_1_log_probs)
#agent_1_pred.append(agent_1_state_val)
#agent_1_ents.append(agent_1_ent.item())
#agent_1_mask.append( 1 if env.agents else 0)
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
#eve.update_params(eve_closs, eve_aloss)
agent_0_closs, agent_0_aloss = agent0.get_losses(agent_0_rewards, torch.stack(agent_0_probs), agent_0_pred, agent_0_ents, agent_0_mask, gamma, ent_coef)
#print(agent_0_rewards)
agent0_critic_loss.append(agent_0_closs.item())
agent0_actor_loss.append(agent_0_aloss.item())
#print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
agent0.update_params(agent_0_closs, agent_0_aloss)
#agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef)
#agent1_critic_loss.append(agent_1_closs.item())
#agent1_actor_loss.append(agent_1_aloss.item())
#print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
#agent1.update_params(agent_1_closs, agent_1_aloss)
agent0_rewards.append(np.array(agent_0_rewards).sum())
#agent1_rewards.append(np.array(agent_1_rewards).sum())
#print(np.array(agent_0_ents).sum())
agent0_entropy.append(np.array(agent_0_ents).mean())
#agent1_entropy.append(np.array(agent_1_ents).sum())
if episode % 500 == 0:
drawPlots()
plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png')
drawPlots()
plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png')
actor0_weights_path = "weights/actor0_weights.h5"
critic0_weights_path = "weights/critic0_weights.h5"
actor1_weights_path = "weights/actor1_weights.h5"
critic1_weights_path = "weights/critic1_weights.h5"
if not os.path.exists("weights"):
os.mkdir("weights")
torch.save(agent0.actor.state_dict(), actor0_weights_path)
torch.save(agent0.critic.state_dict(), critic0_weights_path)
#torch.save(agent1.actor.state_dict(), actor1_weights_path)
#torch.save(agent1.critic.state_dict(), critic1_weights_path)
agent0.set_eval()
#agent1.set_eval()
#env = simple_reference_v3.parallel_env(render_mode="human")
env = simple_v3.parallel_env(render_mode="human")
while True:
observations, infos = env.reset()
while env.agents:
plt.pause(0.001)
actions = {}
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
#agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
actions["agent_0"] = agent_0_action
#actions["agent_1"] = agent_1_action
observations, rewards, terminations, truncations, infos = env.step(actions)
env.close()
#environment hyperparams
n_episodes = 1000
train(10000, 0.9, 0.03, 0.001, 0.005)
best = 1
for gamma in np.arange(0.999, 0.99, -0.1):
for ent_coef in np.arange(0, 0.1, 0.01):
for actor_lr in np.arange(0.002,0.1, 0.002):
for critic_lr in np.arange(0.002,0.1,0.002):
#train(n_episodes, gamma, ent_coef, actor_lr, critic_lr)
if best == 1 or agent0_rewards[n_episodes-1] > best:
best = agent0_rewards[n_episodes-1]
print("New Best: " + str(best) + "\n\tWith Parameters (gamma=" + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ')')
#agent hyperparams
#gamma = 0.999
#ent_coef = 0.01 # coefficient for entropy bonus
#actor_lr = 0.001
#critic_lr = 0.005
#environment setup
#env = simple_reference_v3.parallel_env(render_mode="human")
#drawPlots()
#plt.savefig('data.png')
#plt.show(block=False)
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
#plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss")
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
#actor0_weights_path = "weights/actor0_weights.h5"
#critic0_weights_path = "weights/critic0_weights.h5"
#actor1_weights_path = "weights/actor1_weights.h5"
#critic1_weights_path = "weights/critic1_weights.h5"
#if not os.path.exists("weights"):
# os.mkdir("weights")
#torch.save(agent0.actor.state_dict(), actor0_weights_path)
#torch.save(agent0.critic.state_dict(), critic0_weights_path)
#torch.save(agent1.actor.state_dict(), actor1_weights_path)
#torch.save(agent1.critic.state_dict(), critic1_weights_path)
#if load_weights:
# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
#
# agent.actor.load_state_dict(torch.load(actor_weights_path))
# agent.critic.load_state_dict(torch.load(critic_weights_path))
# agent.actor.eval()
# agent.critic.eval()
#agent0.set_eval()
#agent1.set_eval()
#env = simple_reference_v3.parallel_env(render_mode="human")
#while True:
# observations, infos = env.reset()
# while env.agents:
# plt.pause(0.001)
# actions = {}
# agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
# agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
# actions["agent_0"] = agent_0_action
# actions["agent_1"] = agent_1_action
# observations, rewards, terminations, truncations, infos = env.step(actions)
#env.close()