Files
simple_crypto/main.py
2025-08-31 12:41:23 -06:00

135 lines
3.8 KiB
Python

import os
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import gymnasium as gym
from pettingzoo.mpe import simple_crypto_v3
class A2C(nn.Module):
def __init__(
self,
n_features: int,
n_actions: int,
device: torch.device,
critic_lr: float,
actor_lr: float
) -> None:
super().__init__()
self.device = device
critic_layers = [
nn.Linear(n_features, 8),
nn.ReLU(),
nn.Linear(8, 8),
nn.ReLU(),
nn.Linear(8, 1)
]
actor_layers = [
nn.Linear(n_features, 8),
nn.ReLU(),
nn.Linear(8, 8),
nn.ReLU(),
nn.Linear(8, n_actions),
nn.Softmax()
]
self.critic = nn.Sequential(*critic_layers).to(self.device)
self.actor = nn.Sequential(*actor_layers).to(self.device)
self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr)
self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr)
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
x = torch.Tensor(x).to(self.device)
state_values = self.critic(x)
action_logits_vec = self.actor(x)
return (state_values, action_logits_vec)
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
state_values, action_logits = self.forward(x)
action_pd = torch.distributions.Categorical(
logits=action_logits
)
actions = action_pd.sample()
action_log_probs = action_pd.log_prob(actions)
entropy = action_pd.entropy()
return actions, action_log_probs, state_values, entropy
def get_losses(
self,
rewards: torch.Tensor,
action_log_probs: torch.Tensor,
value_preds: torch.Tensor,
entropy: torch.Tensor,
masks: torch.Tensor,
gamma: float,
ent_coef: float,
device: torch.device
) -> tuple[torch.tensor, torch.tensor]:
advantages = torch.zeros(len(rewards), device=device)
#compute advantages
#mask - 0 if end of episode
#gamma - coeffecient for value prediction
for t in range(len(rewards) - 1):
advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
#calculate critic loss - MSE
critic_loss = advantages.pow(2).mean()
#calculate actor loss - give bonus for entropy to encourage exploration
actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
return (critic_loss, actor_loss)
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
#environment hyperparams
n_episodes = 10
#agent hyperparams
gamma = 0.999
ent_coef = 0.01 # coefficient for entropy bonus
actor_lr = 0.001
critic_lr = 0.005
#environment setup
env = simple_crypto_v3.parallel_env(render_mode="human")
#obs_space
#action_space
device = torch.device("cpu")
#init the agent
#agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
#wrapper to record statistics
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
# env, buffer_length=n_episodes
#)
observations, infos = env.reset()
done = False
while env.agents:
actions = {agent: env.action_space(agent).sample() for agent in env.agents}
observations, rewards, terminations, truncations, infos = env.step(actions)
print(observations)
env.close()