improved plots, and moved training code into separate function

This commit is contained in:
2025-08-31 22:14:07 -06:00
parent 3bd7f70a94
commit fc04bdcd97

321
main.py
View File

@ -90,7 +90,7 @@ class A2C(nn.Module):
) -> tuple[torch.tensor, torch.tensor]: ) -> tuple[torch.tensor, torch.tensor]:
T = len(rewards) T = len(rewards)
advantages = torch.zeros(T, device=device) advantages = torch.zeros(T, device=self.device)
# compute the advantages using GAE # compute the advantages using GAE
gae = 0.0 gae = 0.0
@ -162,41 +162,52 @@ class A2C(nn.Module):
self.critic.eval() self.critic.eval()
self.actor.eval() self.actor.eval()
#environment hyperparams
n_episodes = 100
#agent hyperparams fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(15,5))
gamma = 0.999 fig.suptitle(
ent_coef = 0.01 # coefficient for entropy bonus f"training plots for the Simple Reference environment"
actor_lr = 0.001 )
critic_lr = 0.005
#environment setup def drawPlots():
#env = simple_reference_v3.parallel_env(render_mode="human") rolling_length = 20
env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
#obs_space
#action_space
device = torch.device("cpu")
#init the agent
#agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
#wrapper to record statistics agent0_average = []
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics( agent1_average = []
# env, buffer_length=n_episodes agent0_average_closs = []
#) agent0_average_aloss = []
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) agent1_average_closs = []
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) agent1_average_aloss = []
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) window = 20
for ind in range(len(agent0_rewards) - window + 1):
agent0_average.append(np.mean(agent0_rewards[ind:ind+window]))
for ind in range(len(agent1_rewards) - window + 1):
agent1_average.append(np.mean(agent1_rewards[ind:ind+window]))
for ind in range(len(agent0_critic_loss) - window + 1):
agent0_average_closs.append(np.mean(agent0_critic_loss[ind:ind+window]))
for ind in range(len(agent0_actor_loss) - window + 1):
agent0_average_aloss.append(np.mean(agent0_actor_loss[ind:ind+window]))
for ind in range(len(agent1_critic_loss) - window + 1):
agent1_average_closs.append(np.mean(agent1_critic_loss[ind:ind+window]))
for ind in range(len(agent1_actor_loss) - window + 1):
agent1_average_aloss.append(np.mean(agent1_actor_loss[ind:ind+window]))
axs[0].cla()
axs[0].plot(agent0_average, label="Agent 0")
axs[0].plot(agent1_average, label="Agent 1")
axs[0].legend()
axs[0].set_title("Rewards over Tme")
axs[1].cla()
axs[1].plot(agent0_average_closs, label="Agent 0")
axs[1].plot(agent1_average_closs, label="Agent 1")
axs[1].legend()
axs[1].set_title("Critic Loss over Tme")
axs[2].cla()
axs[2].plot(agent0_average_aloss, label="Agent 0")
axs[2].plot(agent1_average_aloss, label="Agent 1")
axs[2].legend()
axs[2].set_title("Actor Loss over Tme")
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#print(env.action_space("agent_0").n)
#print(env.observation_space("agent_0"))
agent0_critic_loss = [] agent0_critic_loss = []
agent0_actor_loss = [] agent0_actor_loss = []
agent1_critic_loss = [] agent1_critic_loss = []
@ -204,85 +215,137 @@ agent1_actor_loss = []
agent0_rewards = [] agent0_rewards = []
agent1_rewards = [] agent1_rewards = []
for episode in tqdm(range(n_episodes)): def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr):
#print("Episode " + str(episode) + "/" + str(n_episodes)) global agent0_critic_loss
observations, infos = env.reset() global agent0_actor_loss
agent_0_rewards = [] global agent1_critic_loss
agent_0_probs = [] global agent1_actor_loss
agent_0_pred = [] global agent0_rewards
agent_0_ents = [] global agent1_rewards
agent_0_mask = [] agent0_critic_loss = []
agent0_actor_loss = []
agent1_critic_loss = []
agent1_actor_loss = []
agent0_rewards = []
agent1_rewards = []
env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
#obs_space
#action_space
agent_1_rewards = [] device = torch.device("cpu")
agent_1_probs = []
agent_1_pred = [] #init the agent
agent_1_ents = [] #agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
agent_1_mask = []
while env.agents:
actions = {} #wrapper to record statistics
#eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"]) #env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
#bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"]) # env, buffer_length=n_episodes
#alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"]) #)
#actions["eve_0"] = eve_action.item() #eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#actions["bob_0"] = bob_action.item() #bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
#actions["alice_0"] = alice_action.item() #alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
actions["agent_0"] = agent_0_action
actions["agent_1"] = agent_1_action
observations, rewards, terminations, truncations, infos = env.step(actions)
#print(rewards)
agent_0_rewards.append(rewards["agent_0"])
agent_0_probs.append(agent_0_log_probs)
agent_0_pred.append(agent_0_state_val)
agent_0_ents.append(agent_0_ent)
agent_0_mask.append( 1 if env.agents else 0)
agent_1_rewards.append(rewards["agent_1"]) agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent_1_probs.append(agent_1_log_probs) agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
agent_1_pred.append(agent_1_state_val) #print(env.action_space("agent_0").n)
agent_1_ents.append(agent_1_ent) #print(env.observation_space("agent_0"))
agent_1_mask.append( 1 if env.agents else 0)
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
#eve.update_params(eve_closs, eve_aloss)
agent_0_closs, agent_0_aloss = agent0.get_losses(agent_0_rewards, torch.stack(agent_0_probs), agent_0_pred, agent_0_ents, agent_0_mask, gamma, ent_coef)
#print(agent_0_rewards)
agent0_critic_loss.append(agent_0_closs.item())
agent0_actor_loss.append(agent_0_aloss.item())
#print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
agent0.update_params(agent_0_closs, agent_0_aloss)
agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef)
agent1_critic_loss.append(agent_1_closs.item())
agent1_actor_loss.append(agent_1_aloss.item())
#print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
agent1.update_params(agent_1_closs, agent_1_aloss)
agent0_rewards.append(np.array(agent_0_rewards).sum()) for episode in tqdm(range(n_episodes)):
agent1_rewards.append(np.array(agent_1_rewards).sum()) #print("Episode " + str(episode) + "/" + str(n_episodes))
observations, infos = env.reset()
agent_0_rewards = []
agent_0_probs = []
if episode % 500 == 0: agent_0_pred = []
#rolling_length = 20 agent_0_ents = []
#fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,5)) agent_0_mask = []
#fig.suptitle(
# f"training plots for the Simple Reference environment"
#)
plt.plot(agent0_rewards, label="Agent 0 Rewards") agent_1_rewards = []
plt.plot(agent1_rewards, label="Agent 1 Rewards") agent_1_probs = []
plt.legend() agent_1_pred = []
plt.savefig('data.png') agent_1_ents = []
plt.clf() agent_1_mask = []
while env.agents:
actions = {}
#eve_action, eve_log_probs, eve_state_val, eve_ent = eve.select_action(observations["eve_0"])
#bob_action, bob_log_probs, bob_state_val, bob_ent = bob.select_action(observations["bob_0"])
#alice_action, alice_log_probs, alice_state_val, alice_ent = alice.select_action(observations["alice_0"])
#actions["eve_0"] = eve_action.item()
#actions["bob_0"] = bob_action.item()
#actions["alice_0"] = alice_action.item()
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
actions["agent_0"] = agent_0_action
actions["agent_1"] = agent_1_action
observations, rewards, terminations, truncations, infos = env.step(actions)
#print(rewards)
agent_0_rewards.append(rewards["agent_0"])
agent_0_probs.append(agent_0_log_probs)
agent_0_pred.append(agent_0_state_val)
agent_0_ents.append(agent_0_ent)
agent_0_mask.append( 1 if env.agents else 0)
plt.plot(agent0_rewards, label="Agent 0 Rewards") agent_1_rewards.append(rewards["agent_1"])
plt.plot(agent1_rewards, label="Agent 1 Rewards") agent_1_probs.append(agent_1_log_probs)
plt.legend() agent_1_pred.append(agent_1_state_val)
plt.savefig('data.png') agent_1_ents.append(agent_1_ent)
plt.show(block=False) agent_1_mask.append( 1 if env.agents else 0)
#eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef)
#print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item()))
#eve.update_params(eve_closs, eve_aloss)
agent_0_closs, agent_0_aloss = agent0.get_losses(agent_0_rewards, torch.stack(agent_0_probs), agent_0_pred, agent_0_ents, agent_0_mask, gamma, ent_coef)
#print(agent_0_rewards)
agent0_critic_loss.append(agent_0_closs.item())
agent0_actor_loss.append(agent_0_aloss.item())
#print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item()))
agent0.update_params(agent_0_closs, agent_0_aloss)
agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef)
agent1_critic_loss.append(agent_1_closs.item())
agent1_actor_loss.append(agent_1_aloss.item())
#print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item()))
agent1.update_params(agent_1_closs, agent_1_aloss)
agent0_rewards.append(np.array(agent_0_rewards).sum())
agent1_rewards.append(np.array(agent_1_rewards).sum())
if episode % 500 == 0:
drawPlots()
plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png')
drawPlots()
plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png')
env.close()
#environment hyperparams
n_episodes = 1000
train(10000, 0.999, 0, 0.0001, 0.0001)
best = 1
for gamma in np.arange(0.999, 0.99, -0.1):
for ent_coef in np.arange(0, 0.1, 0.01):
for actor_lr in np.arange(0.002,0.1, 0.002):
for critic_lr in np.arange(0.002,0.1,0.002):
#train(n_episodes, gamma, ent_coef, actor_lr, critic_lr)
if best == 1 or agent0_rewards[n_episodes-1] > best:
best = agent0_rewards[n_episodes-1]
print("New Best: " + str(best) + "\n\tWith Parameters (gamma=" + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ')')
#agent hyperparams
#gamma = 0.999
#ent_coef = 0.01 # coefficient for entropy bonus
#actor_lr = 0.001
#critic_lr = 0.005
#environment setup
#env = simple_reference_v3.parallel_env(render_mode="human")
#drawPlots()
#plt.savefig('data.png')
#plt.show(block=False)
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss") #plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss") #plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
@ -290,18 +353,18 @@ plt.show(block=False)
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss") #plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
actor0_weights_path = "weights/actor0_weights.h5" #actor0_weights_path = "weights/actor0_weights.h5"
critic0_weights_path = "weights/critic0_weights.h5" #critic0_weights_path = "weights/critic0_weights.h5"
actor1_weights_path = "weights/actor1_weights.h5" #actor1_weights_path = "weights/actor1_weights.h5"
critic1_weights_path = "weights/critic1_weights.h5" #critic1_weights_path = "weights/critic1_weights.h5"
if not os.path.exists("weights"): #if not os.path.exists("weights"):
os.mkdir("weights") # os.mkdir("weights")
torch.save(agent0.actor.state_dict(), actor0_weights_path) #torch.save(agent0.actor.state_dict(), actor0_weights_path)
torch.save(agent0.critic.state_dict(), critic0_weights_path) #torch.save(agent0.critic.state_dict(), critic0_weights_path)
torch.save(agent1.actor.state_dict(), actor1_weights_path) #torch.save(agent1.actor.state_dict(), actor1_weights_path)
torch.save(agent1.critic.state_dict(), critic1_weights_path) #torch.save(agent1.critic.state_dict(), critic1_weights_path)
#if load_weights: #if load_weights:
# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr) # agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
@ -311,18 +374,18 @@ torch.save(agent1.critic.state_dict(), critic1_weights_path)
# agent.actor.eval() # agent.actor.eval()
# agent.critic.eval() # agent.critic.eval()
agent0.set_eval() #agent0.set_eval()
agent1.set_eval() #agent1.set_eval()
env = simple_reference_v3.parallel_env(render_mode="human") #env = simple_reference_v3.parallel_env(render_mode="human")
while True: #while True:
observations, infos = env.reset() # observations, infos = env.reset()
while env.agents: # while env.agents:
plt.pause(0.001) # plt.pause(0.001)
actions = {} # actions = {}
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0)) # agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0)) # agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
actions["agent_0"] = agent_0_action # actions["agent_0"] = agent_0_action
actions["agent_1"] = agent_1_action # actions["agent_1"] = agent_1_action
observations, rewards, terminations, truncations, infos = env.step(actions) # observations, rewards, terminations, truncations, infos = env.step(actions)
env.close() #env.close()