Added loss function
This commit is contained in:
29
main.py
29
main.py
@ -64,8 +64,32 @@ class A2C(nn.Module):
|
||||
entropy = action_pd.entropy()
|
||||
return actions, action_log_probs, state_values, entropy
|
||||
|
||||
def get_losses(self):
|
||||
pass
|
||||
def get_losses(
|
||||
self,
|
||||
rewards: torch.Tensor,
|
||||
action_log_probs: torch.Tensor,
|
||||
value_preds: torch.Tensor,
|
||||
entropy: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
gamma: float,
|
||||
ent_coef: float,
|
||||
device: torch.device
|
||||
) -> tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
advantages = torch.zeros(len(rewards), device=device)
|
||||
#compute advantages
|
||||
#mask - 0 if end of episode
|
||||
#gamma - coeffecient for value prediction
|
||||
for t in range(len(rewards) - 1):
|
||||
advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
||||
|
||||
#calculate critic loss - MSE
|
||||
critic_loss = advantages.pow(2).mean()
|
||||
|
||||
#calculate actor loss - give bonus for entropy to encourage exploration
|
||||
actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
|
||||
|
||||
return (critic_loss, actor_loss)
|
||||
|
||||
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
|
||||
self.critic_optim.zero_grad()
|
||||
@ -80,6 +104,7 @@ class A2C(nn.Module):
|
||||
n_episodes = 10
|
||||
|
||||
#agent hyperparams
|
||||
gamma = 0.999
|
||||
ent_coef = 0.01 # coefficient for entropy bonus
|
||||
actor_lr = 0.001
|
||||
critic_lr = 0.005
|
||||
|
Reference in New Issue
Block a user