Added loss function

This commit is contained in:
Scirockses
2025-08-31 12:41:23 -06:00
parent b78e808108
commit 321225cf88

29
main.py
View File

@ -64,8 +64,32 @@ class A2C(nn.Module):
entropy = action_pd.entropy()
return actions, action_log_probs, state_values, entropy
def get_losses(self):
pass
def get_losses(
self,
rewards: torch.Tensor,
action_log_probs: torch.Tensor,
value_preds: torch.Tensor,
entropy: torch.Tensor,
masks: torch.Tensor,
gamma: float,
ent_coef: float,
device: torch.device
) -> tuple[torch.tensor, torch.tensor]:
advantages = torch.zeros(len(rewards), device=device)
#compute advantages
#mask - 0 if end of episode
#gamma - coeffecient for value prediction
for t in range(len(rewards) - 1):
advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
#calculate critic loss - MSE
critic_loss = advantages.pow(2).mean()
#calculate actor loss - give bonus for entropy to encourage exploration
actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
return (critic_loss, actor_loss)
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
self.critic_optim.zero_grad()
@ -80,6 +104,7 @@ class A2C(nn.Module):
n_episodes = 10
#agent hyperparams
gamma = 0.999
ent_coef = 0.01 # coefficient for entropy bonus
actor_lr = 0.001
critic_lr = 0.005