A2C w/o loss function
This commit is contained in:
68
main.py
68
main.py
@ -11,21 +11,70 @@ from tqdm import tqdm
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from pettingzoo.mpe import simple_crypto_v3
|
from pettingzoo.mpe import simple_crypto_v3
|
||||||
|
|
||||||
class A2C():
|
class A2C(nn.Module):
|
||||||
def __init__(self):
|
def __init__(
|
||||||
pass
|
self,
|
||||||
|
n_features: int,
|
||||||
|
n_actions: int,
|
||||||
|
device: torch.device,
|
||||||
|
critic_lr: float,
|
||||||
|
actor_lr: float
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def forward(self):
|
critic_layers = [
|
||||||
pass
|
nn.Linear(n_features, 8),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(8, 8),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(8, 1)
|
||||||
|
]
|
||||||
|
|
||||||
def select_action(self):
|
actor_layers = [
|
||||||
pass
|
nn.Linear(n_features, 8),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(8, 8),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(8, n_actions),
|
||||||
|
nn.Softmax()
|
||||||
|
]
|
||||||
|
|
||||||
|
self.critic = nn.Sequential(*critic_layers).to(self.device)
|
||||||
|
self.actor = nn.Sequential(*actor_layers).to(self.device)
|
||||||
|
|
||||||
|
self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr)
|
||||||
|
self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr)
|
||||||
|
|
||||||
|
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
|
||||||
|
x = torch.Tensor(x).to(self.device)
|
||||||
|
state_values = self.critic(x)
|
||||||
|
action_logits_vec = self.actor(x)
|
||||||
|
return (state_values, action_logits_vec)
|
||||||
|
|
||||||
|
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
|
state_values, action_logits = self.forward(x)
|
||||||
|
action_pd = torch.distributions.Categorical(
|
||||||
|
logits=action_logits
|
||||||
|
)
|
||||||
|
actions = action_pd.sample()
|
||||||
|
action_log_probs = action_pd.log_prob(actions)
|
||||||
|
entropy = action_pd.entropy()
|
||||||
|
return actions, action_log_probs, state_values, entropy
|
||||||
|
|
||||||
def get_losses(self):
|
def get_losses(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def update_params(self):
|
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
|
||||||
pass
|
self.critic_optim.zero_grad()
|
||||||
|
critic_loss.backward()
|
||||||
|
self.critic_optim.step()
|
||||||
|
|
||||||
|
self.actor_optim.zero_grad()
|
||||||
|
actor_loss.backward()
|
||||||
|
self.actor_optim.step()
|
||||||
|
|
||||||
#environment hyperparams
|
#environment hyperparams
|
||||||
n_episodes = 10
|
n_episodes = 10
|
||||||
@ -56,5 +105,6 @@ while env.agents:
|
|||||||
|
|
||||||
actions = {agent: env.action_space(agent).sample() for agent in env.agents}
|
actions = {agent: env.action_space(agent).sample() for agent in env.agents}
|
||||||
observations, rewards, terminations, truncations, infos = env.step(actions)
|
observations, rewards, terminations, truncations, infos = env.step(actions)
|
||||||
|
print(observations)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
Reference in New Issue
Block a user