Files
simple_crypto/main.py
Scirockses 3bd7f70a94 MoreSTUFF
2025-08-31 19:05:39 -06:00

328 lines
12 KiB
Python

import os
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import gymnasium as gym
from pettingzoo.mpe import simple_reference_v3
import pettingzoo
class A2C(nn.Module):
def __init__(
self,
n_features: int,
n_actions: int,
device: torch.device,
critic_lr: float,
actor_lr: float
) -> None:
super().__init__()
self.device = device
critic_layers = [
nn.Linear(n_features, 128),
nn.ReLU(),
nn.Linear(128,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128, 1),
]
actor_layers = [
nn.Linear(n_features, 128),
nn.ReLU(),
nn.Linear(128,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128, n_actions),
nn.Softmax()
#nn.Sigmoid()
]
self.critic = nn.Sequential(*critic_layers).to(self.device)
self.actor = nn.Sequential(*actor_layers).to(self.device)
self.critic_optim = optim.Adam(self.critic.parameters(), lr=critic_lr)
self.actor_optim = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=1)
self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=1)
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
x = torch.Tensor(x).to(self.device)
state_values = self.critic(x)
action_logits_vec = self.actor(x)
return (state_values, action_logits_vec)
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
state_values, action_logits = self.forward(x)
action_pd = torch.distributions.Categorical(
logits=action_logits
)
actions = action_pd.sample()
#actions = torch.multinomial(action_logits, 1).item()
#action_log_probs = torch.log(action_logits.squeeze(0)[actions])
action_log_probs = action_pd.log_prob(actions)
#entropy = action_logits * action_log_probs
entropy = action_pd.entropy()
#print(entropy.item())
#print(action_logits)
return actions.item(), action_log_probs, state_values, entropy
def get_losses(
self,
rewards: torch.Tensor,
action_log_probs: torch.Tensor,
value_preds: torch.Tensor,
entropy: torch.Tensor,
masks: torch.Tensor,
gamma: float,
ent_coef: float,
) -> tuple[torch.tensor, torch.tensor]:
T = len(rewards)
advantages = torch.zeros(T, device=device)
# compute the advantages using GAE
gae = 0.0
for t in reversed(range(T - 1)):
td_error = (
rewards[t] + gamma * masks[t] * value_preds[t+1] - value_preds[t]
)
gae = td_error + gamma * 0.95 * masks[t] * gae
advantages[t] = gae
# calculate the loss of the minibatch for actor and critic
critic_loss = advantages.pow(2).mean()
#give a bonus for higher entropy to encourage exploration
actor_loss = (
-(advantages.detach() * action_log_probs).mean() - ent_coef * torch.Tensor(entropy).mean()
)
#advantages = torch.zeros(len(rewards), device=self.device)
#compute advantages
#mask - 0 if end of episode
#gamma - coeffecient for value prediction
#rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
#for t in range(len(rewards) - 1):
#advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
#print(advantages[t])
#rewards[t] + masks[t] * gamma * value_preds[t+1]
#(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
#rewards = np.array(rewards)
#rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
#returns = []
#R = 0
#for r, mask in zip(reversed(rewards), reversed(masks)):
# R = r + gamma * R * mask
# returns.insert(0, R)
#returns = torch.FloatTensor(returns)
#values = torch.stack(value_preds).squeeze(1)
#advantage = returns - values
#calculate critic loss - MSE
#critic_loss = advantages.pow(2).mean()
#critic_loss = advantages.pow(2).mean()
#calculate actor loss - give bonus for entropy to encourage exploration
#actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
#entropy = -torch.stack(entropy).sum(dim=-1).mean()
#actor_loss = (-action_log_probs * advantages.detach()).mean() - ent_coef * torch.Tensor(entropy).mean()
#print(action_log_probs)
#print(actor_loss)
return (critic_loss, actor_loss)
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
self.critic_optim.zero_grad()
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.critic_optim.step()
self.critic_scheduler.step()
self.actor_optim.zero_grad()
actor_loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
self.actor_optim.step()
self.actor_scheduler.step()
def set_eval(self):
self.critic.eval()
self.actor.eval()
#environment hyperparams
n_episodes = 100
#agent hyperparams
gamma = 0.999
ent_coef = 0.01 # coefficient for entropy bonus
actor_lr = 0.001
critic_lr = 0.005
#environment setup
#env = simple_reference_v3.parallel_env(render_mode="human")
env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
#obs_space
#action_space
device = torch.device("cpu")
#init the agent
#agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
#wrapper to record statistics
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
# env, buffer_length=n_episodes
#)
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#print(env.action_space("agent_0").n)
#print(env.observation_space("agent_0"))
agent0_critic_loss = []
agent0_actor_loss = []
agent1_critic_loss = []
agent1_actor_loss = []
agent0_rewards = []
agent1_rewards = []
for episode in tqdm(range(n_episodes)):
#print("Episode " + str(episode) + "/" + str(n_episodes))
observations, infos = env.reset()
agent_0_rewards = []
agent_0_probs = []
agent_0_pred = []
agent_0_ents = []
agent_0_mask = []
agent_1_rewards = []
agent_1_probs = []
agent_1_pred = []
agent_1_ents = []
agent_1_mask = []
while env.agents:
actions = {}
#eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"])
#bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"])
#alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"])
#actions["eve_0"] = eve_action.item()
#actions["bob_0"] = bob_action.item()
#actions["alice_0"] = alice_action.item()
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
actions["agent_0"] = agent_0_action
actions["agent_1"] = agent_1_action
observations, rewards, terminations, truncations, infos = env.step(actions)
#print(rewards)
agent_0_rewards.append(rewards["agent_0"])
agent_0_probs.append(agent_0_log_probs)
agent_0_pred.append(agent_0_state_val)
agent_0_ents.append(agent_0_ent)
agent_0_mask.append( 1 if env.agents else 0)
agent_1_rewards.append(rewards["agent_1"])
agent_1_probs.append(agent_1_log_probs)
agent_1_pred.append(agent_1_state_val)
agent_1_ents.append(agent_1_ent)
agent_1_mask.append( 1 if env.agents else 0)
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
#eve.update_params(eve_closs, eve_aloss)
agent_0_closs, agent_0_aloss = agent0.get_losses(agent_0_rewards, torch.stack(agent_0_probs), agent_0_pred, agent_0_ents, agent_0_mask, gamma, ent_coef)
#print(agent_0_rewards)
agent0_critic_loss.append(agent_0_closs.item())
agent0_actor_loss.append(agent_0_aloss.item())
#print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
agent0.update_params(agent_0_closs, agent_0_aloss)
agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef)
agent1_critic_loss.append(agent_1_closs.item())
agent1_actor_loss.append(agent_1_aloss.item())
#print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
agent1.update_params(agent_1_closs, agent_1_aloss)
agent0_rewards.append(np.array(agent_0_rewards).sum())
agent1_rewards.append(np.array(agent_1_rewards).sum())
if episode % 500 == 0:
#rolling_length = 20
#fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,5))
#fig.suptitle(
# f"training plots for the Simple Reference environment"
#)
plt.plot(agent0_rewards, label="Agent 0 Rewards")
plt.plot(agent1_rewards, label="Agent 1 Rewards")
plt.legend()
plt.savefig('data.png')
plt.clf()
plt.plot(agent0_rewards, label="Agent 0 Rewards")
plt.plot(agent1_rewards, label="Agent 1 Rewards")
plt.legend()
plt.savefig('data.png')
plt.show(block=False)
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
#plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss")
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
actor0_weights_path = "weights/actor0_weights.h5"
critic0_weights_path = "weights/critic0_weights.h5"
actor1_weights_path = "weights/actor1_weights.h5"
critic1_weights_path = "weights/critic1_weights.h5"
if not os.path.exists("weights"):
os.mkdir("weights")
torch.save(agent0.actor.state_dict(), actor0_weights_path)
torch.save(agent0.critic.state_dict(), critic0_weights_path)
torch.save(agent1.actor.state_dict(), actor1_weights_path)
torch.save(agent1.critic.state_dict(), critic1_weights_path)
#if load_weights:
# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
#
# agent.actor.load_state_dict(torch.load(actor_weights_path))
# agent.critic.load_state_dict(torch.load(critic_weights_path))
# agent.actor.eval()
# agent.critic.eval()
agent0.set_eval()
agent1.set_eval()
env = simple_reference_v3.parallel_env(render_mode="human")
while True:
observations, infos = env.reset()
while env.agents:
plt.pause(0.001)
actions = {}
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
actions["agent_0"] = agent_0_action
actions["agent_1"] = agent_1_action
observations, rewards, terminations, truncations, infos = env.step(actions)
env.close()