MoreSTUFF

This commit is contained in:
Scirockses
2025-08-31 19:05:39 -06:00
parent 40203becdb
commit 3bd7f70a94
2 changed files with 55 additions and 13 deletions

66
main.py
View File

@ -26,21 +26,21 @@ class A2C(nn.Module):
self.device = device
critic_layers = [
nn.Linear(n_features, 64),
nn.Linear(n_features, 128),
nn.ReLU(),
nn.Linear(64,128),
nn.Linear(128,256),
nn.ReLU(),
nn.Linear(128,128),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128, 1),
]
actor_layers = [
nn.Linear(n_features, 64),
nn.Linear(n_features, 128),
nn.ReLU(),
nn.Linear(64,128),
nn.Linear(128,256),
nn.ReLU(),
nn.Linear(128,128),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128, n_actions),
nn.Softmax()
@ -163,7 +163,7 @@ class A2C(nn.Module):
self.actor.eval()
#environment hyperparams
n_episodes = 1000
n_episodes = 100
#agent hyperparams
gamma = 0.999
@ -204,8 +204,8 @@ agent1_actor_loss = []
agent0_rewards = []
agent1_rewards = []
for episode in range(0, n_episodes):
print("Episode " + str(episode) + "/" + str(n_episodes))
for episode in tqdm(range(n_episodes)):
#print("Episode " + str(episode) + "/" + str(n_episodes))
observations, infos = env.reset()
agent_0_rewards = []
agent_0_probs = []
@ -261,15 +261,55 @@ for episode in range(0, n_episodes):
agent0_rewards.append(np.array(agent_0_rewards).sum())
agent1_rewards.append(np.array(agent_1_rewards).sum())
if episode % 500 == 0:
#rolling_length = 20
#fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,5))
#fig.suptitle(
# f"training plots for the Simple Reference environment"
#)
plt.plot(agent0_rewards, label="Agent 0 Rewards")
plt.plot(agent1_rewards, label="Agent 1 Rewards")
plt.legend()
plt.savefig('data.png')
plt.clf()
plt.plot(agent0_rewards, label="Agent 0 Rewards")
plt.plot(agent1_rewards, label="Agent 1 Rewards")
plt.legend()
plt.savefig('data.png')
plt.show(block=False)
#plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss")
#plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss")
#plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss")
#plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss")
plt.plot(agent0_rewards, label="Agent 0 Rewards")
plt.plot(agent1_rewards, label="Agent 1 Rewards")
plt.legend()
plt.show(block=False)
actor0_weights_path = "weights/actor0_weights.h5"
critic0_weights_path = "weights/critic0_weights.h5"
actor1_weights_path = "weights/actor1_weights.h5"
critic1_weights_path = "weights/critic1_weights.h5"
if not os.path.exists("weights"):
os.mkdir("weights")
torch.save(agent0.actor.state_dict(), actor0_weights_path)
torch.save(agent0.critic.state_dict(), critic0_weights_path)
torch.save(agent1.actor.state_dict(), actor1_weights_path)
torch.save(agent1.critic.state_dict(), critic1_weights_path)
#if load_weights:
# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr)
#
# agent.actor.load_state_dict(torch.load(actor_weights_path))
# agent.critic.load_state_dict(torch.load(critic_weights_path))
# agent.actor.eval()
# agent.critic.eval()
agent0.set_eval()
agent1.set_eval()