Bunchostuff
This commit is contained in:
125
main.py
125
main.py
@ -10,6 +10,7 @@ from tqdm import tqdm
|
||||
|
||||
import gymnasium as gym
|
||||
from pettingzoo.mpe import simple_reference_v3
|
||||
import pettingzoo
|
||||
|
||||
class A2C(nn.Module):
|
||||
def __init__(
|
||||
@ -25,26 +26,35 @@ class A2C(nn.Module):
|
||||
self.device = device
|
||||
|
||||
critic_layers = [
|
||||
nn.Linear(n_features, 128),
|
||||
nn.Linear(n_features, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64,128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128,128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 1),
|
||||
]
|
||||
|
||||
actor_layers = [
|
||||
nn.Linear(n_features, 128),
|
||||
nn.Linear(n_features, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64,128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128,128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, n_actions),
|
||||
nn.Softmax()
|
||||
#nn.Sigmoid()
|
||||
]
|
||||
|
||||
self.critic = nn.Sequential(*critic_layers).to(self.device)
|
||||
self.actor = nn.Sequential(*actor_layers).to(self.device)
|
||||
|
||||
self.critic_optim = optim.RMSprop(self.critic.parameters(), lr=critic_lr)
|
||||
self.actor_optim = optim.RMSprop(self.actor.parameters(), lr=actor_lr)
|
||||
self.critic_optim = optim.Adam(self.critic.parameters(), lr=critic_lr)
|
||||
self.actor_optim = optim.Adam(self.actor.parameters(), lr=actor_lr)
|
||||
|
||||
self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=0.9)
|
||||
self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=0.9)
|
||||
self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=1)
|
||||
self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=1)
|
||||
|
||||
def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]:
|
||||
x = torch.Tensor(x).to(self.device)
|
||||
@ -55,14 +65,18 @@ class A2C(nn.Module):
|
||||
def select_action(self, x: np.array) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
|
||||
|
||||
state_values, action_logits = self.forward(x)
|
||||
#action_pd = torch.distributions.Categorical(
|
||||
# logits=action_logits
|
||||
#)
|
||||
#actions = action_pd.sample()
|
||||
actions = torch.multinomial(action_logits, 1).item()
|
||||
action_log_probs = torch.log(action_logits.squeeze(0)[actions])
|
||||
entropy = action_logits * action_log_probs#action_pd.entropy()
|
||||
return actions, action_log_probs, state_values, entropy
|
||||
action_pd = torch.distributions.Categorical(
|
||||
logits=action_logits
|
||||
)
|
||||
actions = action_pd.sample()
|
||||
#actions = torch.multinomial(action_logits, 1).item()
|
||||
#action_log_probs = torch.log(action_logits.squeeze(0)[actions])
|
||||
action_log_probs = action_pd.log_prob(actions)
|
||||
#entropy = action_logits * action_log_probs
|
||||
entropy = action_pd.entropy()
|
||||
#print(entropy.item())
|
||||
#print(action_logits)
|
||||
return actions.item(), action_log_probs, state_values, entropy
|
||||
|
||||
def get_losses(
|
||||
self,
|
||||
@ -75,34 +89,60 @@ class A2C(nn.Module):
|
||||
ent_coef: float,
|
||||
) -> tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
advantages = torch.zeros(len(rewards), device=self.device)
|
||||
T = len(rewards)
|
||||
advantages = torch.zeros(T, device=device)
|
||||
|
||||
# compute the advantages using GAE
|
||||
gae = 0.0
|
||||
for t in reversed(range(T - 1)):
|
||||
td_error = (
|
||||
rewards[t] + gamma * masks[t] * value_preds[t+1] - value_preds[t]
|
||||
)
|
||||
gae = td_error + gamma * 0.95 * masks[t] * gae
|
||||
advantages[t] = gae
|
||||
|
||||
# calculate the loss of the minibatch for actor and critic
|
||||
critic_loss = advantages.pow(2).mean()
|
||||
|
||||
#give a bonus for higher entropy to encourage exploration
|
||||
actor_loss = (
|
||||
-(advantages.detach() * action_log_probs).mean() - ent_coef * torch.Tensor(entropy).mean()
|
||||
)
|
||||
|
||||
|
||||
#advantages = torch.zeros(len(rewards), device=self.device)
|
||||
#compute advantages
|
||||
#mask - 0 if end of episode
|
||||
#gamma - coeffecient for value prediction
|
||||
#rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
|
||||
#for t in range(len(rewards) - 1):
|
||||
# advantages[t] = rewards[t] + masks[t] * gamma * value_preds[t+1]#(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
||||
rewards = np.array(rewards)
|
||||
rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
|
||||
returns = []
|
||||
R = 0
|
||||
for r, mask in zip(reversed(rewards), reversed(masks)):
|
||||
R = r + gamma * R * mask
|
||||
returns.insert(0, R)
|
||||
#advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
||||
#print(advantages[t])
|
||||
#rewards[t] + masks[t] * gamma * value_preds[t+1]
|
||||
#(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t]))
|
||||
#rewards = np.array(rewards)
|
||||
#rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5)
|
||||
#returns = []
|
||||
#R = 0
|
||||
#for r, mask in zip(reversed(rewards), reversed(masks)):
|
||||
# R = r + gamma * R * mask
|
||||
# returns.insert(0, R)
|
||||
|
||||
returns = torch.FloatTensor(returns)
|
||||
values = torch.stack(value_preds).squeeze(1)
|
||||
#returns = torch.FloatTensor(returns)
|
||||
#values = torch.stack(value_preds).squeeze(1)
|
||||
|
||||
advantage = returns - values
|
||||
#advantage = returns - values
|
||||
|
||||
|
||||
#calculate critic loss - MSE
|
||||
#critic_loss = advantages.pow(2).mean()
|
||||
critic_loss = advantage.pow(2).mean()
|
||||
#critic_loss = advantages.pow(2).mean()
|
||||
#calculate actor loss - give bonus for entropy to encourage exploration
|
||||
#actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean()
|
||||
entropy = -torch.stack(entropy).sum(dim=-1).mean()
|
||||
actor_loss = -(action_log_probs * advantage.detach()).mean() - ent_coef * entropy
|
||||
|
||||
#entropy = -torch.stack(entropy).sum(dim=-1).mean()
|
||||
#actor_loss = (-action_log_probs * advantages.detach()).mean() - ent_coef * torch.Tensor(entropy).mean()
|
||||
#print(action_log_probs)
|
||||
#print(actor_loss)
|
||||
return (critic_loss, actor_loss)
|
||||
|
||||
def update_params(self, critic_loss: torch.tensor, actor_loss: torch.tensor) -> None:
|
||||
@ -123,7 +163,7 @@ class A2C(nn.Module):
|
||||
self.actor.eval()
|
||||
|
||||
#environment hyperparams
|
||||
n_episodes = 10000
|
||||
n_episodes = 1000
|
||||
|
||||
#agent hyperparams
|
||||
gamma = 0.999
|
||||
@ -133,7 +173,9 @@ critic_lr = 0.005
|
||||
|
||||
#environment setup
|
||||
#env = simple_reference_v3.parallel_env(render_mode="human")
|
||||
env = simple_reference_v3.parallel_env()
|
||||
env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
|
||||
|
||||
|
||||
|
||||
#obs_space
|
||||
#action_space
|
||||
@ -153,11 +195,14 @@ device = torch.device("cpu")
|
||||
|
||||
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||
|
||||
#print(env.action_space("agent_0").n)
|
||||
#print(env.observation_space("agent_0"))
|
||||
agent0_critic_loss = []
|
||||
agent0_actor_loss = []
|
||||
agent1_critic_loss = []
|
||||
agent1_actor_loss = []
|
||||
agent0_rewards = []
|
||||
agent1_rewards = []
|
||||
|
||||
for episode in range(0, n_episodes):
|
||||
print("Episode " + str(episode) + "/" + str(n_episodes))
|
||||
@ -187,6 +232,7 @@ for episode in range(0, n_episodes):
|
||||
actions["agent_0"] = agent_0_action
|
||||
actions["agent_1"] = agent_1_action
|
||||
observations, rewards, terminations, truncations, infos = env.step(actions)
|
||||
#print(rewards)
|
||||
agent_0_rewards.append(rewards["agent_0"])
|
||||
agent_0_probs.append(agent_0_log_probs)
|
||||
agent_0_pred.append(agent_0_state_val)
|
||||
@ -213,10 +259,15 @@ for episode in range(0, n_episodes):
|
||||
#print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
|
||||
agent1.update_params(agent_1_closs, agent_1_aloss)
|
||||
|
||||
plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
|
||||
plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
|
||||
plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss")
|
||||
plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
|
||||
agent0_rewards.append(np.array(agent_0_rewards).sum())
|
||||
agent1_rewards.append(np.array(agent_1_rewards).sum())
|
||||
|
||||
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
|
||||
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
|
||||
#plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss")
|
||||
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
|
||||
plt.plot(agent0_rewards, label="Agent 0 Rewards")
|
||||
plt.plot(agent1_rewards, label="Agent 1 Rewards")
|
||||
plt.legend()
|
||||
plt.show(block=False)
|
||||
|
||||
|
Reference in New Issue
Block a user