Files
simple_crypto/.ipynb_checkpoints/test-checkpoint.py

157 lines
5.2 KiB
Python

import torch
from pettingzoo.mpe import simple_reference_v3,simple_v3
import numpy as np
from IPython.display import clear_output
from IPython.core.debugger import set_trace
import matplotlib.pyplot as plt
class Model(torch.nn.Module):
def __init__(self, observation_space, action_space):
super(Model, self).__init__()
self.features = torch.nn.Sequential(
torch.nn.Linear(observation_space, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, 128),
torch.nn.ReLu()
)
self.critic = torch.nn.Sequential(
torch.nn.Linear(128, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 1)
)
self.actor = torch.nn.Sequential(
torch.nn.Linear(128, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, action_space)
)
def forward(self, x):
x = self.features(x)
value = self.critic(x)
actions = self.actor(x)
return value, actions
def get_critic(self, x):
x = self.features(x)
return self.critic(x)
def evaluate_action(self, state, action):
value, actor_features = self.forward(state)
dist = torch.distributions.Categorical(actor_features)
log_probs = dist.log_prob(action).view(-1, 1)
entropy = dist.entropy().mean()
return value, log_rpobs, entropy
def act(self, state):
value, actor_features = self.forward(state)
dist = torch.distributions.Categorical(actor_features)
chosen_action = dist.sample()
return chosen_action.item()
class Memory(object):
def __init__(self):
self.states, self.actions, self.true_values = [], [], []
def push(self, state, action, true_value):
self.states.append(state)
self.actions.append(action)
self.true_values.append(true_value)
def pop_all(self):
states = torch.stack(self.states)
actions = LongTensor(self.actions)
true_values = FloatTensor(self.true_values).unsqueeze(1)
self.states, self.actions, self.true_values = [], [], []
return states, actions, true_values
class Worker(object):
def __init__(self):
self.env = simple_v3.parallel_env()
self.episode_reward = 0
self.state = FloatTensor(self.env.reset()[0])
def get_batch(self):
states, actions, rewards, dones = [], [], [], []
for _ in range(batch_size):
action = model.act(torch.Tensor(self.state["agent_0"]).unsqueeze(0))
actions = []
actions["agent_0"] = action
next_state, rewards, terminations, truncations, _ = env.step(actions)
self.episode_reward += rewards["agent_0"]
states.append(torch.Tensor(self.state["agent_0"]))
actions.append(action)
rewards.append(reward["agent_0"])
done = terminations["agent_0"] or truncations["agent_0"]
dones.append(done)
if done:
self.state = FloatTensor(self.env.reset()[0])
data['episode_rewards'].append(self.episode_reward)
self.episode_reward = 0
else:
self.state = FloatTensor(next_state)
values = compute_true_values(states, rewards, dones).unsqueeze(1)
return states, actions, values
def compute_true_values(states, rewards, dones):
true_values = []
rewards = FloatTensor(rewards)
dones = FloatTensor(dones)
states = torch.stack(states)
if dones[-1] == True:
next_value = rewards[-1]
else:
next_value = model.get_critic(states[-1].unsqueeze(0))
true_values.append(next_value)
for i in reversed(range(0, len(rewards) -1)):
if not dones[i]:
next_value = rewards[i] + next_value * gamma
else:
next_value = rewards[i]
true_values.append(next_value)
true_values.reverse()
return FloatTensor(true_values)
def reflect(memory):
states, actions, true_values = memory.pop_all()
values, log_probs, entropy = model.evaluate_action(states, actions)
advantages = true_values - values
critic_loss = advantages.pow(2).mean()
actor_loss = -(log_probs * advantages.detach()).mean()
total_loss = (critic_coef * critic_loss) + actor_loss - (entropy_coef * entropy)
optimizer.zero_grad()
total_loss.backward()
tourch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
return values.mean().item()
def plot(data, frame_idx):
clear_output(True)
plt.figure(figsize=(20, 5))
if data['episode_rewards']:
ax = plt.subplot(121)
ax.plt.gca()
average_score = np.mean(data['episode_rewards'][-100:])
plt.title(f"Frame: {frame_idx} - Average Store: {average_score}")
plt.grid()
plt.plot(data['episode_rewards'])
if data['values']:
ax = plt.subplot(122)
average_value = np.mean(data['values'][-1000:0])
plt.title(f"Frame: {frame_idx} - Average Values: {average_value}")
plt.plot(data['values'])
plt.show()
env = simple_v3.parallel_env()
model = Model(env.observation_space("agent_0").shape[0], env.action_space("agent_0").n)
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, eps=1e-5)