Added loss function
This commit is contained in:
29
main.py
29
main.py
@ -64,8 +64,32 @@ class A2C(nn.Module):
|
|||||||
entropy = action_pd.entropy()
|
entropy = action_pd.entropy()
|
||||||
return actions, action_log_probs, state_values, entropy
|
return actions, action_log_probs, state_values, entropy
|
||||||
|
|
||||||
def get_losses(self):
|
def get_losses(
|
||||||
pass
|
self,
|
||||||
|
rewards: torch.Tensor,
|
||||||
|
action_log_probs: torch.Tensor,
|
||||||
|
value_preds: torch.Tensor,
|
||||||
|
entropy: torch.Tensor,
|
||||||
|
masks: torch.Tensor,
|
||||||
|
gamma: float,
|
||||||
|
ent_coef: float,
|
||||||
|
device: torch.device
|
||||||
|
) -> tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
|
advantages = torch.zeros(len(rewards), device=device)
|
||||||
|
#compute advantages
|
||||||
|
#mask - 0 if end of episode
|
||||||
|
#gamma - coeffecient for value prediction
|
||||||
|
for t in range(len(rewards) - 1):
|
||||||
|
advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
||||||
|
|
||||||
|
#calculate critic loss - MSE
|
||||||
|
critic_loss = advantages.pow(2).mean()
|
||||||
|
|
||||||
|
#calculate actor loss - give bonus for entropy to encourage exploration
|
||||||
|
actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
|
||||||
|
|
||||||
|
return (critic_loss, actor_loss)
|
||||||
|
|
||||||
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
|
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
@ -80,6 +104,7 @@ class A2C(nn.Module):
|
|||||||
n_episodes = 10
|
n_episodes = 10
|
||||||
|
|
||||||
#agent hyperparams
|
#agent hyperparams
|
||||||
|
gamma = 0.999
|
||||||
ent_coef = 0.01 # coefficient for entropy bonus
|
ent_coef = 0.01 # coefficient for entropy bonus
|
||||||
actor_lr = 0.001
|
actor_lr = 0.001
|
||||||
critic_lr = 0.005
|
critic_lr = 0.005
|
||||||
|
Reference in New Issue
Block a user