A2C w/o loss function

This commit is contained in:
Scirockses
2025-08-31 11:47:27 -06:00
parent 366bc88355
commit b78e808108

68
main.py
View File

@ -11,21 +11,70 @@ from tqdm import tqdm
import gymnasium as gym
from pettingzoo.mpe import simple_crypto_v3
class A2C():
def __init__(self):
pass
class A2C(nn.Module):
def __init__(
self,
n_features: int,
n_actions: int,
device: torch.device,
critic_lr: float,
actor_lr: float
) -> None:
super().__init__()
self.device = device
def forward(self):
pass
critic_layers = [
nn.Linear(n_features, 8),
nn.ReLU(),
nn.Linear(8, 8),
nn.ReLU(),
nn.Linear(8, 1)
]
def select_action(self):
pass
actor_layers = [
nn.Linear(n_features, 8),
nn.ReLU(),
nn.Linear(8, 8),
nn.ReLU(),
nn.Linear(8, n_actions),
nn.Softmax()
]
self.critic = nn.Sequential(*critic_layers).to(self.device)
self.actor = nn.Sequential(*actor_layers).to(self.device)
self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr)
self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr)
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
x = torch.Tensor(x).to(self.device)
state_values = self.critic(x)
action_logits_vec = self.actor(x)
return (state_values, action_logits_vec)
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
state_values, action_logits = self.forward(x)
action_pd = torch.distributions.Categorical(
logits=action_logits
)
actions = action_pd.sample()
action_log_probs = action_pd.log_prob(actions)
entropy = action_pd.entropy()
return actions, action_log_probs, state_values, entropy
def get_losses(self):
pass
def update_params(self):
pass
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
#environment hyperparams
n_episodes = 10
@ -56,5 +105,6 @@ while env.agents:
actions = {agent: env.action_space(agent).sample() for agent in env.agents}
observations, rewards, terminations, truncations, infos = env.step(actions)
print(observations)
env.close()