diff --git a/main.py b/main.py index ba97137..f32e79b 100644 --- a/main.py +++ b/main.py @@ -11,21 +11,70 @@ from tqdm import tqdm import gymnasium as gym from pettingzoo.mpe import simple_crypto_v3 -class A2C(): - def __init__(self): - pass +class A2C(nn.Module): + def __init__( + self, + n_features: int, + n_actions: int, + device: torch.device, + critic_lr: float, + actor_lr: float + ) -> None: + + super().__init__() + self.device = device - def forward(self): - pass + critic_layers = [ + nn.Linear(n_features, 8), + nn.ReLU(), + nn.Linear(8, 8), + nn.ReLU(), + nn.Linear(8, 1) + ] - def select_action(self): - pass + actor_layers = [ + nn.Linear(n_features, 8), + nn.ReLU(), + nn.Linear(8, 8), + nn.ReLU(), + nn.Linear(8, n_actions), + nn.Softmax() + ] + + self.critic = nn.Sequential(*critic_layers).to(self.device) + self.actor = nn.Sequential(*actor_layers).to(self.device) + + self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr) + self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr) + + def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]: + x = torch.Tensor(x).to(self.device) + state_values = self.critic(x) + action_logits_vec = self.actor(x) + return (state_values, action_logits_vec) + + def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + + state_values, action_logits = self.forward(x) + action_pd = torch.distributions.Categorical( + logits=action_logits + ) + actions = action_pd.sample() + action_log_probs = action_pd.log_prob(actions) + entropy = action_pd.entropy() + return actions, action_log_probs, state_values, entropy def get_losses(self): pass - def update_params(self): - pass + def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None: + self.critic_optim.zero_grad() + critic_loss.backward() + self.critic_optim.step() + + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() #environment hyperparams n_episodes = 10 @@ -56,5 +105,6 @@ while env.agents: actions = {agent: env.action_space(agent).sample() for agent in env.agents} observations, rewards, terminations, truncations, infos = env.step(actions) + print(observations) env.close() \ No newline at end of file