MoreSTUFF
This commit is contained in:
66
main.py
66
main.py
@ -26,21 +26,21 @@ class A2C(nn.Module):
|
||||
self.device = device
|
||||
|
||||
critic_layers = [
|
||||
nn.Linear(n_features, 64),
|
||||
nn.Linear(n_features, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64,128),
|
||||
nn.Linear(128,256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128,128),
|
||||
nn.Linear(256,128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 1),
|
||||
]
|
||||
|
||||
actor_layers = [
|
||||
nn.Linear(n_features, 64),
|
||||
nn.Linear(n_features, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64,128),
|
||||
nn.Linear(128,256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128,128),
|
||||
nn.Linear(256,128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, n_actions),
|
||||
nn.Softmax()
|
||||
@ -163,7 +163,7 @@ class A2C(nn.Module):
|
||||
self.actor.eval()
|
||||
|
||||
#environment hyperparams
|
||||
n_episodes = 1000
|
||||
n_episodes = 100
|
||||
|
||||
#agent hyperparams
|
||||
gamma = 0.999
|
||||
@ -204,8 +204,8 @@ agent1_actor_loss = []
|
||||
agent0_rewards = []
|
||||
agent1_rewards = []
|
||||
|
||||
for episode in range(0, n_episodes):
|
||||
print("Episode " + str(episode) + "/" + str(n_episodes))
|
||||
for episode in tqdm(range(n_episodes)):
|
||||
#print("Episode " + str(episode) + "/" + str(n_episodes))
|
||||
observations, infos = env.reset()
|
||||
agent_0_rewards = []
|
||||
agent_0_probs = []
|
||||
@ -261,15 +261,55 @@ for episode in range(0, n_episodes):
|
||||
|
||||
agent0_rewards.append(np.array(agent_0_rewards).sum())
|
||||
agent1_rewards.append(np.array(agent_1_rewards).sum())
|
||||
|
||||
|
||||
|
||||
if episode % 500 == 0:
|
||||
#rolling_length = 20
|
||||
#fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,5))
|
||||
#fig.suptitle(
|
||||
# f"training plots for the Simple Reference environment"
|
||||
#)
|
||||
|
||||
|
||||
plt.plot(agent0_rewards, label="Agent 0 Rewards")
|
||||
plt.plot(agent1_rewards, label="Agent 1 Rewards")
|
||||
plt.legend()
|
||||
plt.savefig('data.png')
|
||||
plt.clf()
|
||||
|
||||
plt.plot(agent0_rewards, label="Agent 0 Rewards")
|
||||
plt.plot(agent1_rewards, label="Agent 1 Rewards")
|
||||
plt.legend()
|
||||
plt.savefig('data.png')
|
||||
plt.show(block=False)
|
||||
|
||||
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
|
||||
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
|
||||
#plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss")
|
||||
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
|
||||
plt.plot(agent0_rewards, label="Agent 0 Rewards")
|
||||
plt.plot(agent1_rewards, label="Agent 1 Rewards")
|
||||
plt.legend()
|
||||
plt.show(block=False)
|
||||
|
||||
|
||||
actor0_weights_path = "weights/actor0_weights.h5"
|
||||
critic0_weights_path = "weights/critic0_weights.h5"
|
||||
actor1_weights_path = "weights/actor1_weights.h5"
|
||||
critic1_weights_path = "weights/critic1_weights.h5"
|
||||
|
||||
if not os.path.exists("weights"):
|
||||
os.mkdir("weights")
|
||||
|
||||
torch.save(agent0.actor.state_dict(), actor0_weights_path)
|
||||
torch.save(agent0.critic.state_dict(), critic0_weights_path)
|
||||
torch.save(agent1.actor.state_dict(), actor1_weights_path)
|
||||
torch.save(agent1.critic.state_dict(), critic1_weights_path)
|
||||
|
||||
#if load_weights:
|
||||
# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
|
||||
#
|
||||
# agent.actor.load_state_dict(torch.load(actor_weights_path))
|
||||
# agent.critic.load_state_dict(torch.load(critic_weights_path))
|
||||
# agent.actor.eval()
|
||||
# agent.critic.eval()
|
||||
|
||||
agent0.set_eval()
|
||||
agent1.set_eval()
|
||||
|
Reference in New Issue
Block a user