157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
import torch
|
|
from pettingzoo.mpe import simple_reference_v3,simple_v3
|
|
import numpy as np
|
|
from IPython.display import clear_output
|
|
from IPython.core.debugger import set_trace
|
|
import matplotlib.pyplot as plt
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, observation_space, action_space):
|
|
super(Model, self).__init__()
|
|
self.features = torch.nn.Sequential(
|
|
torch.nn.Linear(observation_space, 32),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(32, 128),
|
|
torch.nn.ReLu()
|
|
)
|
|
|
|
self.critic = torch.nn.Sequential(
|
|
torch.nn.Linear(128, 256),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(256, 1)
|
|
)
|
|
|
|
self.actor = torch.nn.Sequential(
|
|
torch.nn.Linear(128, 256),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(256, action_space)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
value = self.critic(x)
|
|
actions = self.actor(x)
|
|
return value, actions
|
|
|
|
def get_critic(self, x):
|
|
x = self.features(x)
|
|
return self.critic(x)
|
|
|
|
def evaluate_action(self, state, action):
|
|
value, actor_features = self.forward(state)
|
|
dist = torch.distributions.Categorical(actor_features)
|
|
log_probs = dist.log_prob(action).view(-1, 1)
|
|
entropy = dist.entropy().mean()
|
|
|
|
return value, log_rpobs, entropy
|
|
|
|
def act(self, state):
|
|
value, actor_features = self.forward(state)
|
|
dist = torch.distributions.Categorical(actor_features)
|
|
|
|
chosen_action = dist.sample()
|
|
return chosen_action.item()
|
|
|
|
class Memory(object):
|
|
def __init__(self):
|
|
self.states, self.actions, self.true_values = [], [], []
|
|
|
|
def push(self, state, action, true_value):
|
|
self.states.append(state)
|
|
self.actions.append(action)
|
|
self.true_values.append(true_value)
|
|
|
|
def pop_all(self):
|
|
states = torch.stack(self.states)
|
|
actions = LongTensor(self.actions)
|
|
true_values = FloatTensor(self.true_values).unsqueeze(1)
|
|
|
|
self.states, self.actions, self.true_values = [], [], []
|
|
return states, actions, true_values
|
|
|
|
class Worker(object):
|
|
def __init__(self):
|
|
self.env = simple_v3.parallel_env()
|
|
self.episode_reward = 0
|
|
self.state = FloatTensor(self.env.reset()[0])
|
|
|
|
def get_batch(self):
|
|
states, actions, rewards, dones = [], [], [], []
|
|
for _ in range(batch_size):
|
|
action = model.act(torch.Tensor(self.state["agent_0"]).unsqueeze(0))
|
|
actions = []
|
|
actions["agent_0"] = action
|
|
next_state, rewards, terminations, truncations, _ = env.step(actions)
|
|
self.episode_reward += rewards["agent_0"]
|
|
states.append(torch.Tensor(self.state["agent_0"]))
|
|
actions.append(action)
|
|
rewards.append(reward["agent_0"])
|
|
done = terminations["agent_0"] or truncations["agent_0"]
|
|
dones.append(done)
|
|
|
|
if done:
|
|
self.state = FloatTensor(self.env.reset()[0])
|
|
data['episode_rewards'].append(self.episode_reward)
|
|
self.episode_reward = 0
|
|
else:
|
|
self.state = FloatTensor(next_state)
|
|
values = compute_true_values(states, rewards, dones).unsqueeze(1)
|
|
return states, actions, values
|
|
|
|
|
|
|
|
def compute_true_values(states, rewards, dones):
|
|
true_values = []
|
|
rewards = FloatTensor(rewards)
|
|
dones = FloatTensor(dones)
|
|
states = torch.stack(states)
|
|
|
|
if dones[-1] == True:
|
|
next_value = rewards[-1]
|
|
else:
|
|
next_value = model.get_critic(states[-1].unsqueeze(0))
|
|
|
|
true_values.append(next_value)
|
|
for i in reversed(range(0, len(rewards) -1)):
|
|
if not dones[i]:
|
|
next_value = rewards[i] + next_value * gamma
|
|
else:
|
|
next_value = rewards[i]
|
|
true_values.append(next_value)
|
|
true_values.reverse()
|
|
return FloatTensor(true_values)
|
|
|
|
def reflect(memory):
|
|
states, actions, true_values = memory.pop_all()
|
|
values, log_probs, entropy = model.evaluate_action(states, actions)
|
|
advantages = true_values - values
|
|
critic_loss = advantages.pow(2).mean()
|
|
actor_loss = -(log_probs * advantages.detach()).mean()
|
|
total_loss = (critic_coef * critic_loss) + actor_loss - (entropy_coef * entropy)
|
|
optimizer.zero_grad()
|
|
total_loss.backward()
|
|
tourch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
optimizer.step()
|
|
return values.mean().item()
|
|
|
|
def plot(data, frame_idx):
|
|
clear_output(True)
|
|
plt.figure(figsize=(20, 5))
|
|
if data['episode_rewards']:
|
|
ax = plt.subplot(121)
|
|
ax.plt.gca()
|
|
average_score = np.mean(data['episode_rewards'][-100:])
|
|
plt.title(f"Frame: {frame_idx} - Average Store: {average_score}")
|
|
plt.grid()
|
|
plt.plot(data['episode_rewards'])
|
|
if data['values']:
|
|
ax = plt.subplot(122)
|
|
average_value = np.mean(data['values'][-1000:0])
|
|
plt.title(f"Frame: {frame_idx} - Average Values: {average_value}")
|
|
plt.plot(data['values'])
|
|
plt.show()
|
|
|
|
env = simple_v3.parallel_env()
|
|
model = Model(env.observation_space("agent_0").shape[0], env.action_space("agent_0").n)
|
|
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, eps=1e-5)
|