improved plots, and moved training code into separate function
This commit is contained in:
203
main.py
203
main.py
@ -90,7 +90,7 @@ class A2C(nn.Module):
|
|||||||
) -> tuple[torch.tensor, torch.tensor]:
|
) -> tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
T = len(rewards)
|
T = len(rewards)
|
||||||
advantages = torch.zeros(T, device=device)
|
advantages = torch.zeros(T, device=self.device)
|
||||||
|
|
||||||
# compute the advantages using GAE
|
# compute the advantages using GAE
|
||||||
gae = 0.0
|
gae = 0.0
|
||||||
@ -162,41 +162,52 @@ class A2C(nn.Module):
|
|||||||
self.critic.eval()
|
self.critic.eval()
|
||||||
self.actor.eval()
|
self.actor.eval()
|
||||||
|
|
||||||
#environment hyperparams
|
|
||||||
n_episodes = 100
|
|
||||||
|
|
||||||
#agent hyperparams
|
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(15,5))
|
||||||
gamma = 0.999
|
fig.suptitle(
|
||||||
ent_coef = 0.01 # coefficient for entropy bonus
|
f"training plots for the Simple Reference environment"
|
||||||
actor_lr = 0.001
|
)
|
||||||
critic_lr = 0.005
|
|
||||||
|
|
||||||
#environment setup
|
def drawPlots():
|
||||||
#env = simple_reference_v3.parallel_env(render_mode="human")
|
rolling_length = 20
|
||||||
env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
|
|
||||||
|
|
||||||
|
agent0_average = []
|
||||||
|
agent1_average = []
|
||||||
|
agent0_average_closs = []
|
||||||
|
agent0_average_aloss = []
|
||||||
|
agent1_average_closs = []
|
||||||
|
agent1_average_aloss = []
|
||||||
|
window = 20
|
||||||
|
for ind in range(len(agent0_rewards) - window + 1):
|
||||||
|
agent0_average.append(np.mean(agent0_rewards[ind:ind+window]))
|
||||||
|
for ind in range(len(agent1_rewards) - window + 1):
|
||||||
|
agent1_average.append(np.mean(agent1_rewards[ind:ind+window]))
|
||||||
|
for ind in range(len(agent0_critic_loss) - window + 1):
|
||||||
|
agent0_average_closs.append(np.mean(agent0_critic_loss[ind:ind+window]))
|
||||||
|
for ind in range(len(agent0_actor_loss) - window + 1):
|
||||||
|
agent0_average_aloss.append(np.mean(agent0_actor_loss[ind:ind+window]))
|
||||||
|
for ind in range(len(agent1_critic_loss) - window + 1):
|
||||||
|
agent1_average_closs.append(np.mean(agent1_critic_loss[ind:ind+window]))
|
||||||
|
for ind in range(len(agent1_actor_loss) - window + 1):
|
||||||
|
agent1_average_aloss.append(np.mean(agent1_actor_loss[ind:ind+window]))
|
||||||
|
axs[0].cla()
|
||||||
|
axs[0].plot(agent0_average, label="Agent 0")
|
||||||
|
axs[0].plot(agent1_average, label="Agent 1")
|
||||||
|
axs[0].legend()
|
||||||
|
axs[0].set_title("Rewards over Tme")
|
||||||
|
|
||||||
|
axs[1].cla()
|
||||||
|
axs[1].plot(agent0_average_closs, label="Agent 0")
|
||||||
|
axs[1].plot(agent1_average_closs, label="Agent 1")
|
||||||
|
axs[1].legend()
|
||||||
|
axs[1].set_title("Critic Loss over Tme")
|
||||||
|
|
||||||
#obs_space
|
axs[2].cla()
|
||||||
#action_space
|
axs[2].plot(agent0_average_aloss, label="Agent 0")
|
||||||
|
axs[2].plot(agent1_average_aloss, label="Agent 1")
|
||||||
|
axs[2].legend()
|
||||||
|
axs[2].set_title("Actor Loss over Tme")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
|
|
||||||
#init the agent
|
|
||||||
#agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
|
|
||||||
|
|
||||||
#wrapper to record statistics
|
|
||||||
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
|
|
||||||
# env, buffer_length=n_episodes
|
|
||||||
#)
|
|
||||||
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
||||||
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
||||||
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
||||||
|
|
||||||
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
||||||
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
|
||||||
#print(env.action_space("agent_0").n)
|
|
||||||
#print(env.observation_space("agent_0"))
|
|
||||||
agent0_critic_loss = []
|
agent0_critic_loss = []
|
||||||
agent0_actor_loss = []
|
agent0_actor_loss = []
|
||||||
agent1_critic_loss = []
|
agent1_critic_loss = []
|
||||||
@ -204,7 +215,42 @@ agent1_actor_loss = []
|
|||||||
agent0_rewards = []
|
agent0_rewards = []
|
||||||
agent1_rewards = []
|
agent1_rewards = []
|
||||||
|
|
||||||
for episode in tqdm(range(n_episodes)):
|
def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr):
|
||||||
|
global agent0_critic_loss
|
||||||
|
global agent0_actor_loss
|
||||||
|
global agent1_critic_loss
|
||||||
|
global agent1_actor_loss
|
||||||
|
global agent0_rewards
|
||||||
|
global agent1_rewards
|
||||||
|
agent0_critic_loss = []
|
||||||
|
agent0_actor_loss = []
|
||||||
|
agent1_critic_loss = []
|
||||||
|
agent1_actor_loss = []
|
||||||
|
agent0_rewards = []
|
||||||
|
agent1_rewards = []
|
||||||
|
env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array")
|
||||||
|
#obs_space
|
||||||
|
#action_space
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
#init the agent
|
||||||
|
#agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr, n_envs)
|
||||||
|
|
||||||
|
#wrapper to record statistics
|
||||||
|
#env_wrapper_stats = gym.wrappers.vector.RecordEpisodeStatistics(
|
||||||
|
# env, buffer_length=n_episodes
|
||||||
|
#)
|
||||||
|
#eve = A2C(n_features = env.observation_space("eve_0").shape[0], n_actions = env.action_space("eve_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
#bob = A2C(n_features = env.observation_space("bob_0").shape[0], n_actions = env.action_space("bob_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
#alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
|
||||||
|
agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr)
|
||||||
|
#print(env.action_space("agent_0").n)
|
||||||
|
#print(env.observation_space("agent_0"))
|
||||||
|
|
||||||
|
for episode in tqdm(range(n_episodes)):
|
||||||
#print("Episode " + str(episode) + "/" + str(n_episodes))
|
#print("Episode " + str(episode) + "/" + str(n_episodes))
|
||||||
observations, infos = env.reset()
|
observations, infos = env.reset()
|
||||||
agent_0_rewards = []
|
agent_0_rewards = []
|
||||||
@ -265,24 +311,41 @@ for episode in tqdm(range(n_episodes)):
|
|||||||
|
|
||||||
|
|
||||||
if episode % 500 == 0:
|
if episode % 500 == 0:
|
||||||
#rolling_length = 20
|
drawPlots()
|
||||||
#fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,5))
|
plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png')
|
||||||
#fig.suptitle(
|
drawPlots()
|
||||||
# f"training plots for the Simple Reference environment"
|
plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png')
|
||||||
#)
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
plt.plot(agent0_rewards, label="Agent 0 Rewards")
|
|
||||||
plt.plot(agent1_rewards, label="Agent 1 Rewards")
|
|
||||||
plt.legend()
|
|
||||||
plt.savefig('data.png')
|
|
||||||
plt.clf()
|
|
||||||
|
|
||||||
plt.plot(agent0_rewards, label="Agent 0 Rewards")
|
#environment hyperparams
|
||||||
plt.plot(agent1_rewards, label="Agent 1 Rewards")
|
n_episodes = 1000
|
||||||
plt.legend()
|
train(10000, 0.999, 0, 0.0001, 0.0001)
|
||||||
plt.savefig('data.png')
|
best = 1
|
||||||
plt.show(block=False)
|
for gamma in np.arange(0.999, 0.99, -0.1):
|
||||||
|
for ent_coef in np.arange(0, 0.1, 0.01):
|
||||||
|
for actor_lr in np.arange(0.002,0.1, 0.002):
|
||||||
|
for critic_lr in np.arange(0.002,0.1,0.002):
|
||||||
|
#train(n_episodes, gamma, ent_coef, actor_lr, critic_lr)
|
||||||
|
if best == 1 or agent0_rewards[n_episodes-1] > best:
|
||||||
|
best = agent0_rewards[n_episodes-1]
|
||||||
|
print("New Best: " + str(best) + "\n\tWith Parameters (gamma=" + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ')')
|
||||||
|
|
||||||
|
|
||||||
|
#agent hyperparams
|
||||||
|
#gamma = 0.999
|
||||||
|
#ent_coef = 0.01 # coefficient for entropy bonus
|
||||||
|
#actor_lr = 0.001
|
||||||
|
#critic_lr = 0.005
|
||||||
|
|
||||||
|
#environment setup
|
||||||
|
#env = simple_reference_v3.parallel_env(render_mode="human")
|
||||||
|
|
||||||
|
|
||||||
|
#drawPlots()
|
||||||
|
#plt.savefig('data.png')
|
||||||
|
#plt.show(block=False)
|
||||||
|
|
||||||
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
|
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
|
||||||
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
|
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
|
||||||
@ -290,18 +353,18 @@ plt.show(block=False)
|
|||||||
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
|
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
|
||||||
|
|
||||||
|
|
||||||
actor0_weights_path = "weights/actor0_weights.h5"
|
#actor0_weights_path = "weights/actor0_weights.h5"
|
||||||
critic0_weights_path = "weights/critic0_weights.h5"
|
#critic0_weights_path = "weights/critic0_weights.h5"
|
||||||
actor1_weights_path = "weights/actor1_weights.h5"
|
#actor1_weights_path = "weights/actor1_weights.h5"
|
||||||
critic1_weights_path = "weights/critic1_weights.h5"
|
#critic1_weights_path = "weights/critic1_weights.h5"
|
||||||
|
|
||||||
if not os.path.exists("weights"):
|
#if not os.path.exists("weights"):
|
||||||
os.mkdir("weights")
|
# os.mkdir("weights")
|
||||||
|
|
||||||
torch.save(agent0.actor.state_dict(), actor0_weights_path)
|
#torch.save(agent0.actor.state_dict(), actor0_weights_path)
|
||||||
torch.save(agent0.critic.state_dict(), critic0_weights_path)
|
#torch.save(agent0.critic.state_dict(), critic0_weights_path)
|
||||||
torch.save(agent1.actor.state_dict(), actor1_weights_path)
|
#torch.save(agent1.actor.state_dict(), actor1_weights_path)
|
||||||
torch.save(agent1.critic.state_dict(), critic1_weights_path)
|
#torch.save(agent1.critic.state_dict(), critic1_weights_path)
|
||||||
|
|
||||||
#if load_weights:
|
#if load_weights:
|
||||||
# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
|
# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
|
||||||
@ -311,18 +374,18 @@ torch.save(agent1.critic.state_dict(), critic1_weights_path)
|
|||||||
# agent.actor.eval()
|
# agent.actor.eval()
|
||||||
# agent.critic.eval()
|
# agent.critic.eval()
|
||||||
|
|
||||||
agent0.set_eval()
|
#agent0.set_eval()
|
||||||
agent1.set_eval()
|
#agent1.set_eval()
|
||||||
env = simple_reference_v3.parallel_env(render_mode="human")
|
#env = simple_reference_v3.parallel_env(render_mode="human")
|
||||||
while True:
|
#while True:
|
||||||
observations, infos = env.reset()
|
# observations, infos = env.reset()
|
||||||
while env.agents:
|
# while env.agents:
|
||||||
plt.pause(0.001)
|
# plt.pause(0.001)
|
||||||
actions = {}
|
# actions = {}
|
||||||
agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
|
# agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0))
|
||||||
agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
|
# agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0))
|
||||||
actions["agent_0"] = agent_0_action
|
# actions["agent_0"] = agent_0_action
|
||||||
actions["agent_1"] = agent_1_action
|
# actions["agent_1"] = agent_1_action
|
||||||
observations, rewards, terminations, truncations, infos = env.step(actions)
|
# observations, rewards, terminations, truncations, infos = env.step(actions)
|
||||||
|
|
||||||
env.close()
|
#env.close()
|
Reference in New Issue
Block a user