From 3bd7f70a941f541d128021dc6334fd046c8273db Mon Sep 17 00:00:00 2001 From: Scirockses <33169720+Scirockses@users.noreply.github.com> Date: Sun, 31 Aug 2025 19:05:39 -0600 Subject: [PATCH] MoreSTUFF --- .gitignore | 2 ++ main.py | 66 +++++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 55 insertions(+), 13 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6276225 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.png +*.h5 \ No newline at end of file diff --git a/main.py b/main.py index c13ca66..73680b0 100644 --- a/main.py +++ b/main.py @@ -26,21 +26,21 @@ class A2C(nn.Module): self.device = device critic_layers = [ - nn.Linear(n_features, 64), + nn.Linear(n_features, 128), nn.ReLU(), - nn.Linear(64,128), + nn.Linear(128,256), nn.ReLU(), - nn.Linear(128,128), + nn.Linear(256,128), nn.ReLU(), nn.Linear(128, 1), ] actor_layers = [ - nn.Linear(n_features, 64), + nn.Linear(n_features, 128), nn.ReLU(), - nn.Linear(64,128), + nn.Linear(128,256), nn.ReLU(), - nn.Linear(128,128), + nn.Linear(256,128), nn.ReLU(), nn.Linear(128, n_actions), nn.Softmax() @@ -163,7 +163,7 @@ class A2C(nn.Module): self.actor.eval() #environment hyperparams -n_episodes = 1000 +n_episodes = 100 #agent hyperparams gamma = 0.999 @@ -204,8 +204,8 @@ agent1_actor_loss = [] agent0_rewards = [] agent1_rewards = [] -for episode in range(0, n_episodes): - print("Episode " + str(episode) + "/" + str(n_episodes)) +for episode in tqdm(range(n_episodes)): + #print("Episode " + str(episode) + "/" + str(n_episodes)) observations, infos = env.reset() agent_0_rewards = [] agent_0_probs = [] @@ -261,15 +261,55 @@ for episode in range(0, n_episodes): agent0_rewards.append(np.array(agent_0_rewards).sum()) agent1_rewards.append(np.array(agent_1_rewards).sum()) + + + + if episode % 500 == 0: + #rolling_length = 20 + #fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(12,5)) + #fig.suptitle( + # f"training plots for the Simple Reference environment" + #) + + + plt.plot(agent0_rewards, label="Agent 0 Rewards") + plt.plot(agent1_rewards, label="Agent 1 Rewards") + plt.legend() + plt.savefig('data.png') + plt.clf() + +plt.plot(agent0_rewards, label="Agent 0 Rewards") +plt.plot(agent1_rewards, label="Agent 1 Rewards") +plt.legend() +plt.savefig('data.png') +plt.show(block=False) #plt.plot(agent0_critic_loss, label="Agent 0 Critic Loss") #plt.plot(agent0_actor_loss, label="Agent 0 Actor Loss") #plt.plot(agent1_critic_loss, label="Agent 1 Critic Loss") #plt.plot(agent1_actor_loss, label="Agent 1 Actor Loss") -plt.plot(agent0_rewards, label="Agent 0 Rewards") -plt.plot(agent1_rewards, label="Agent 1 Rewards") -plt.legend() -plt.show(block=False) + + +actor0_weights_path = "weights/actor0_weights.h5" +critic0_weights_path = "weights/critic0_weights.h5" +actor1_weights_path = "weights/actor1_weights.h5" +critic1_weights_path = "weights/critic1_weights.h5" + +if not os.path.exists("weights"): + os.mkdir("weights") + +torch.save(agent0.actor.state_dict(), actor0_weights_path) +torch.save(agent0.critic.state_dict(), critic0_weights_path) +torch.save(agent1.actor.state_dict(), actor1_weights_path) +torch.save(agent1.critic.state_dict(), critic1_weights_path) + +#if load_weights: +# agent = A2C(obs_shape, action_shape, device, critic_lr, actor_lr) +# +# agent.actor.load_state_dict(torch.load(actor_weights_path)) +# agent.critic.load_state_dict(torch.load(critic_weights_path)) +# agent.actor.eval() +# agent.critic.eval() agent0.set_eval() agent1.set_eval()