diff --git a/.gitignore b/.gitignore index 6276225..33e399a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.png -*.h5 \ No newline at end of file +*.h5 +*.pt \ No newline at end of file diff --git a/.ipynb_checkpoints/Test-checkpoint.ipynb b/.ipynb_checkpoints/Test-checkpoint.ipynb new file mode 100644 index 0000000..dcbf104 --- /dev/null +++ b/.ipynb_checkpoints/Test-checkpoint.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5ebc5698-7678-4730-8bdb-26d39cc3969d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/piwalker/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pygame/pkgdata.py:25: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " from pkg_resources import resource_stream, resource_exists\n" + ] + } + ], + "source": [ + "import torch\n", + "from pettingzoo.mpe import simple_reference_v3,simple_v3\n", + "import numpy as np\n", + "from IPython.display import clear_output\n", + "from IPython.core.debugger import set_trace\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c1eaeab6-4b9e-4b5c-8db2-bd79b3887413", + "metadata": {}, + "outputs": [], + "source": [ + "max_frames = 5000000\n", + "batch_size = 5\n", + "learning_rate = 7e-4\n", + "gamma = 0.99\n", + "entropy_coef = 0.01\n", + "critic_coef = 0.5\n", + "no_of_workers = 16" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5da2c091-9ac7-4142-803c-4e8b2b4c586e", + "metadata": {}, + "outputs": [], + "source": [ + "FloatTensor = torch.FloatTensor\n", + "LongTensor = torch.LongTensor" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "04f39958-d3f8-4484-a7fd-97f7da05cd4b", + "metadata": {}, + "outputs": [], + "source": [ + "class Model(torch.nn.Module):\n", + " def __init__(self, observation_space, action_space):\n", + " super(Model, self).__init__()\n", + " self.features = torch.nn.Sequential(\n", + " torch.nn.Linear(observation_space, 32),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(32, 128),\n", + " torch.nn.ReLU()\n", + " )\n", + "\n", + " self.critic = torch.nn.Sequential(\n", + " torch.nn.Linear(128, 256),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(256, 1)\n", + " )\n", + "\n", + " self.actor = torch.nn.Sequential(\n", + " torch.nn.Linear(128, 256),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(256, action_space),\n", + " torch.nn.Softmax(dim=-1)\n", + " )\n", + " \n", + " def forward(self, x):\n", + " x = self.features(x)\n", + " value = self.critic(x)\n", + " actions = self.actor(x)\n", + " return value, actions\n", + "\n", + " def get_critic(self, x):\n", + " x = self.features(x)\n", + " return self.critic(x)\n", + " \n", + " def evaluate_action(self, state, action):\n", + " value, actor_features = self.forward(state)\n", + " dist = torch.distributions.Categorical(actor_features)\n", + " log_probs = dist.log_prob(action).view(-1, 1)\n", + " entropy = dist.entropy().mean()\n", + "\n", + " return value, log_probs, entropy\n", + " \n", + " def act(self, state):\n", + " value, actor_features = self.forward(state)\n", + " dist = torch.distributions.Categorical(actor_features)\n", + "\n", + " chosen_action = dist.sample()\n", + " return chosen_action.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "58dcf7c9-6405-4d83-9ea5-bdebdd614db8", + "metadata": {}, + "outputs": [], + "source": [ + "class Memory(object):\n", + " def __init__(self):\n", + " self.states, self.actions, self.true_values = [], [], []\n", + " \n", + " def push(self, state, action, true_value):\n", + " self.states.append(state)\n", + " self.actions.append(action)\n", + " self.true_values.append(true_value)\n", + " \n", + " def pop_all(self):\n", + " states = torch.stack(self.states)\n", + " actions = LongTensor(self.actions)\n", + " true_values = FloatTensor(self.true_values).unsqueeze(1)\n", + "\n", + " self.states, self.actions, self.true_values = [], [], []\n", + " return states, actions, true_values" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4edd19f6-193e-4183-9eab-498f176b32a2", + "metadata": {}, + "outputs": [], + "source": [ + "class Worker(object):\n", + " def __init__(self):\n", + " self.env = simple_v3.parallel_env()\n", + " self.episode_reward = 0\n", + " self.state = FloatTensor(self.env.reset()[0][\"agent_0\"])\n", + "\n", + " def get_batch(self):\n", + " states, actions, rewards, dones = [], [], [], []\n", + " for _ in range(batch_size):\n", + " action = model.act(self.state.unsqueeze(0))\n", + " actiondict = {}\n", + " actiondict[\"agent_0\"] = action\n", + " next_state, reward, terminations, truncations, _ = self.env.step(actiondict)\n", + " self.episode_reward += reward[\"agent_0\"]\n", + " states.append(self.state)\n", + " actions.append(action)\n", + " rewards.append(reward[\"agent_0\"])\n", + " done = False if \"agent_0\" in terminations else True\n", + " dones.append(done)\n", + "\n", + " if done:\n", + " self.state = FloatTensor(self.env.reset()[0][\"agent_0\"])\n", + " data['episode_rewards'].append(self.episode_reward)\n", + " self.episode_reward = 0\n", + " else:\n", + " self.state = FloatTensor(next_state[\"agent_0\"])\n", + " values = compute_true_values(states, rewards, dones).unsqueeze(1)\n", + " return states, actions, values" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a5bb4113-8511-437b-ae9d-1e9f45005073", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_true_values(states, rewards, dones):\n", + " true_values = []\n", + " rewards = FloatTensor(rewards)\n", + " dones = FloatTensor(dones)\n", + " states = torch.stack(states)\n", + "\n", + " if dones[-1] == True:\n", + " next_value = rewards[-1]\n", + " else:\n", + " next_value = model.get_critic(states[-1].unsqueeze(0))\n", + "\n", + " true_values.append(next_value)\n", + " for i in reversed(range(0, len(rewards) -1)):\n", + " if not dones[i]:\n", + " next_value = rewards[i] + next_value * gamma\n", + " else:\n", + " next_value = rewards[i]\n", + " true_values.append(next_value)\n", + " true_values.reverse()\n", + " return FloatTensor(true_values)\n", + "\n", + "def reflect(memory):\n", + " states, actions, true_values = memory.pop_all()\n", + " values, log_probs, entropy = model.evaluate_action(states, actions)\n", + " advantages = true_values - values\n", + " critic_loss = advantages.pow(2).mean()\n", + " actor_loss = -(log_probs * advantages.detach()).mean()\n", + " total_loss = (critic_coef * critic_loss) + actor_loss - (entropy_coef * entropy)\n", + " optimizer.zero_grad()\n", + " total_loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n", + " optimizer.step()\n", + " return values.mean().item()\n", + "\n", + "def plot(data, frame_idx):\n", + " clear_output(True)\n", + " plt.figure(figsize=(20, 5))\n", + " if data['episode_rewards']:\n", + " ax = plt.subplot(121)\n", + " ax = plt.gca()\n", + " average_score = np.mean(data['episode_rewards'][-100:])\n", + " plt.title(f\"Frame: {frame_idx} - Average Store: {average_score}\")\n", + " plt.grid()\n", + " plt.plot(data['episode_rewards'])\n", + " if data['values']:\n", + " ax = plt.subplot(122)\n", + " average_value = np.mean(data['values'][-1000:])\n", + " plt.title(f\"Frame: {frame_idx} - Average Values: {average_value}\")\n", + " plt.plot(data['values'])\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3f0c10ca-c344-450d-ad78-bd418860ea5d", + "metadata": {}, + "outputs": [], + "source": [ + "env = simple_v3.parallel_env()\n", + "model = Model(env.observation_space(\"agent_0\").shape[0], env.action_space(\"agent_0\").n)\n", + "optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, eps=1e-5)\n", + "memory = Memory()\n", + "workers = []\n", + "for _ in range(no_of_workers):\n", + " workers.append(Worker())\n", + "frame_idx = 0\n", + "data = {\n", + " 'episode_rewards': [],\n", + " 'values': []\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b3bae48-f47e-4b8b-be8a-271d7e6871c0", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%debug\n", + "state = FloatTensor(env.reset()[0][\"agent_0\"])\n", + "episode_reward = 0\n", + "while frame_idx < max_frames:\n", + " for worker in workers:\n", + " states, actions, true_values = worker.get_batch()\n", + " for i, _ in enumerate(states):\n", + " memory.push(states[i], actions[i], true_values[i])\n", + " frame_idx += batch_size\n", + " value = reflect(memory)\n", + " data['values'].append(value)\n", + " if frame_idx % 1000 == 0:\n", + " plot(data, frame_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c6db61c0-6828-4283-855d-0d734b36a1fa", + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(model, \"MyModel.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9ef651cb-69d4-4c37-a92b-de5d4a9fe5a2", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'self' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m env = simple_v3.parallel_env(render_mode=\u001b[33m\"\u001b[39m\u001b[33mhuman\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28mself\u001b[39m.state = FloatTensor(\u001b[38;5;28;43mself\u001b[39;49m.env.reset()[\u001b[32m0\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33magent_0\u001b[39m\u001b[33m\"\u001b[39m])\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m env.agents():\n\u001b[32m 5\u001b[39m action = model.act(\u001b[38;5;28mself\u001b[39m.state.unsqueeze(\u001b[32m0\u001b[39m))\n", + "\u001b[31mNameError\u001b[39m: name 'self' is not defined" + ] + } + ], + "source": [ + "env = simple_v3.parallel_env(render_mode=\"human\")\n", + "while True:\n", + " state = FloatTensor(env.reset()[0][\"agent_0\"])\n", + " while env.agents():\n", + " action = model.act(state.unsqueeze(0))\n", + " actiondict = {}\n", + " actiondict[\"agent_0\"] = action\n", + " next_state, reward, terminations, truncations, _ = env.step(actiondict)\n", + " done = False if \"agent_0\" in terminations else True\n", + "\n", + " if done:\n", + " state = FloatTensor(env.reset()[0][\"agent_0\"])\n", + " else:\n", + " state = FloatTensor(next_state[\"agent_0\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "machine_learning", + "language": "python", + "name": "machine_learning" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/.ipynb_checkpoints/test-checkpoint.py b/.ipynb_checkpoints/test-checkpoint.py new file mode 100644 index 0000000..8e5e7f1 --- /dev/null +++ b/.ipynb_checkpoints/test-checkpoint.py @@ -0,0 +1,156 @@ +import torch +from pettingzoo.mpe import simple_reference_v3,simple_v3 +import numpy as np +from IPython.display import clear_output +from IPython.core.debugger import set_trace +import matplotlib.pyplot as plt + +class Model(torch.nn.Module): + def __init__(self, observation_space, action_space): + super(Model, self).__init__() + self.features = torch.nn.Sequential( + torch.nn.Linear(observation_space, 32), + torch.nn.ReLU(), + torch.nn.Linear(32, 128), + torch.nn.ReLu() + ) + + self.critic = torch.nn.Sequential( + torch.nn.Linear(128, 256), + torch.nn.ReLU(), + torch.nn.Linear(256, 1) + ) + + self.actor = torch.nn.Sequential( + torch.nn.Linear(128, 256), + torch.nn.ReLU(), + torch.nn.Linear(256, action_space) + ) + + def forward(self, x): + x = self.features(x) + value = self.critic(x) + actions = self.actor(x) + return value, actions + + def get_critic(self, x): + x = self.features(x) + return self.critic(x) + + def evaluate_action(self, state, action): + value, actor_features = self.forward(state) + dist = torch.distributions.Categorical(actor_features) + log_probs = dist.log_prob(action).view(-1, 1) + entropy = dist.entropy().mean() + + return value, log_rpobs, entropy + + def act(self, state): + value, actor_features = self.forward(state) + dist = torch.distributions.Categorical(actor_features) + + chosen_action = dist.sample() + return chosen_action.item() + +class Memory(object): + def __init__(self): + self.states, self.actions, self.true_values = [], [], [] + + def push(self, state, action, true_value): + self.states.append(state) + self.actions.append(action) + self.true_values.append(true_value) + + def pop_all(self): + states = torch.stack(self.states) + actions = LongTensor(self.actions) + true_values = FloatTensor(self.true_values).unsqueeze(1) + + self.states, self.actions, self.true_values = [], [], [] + return states, actions, true_values + +class Worker(object): + def __init__(self): + self.env = simple_v3.parallel_env() + self.episode_reward = 0 + self.state = FloatTensor(self.env.reset()[0]) + + def get_batch(self): + states, actions, rewards, dones = [], [], [], [] + for _ in range(batch_size): + action = model.act(torch.Tensor(self.state["agent_0"]).unsqueeze(0)) + actions = [] + actions["agent_0"] = action + next_state, rewards, terminations, truncations, _ = env.step(actions) + self.episode_reward += rewards["agent_0"] + states.append(torch.Tensor(self.state["agent_0"])) + actions.append(action) + rewards.append(reward["agent_0"]) + done = terminations["agent_0"] or truncations["agent_0"] + dones.append(done) + + if done: + self.state = FloatTensor(self.env.reset()[0]) + data['episode_rewards'].append(self.episode_reward) + self.episode_reward = 0 + else: + self.state = FloatTensor(next_state) + values = compute_true_values(states, rewards, dones).unsqueeze(1) + return states, actions, values + + + +def compute_true_values(states, rewards, dones): + true_values = [] + rewards = FloatTensor(rewards) + dones = FloatTensor(dones) + states = torch.stack(states) + + if dones[-1] == True: + next_value = rewards[-1] + else: + next_value = model.get_critic(states[-1].unsqueeze(0)) + + true_values.append(next_value) + for i in reversed(range(0, len(rewards) -1)): + if not dones[i]: + next_value = rewards[i] + next_value * gamma + else: + next_value = rewards[i] + true_values.append(next_value) + true_values.reverse() + return FloatTensor(true_values) + +def reflect(memory): + states, actions, true_values = memory.pop_all() + values, log_probs, entropy = model.evaluate_action(states, actions) + advantages = true_values - values + critic_loss = advantages.pow(2).mean() + actor_loss = -(log_probs * advantages.detach()).mean() + total_loss = (critic_coef * critic_loss) + actor_loss - (entropy_coef * entropy) + optimizer.zero_grad() + total_loss.backward() + tourch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) + optimizer.step() + return values.mean().item() + +def plot(data, frame_idx): + clear_output(True) + plt.figure(figsize=(20, 5)) + if data['episode_rewards']: + ax = plt.subplot(121) + ax.plt.gca() + average_score = np.mean(data['episode_rewards'][-100:]) + plt.title(f"Frame: {frame_idx} - Average Store: {average_score}") + plt.grid() + plt.plot(data['episode_rewards']) + if data['values']: + ax = plt.subplot(122) + average_value = np.mean(data['values'][-1000:0]) + plt.title(f"Frame: {frame_idx} - Average Values: {average_value}") + plt.plot(data['values']) + plt.show() + +env = simple_v3.parallel_env() +model = Model(env.observation_space("agent_0").shape[0], env.action_space("agent_0").n) +optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, eps=1e-5) diff --git a/Test.ipynb b/Test.ipynb new file mode 100644 index 0000000..502a8e9 --- /dev/null +++ b/Test.ipynb @@ -0,0 +1,360 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5ebc5698-7678-4730-8bdb-26d39cc3969d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/piwalker/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pygame/pkgdata.py:25: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " from pkg_resources import resource_stream, resource_exists\n" + ] + } + ], + "source": [ + "import torch\n", + "from pettingzoo.mpe import simple_reference_v3,simple_v3\n", + "import numpy as np\n", + "from IPython.display import clear_output\n", + "from IPython.core.debugger import set_trace\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c1eaeab6-4b9e-4b5c-8db2-bd79b3887413", + "metadata": {}, + "outputs": [], + "source": [ + "max_frames = 5000000\n", + "batch_size = 5\n", + "learning_rate = 7e-4\n", + "gamma = 0.99\n", + "entropy_coef = 0.01\n", + "critic_coef = 0.5\n", + "no_of_workers = 16" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5da2c091-9ac7-4142-803c-4e8b2b4c586e", + "metadata": {}, + "outputs": [], + "source": [ + "FloatTensor = torch.FloatTensor\n", + "LongTensor = torch.LongTensor" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "04f39958-d3f8-4484-a7fd-97f7da05cd4b", + "metadata": {}, + "outputs": [], + "source": [ + "class Model(torch.nn.Module):\n", + " def __init__(self, observation_space, action_space):\n", + " super(Model, self).__init__()\n", + " self.features = torch.nn.Sequential(\n", + " torch.nn.Linear(observation_space, 32),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(32, 128),\n", + " torch.nn.ReLU()\n", + " )\n", + "\n", + " self.critic = torch.nn.Sequential(\n", + " torch.nn.Linear(128, 256),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(256, 1)\n", + " )\n", + "\n", + " self.actor = torch.nn.Sequential(\n", + " torch.nn.Linear(128, 256),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(256, action_space),\n", + " torch.nn.Softmax(dim=-1)\n", + " )\n", + " \n", + " def forward(self, x):\n", + " x = self.features(x)\n", + " value = self.critic(x)\n", + " actions = self.actor(x)\n", + " return value, actions\n", + "\n", + " def get_critic(self, x):\n", + " x = self.features(x)\n", + " return self.critic(x)\n", + " \n", + " def evaluate_action(self, state, action):\n", + " value, actor_features = self.forward(state)\n", + " dist = torch.distributions.Categorical(actor_features)\n", + " log_probs = dist.log_prob(action).view(-1, 1)\n", + " entropy = dist.entropy().mean()\n", + "\n", + " return value, log_probs, entropy\n", + " \n", + " def act(self, state):\n", + " value, actor_features = self.forward(state)\n", + " dist = torch.distributions.Categorical(actor_features)\n", + "\n", + " chosen_action = dist.sample()\n", + " return chosen_action.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "58dcf7c9-6405-4d83-9ea5-bdebdd614db8", + "metadata": {}, + "outputs": [], + "source": [ + "class Memory(object):\n", + " def __init__(self):\n", + " self.states, self.actions, self.true_values = [], [], []\n", + " \n", + " def push(self, state, action, true_value):\n", + " self.states.append(state)\n", + " self.actions.append(action)\n", + " self.true_values.append(true_value)\n", + " \n", + " def pop_all(self):\n", + " states = torch.stack(self.states)\n", + " actions = LongTensor(self.actions)\n", + " true_values = FloatTensor(self.true_values).unsqueeze(1)\n", + "\n", + " self.states, self.actions, self.true_values = [], [], []\n", + " return states, actions, true_values" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4edd19f6-193e-4183-9eab-498f176b32a2", + "metadata": {}, + "outputs": [], + "source": [ + "class Worker(object):\n", + " def __init__(self):\n", + " self.env = simple_v3.parallel_env()\n", + " self.episode_reward = 0\n", + " self.state = FloatTensor(self.env.reset()[0][\"agent_0\"])\n", + "\n", + " def get_batch(self):\n", + " states, actions, rewards, dones = [], [], [], []\n", + " for _ in range(batch_size):\n", + " action = model.act(self.state.unsqueeze(0))\n", + " actiondict = {}\n", + " actiondict[\"agent_0\"] = action\n", + " next_state, reward, terminations, truncations, _ = self.env.step(actiondict)\n", + " self.episode_reward += reward[\"agent_0\"]\n", + " states.append(self.state)\n", + " actions.append(action)\n", + " rewards.append(reward[\"agent_0\"])\n", + " done = False if \"agent_0\" in terminations else True\n", + " dones.append(done)\n", + "\n", + " if done:\n", + " self.state = FloatTensor(self.env.reset()[0][\"agent_0\"])\n", + " data['episode_rewards'].append(self.episode_reward)\n", + " self.episode_reward = 0\n", + " else:\n", + " self.state = FloatTensor(next_state[\"agent_0\"])\n", + " values = compute_true_values(states, rewards, dones).unsqueeze(1)\n", + " return states, actions, values" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a5bb4113-8511-437b-ae9d-1e9f45005073", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_true_values(states, rewards, dones):\n", + " true_values = []\n", + " rewards = FloatTensor(rewards)\n", + " dones = FloatTensor(dones)\n", + " states = torch.stack(states)\n", + "\n", + " if dones[-1] == True:\n", + " next_value = rewards[-1]\n", + " else:\n", + " next_value = model.get_critic(states[-1].unsqueeze(0))\n", + "\n", + " true_values.append(next_value)\n", + " for i in reversed(range(0, len(rewards) -1)):\n", + " if not dones[i]:\n", + " next_value = rewards[i] + next_value * gamma\n", + " else:\n", + " next_value = rewards[i]\n", + " true_values.append(next_value)\n", + " true_values.reverse()\n", + " return FloatTensor(true_values)\n", + "\n", + "def reflect(memory):\n", + " states, actions, true_values = memory.pop_all()\n", + " values, log_probs, entropy = model.evaluate_action(states, actions)\n", + " advantages = true_values - values\n", + " critic_loss = advantages.pow(2).mean()\n", + " actor_loss = -(log_probs * advantages.detach()).mean()\n", + " total_loss = (critic_coef * critic_loss) + actor_loss - (entropy_coef * entropy)\n", + " optimizer.zero_grad()\n", + " total_loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n", + " optimizer.step()\n", + " return values.mean().item()\n", + "\n", + "def plot(data, frame_idx):\n", + " clear_output(True)\n", + " plt.figure(figsize=(20, 5))\n", + " if data['episode_rewards']:\n", + " ax = plt.subplot(121)\n", + " ax = plt.gca()\n", + " average_score = np.mean(data['episode_rewards'][-100:])\n", + " plt.title(f\"Frame: {frame_idx} - Average Store: {average_score}\")\n", + " plt.grid()\n", + " plt.plot(data['episode_rewards'])\n", + " if data['values']:\n", + " ax = plt.subplot(122)\n", + " average_value = np.mean(data['values'][-1000:])\n", + " plt.title(f\"Frame: {frame_idx} - Average Values: {average_value}\")\n", + " plt.plot(data['values'])\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3f0c10ca-c344-450d-ad78-bd418860ea5d", + "metadata": {}, + "outputs": [], + "source": [ + "env = simple_v3.parallel_env()\n", + "model = Model(env.observation_space(\"agent_0\").shape[0], env.action_space(\"agent_0\").n)\n", + "optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, eps=1e-5)\n", + "memory = Memory()\n", + "workers = []\n", + "for _ in range(no_of_workers):\n", + " workers.append(Worker())\n", + "frame_idx = 0\n", + "data = {\n", + " 'episode_rewards': [],\n", + " 'values': []\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b3bae48-f47e-4b8b-be8a-271d7e6871c0", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%debug\n", + "state = FloatTensor(env.reset()[0][\"agent_0\"])\n", + "episode_reward = 0\n", + "while frame_idx < max_frames:\n", + " for worker in workers:\n", + " states, actions, true_values = worker.get_batch()\n", + " for i, _ in enumerate(states):\n", + " memory.push(states[i], actions[i], true_values[i])\n", + " frame_idx += batch_size\n", + " value = reflect(memory)\n", + " data['values'].append(value)\n", + " if frame_idx % 1000 == 0:\n", + " plot(data, frame_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c6db61c0-6828-4283-855d-0d734b36a1fa", + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(model, \"MyModel.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "9ef651cb-69d4-4c37-a92b-de5d4a9fe5a2", + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[15]\u001b[39m\u001b[32m, line 8\u001b[39m\n\u001b[32m 6\u001b[39m actiondict = {}\n\u001b[32m 7\u001b[39m actiondict[\u001b[33m\"\u001b[39m\u001b[33magent_0\u001b[39m\u001b[33m\"\u001b[39m] = action\n\u001b[32m----> \u001b[39m\u001b[32m8\u001b[39m next_state, reward, terminations, truncations, _ = \u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mactiondict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 9\u001b[39m done = \u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33magent_0\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m terminations \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m 11\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m done:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pettingzoo/utils/conversions.py:207\u001b[39m, in \u001b[36maec_to_parallel_wrapper.step\u001b[39m\u001b[34m(self, actions)\u001b[39m\n\u001b[32m 203\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(\n\u001b[32m 204\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mexpected agent \u001b[39m\u001b[38;5;132;01m{\u001b[39;00magent\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m got agent \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.aec_env.agent_selection\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m, Parallel environment wrapper expects agents to step in a cycle.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 205\u001b[39m )\n\u001b[32m 206\u001b[39m obs, rew, termination, truncation, info = \u001b[38;5;28mself\u001b[39m.aec_env.last()\n\u001b[32m--> \u001b[39m\u001b[32m207\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43maec_env\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mactions\u001b[49m\u001b[43m[\u001b[49m\u001b[43magent\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 208\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m agent \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.aec_env.agents:\n\u001b[32m 209\u001b[39m rewards[agent] += \u001b[38;5;28mself\u001b[39m.aec_env.rewards[agent]\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/order_enforcing.py:96\u001b[39m, in \u001b[36mOrderEnforcingWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 94\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 95\u001b[39m \u001b[38;5;28mself\u001b[39m._has_updated = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m96\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/base.py:47\u001b[39m, in \u001b[36mBaseWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 46\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m, action: ActionType) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m47\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/assert_out_of_bounds.py:26\u001b[39m, in \u001b[36mAssertOutOfBoundsWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m, action: ActionType) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[32m 18\u001b[39m action \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 19\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m (\n\u001b[32m (...)\u001b[39m\u001b[32m 24\u001b[39m action\n\u001b[32m 25\u001b[39m ), \u001b[33m\"\u001b[39m\u001b[33maction is not in action space\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m26\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pettingzoo/utils/wrappers/base.py:47\u001b[39m, in \u001b[36mBaseWrapper.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 46\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m, action: ActionType) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m47\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pettingzoo/mpe/_mpe_utils/simple_env.py:264\u001b[39m, in \u001b[36mSimpleEnv.step\u001b[39m\u001b[34m(self, action)\u001b[39m\n\u001b[32m 261\u001b[39m \u001b[38;5;28mself\u001b[39m._accumulate_rewards()\n\u001b[32m 263\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.render_mode == \u001b[33m\"\u001b[39m\u001b[33mhuman\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m264\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mrender\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/code/machine_learning/simple_crypto/venv/lib/python3.13/site-packages/pettingzoo/mpe/_mpe_utils/simple_env.py:287\u001b[39m, in \u001b[36mSimpleEnv.render\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 285\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.render_mode == \u001b[33m\"\u001b[39m\u001b[33mhuman\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 286\u001b[39m pygame.display.flip()\n\u001b[32m--> \u001b[39m\u001b[32m287\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mclock\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtick\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmetadata\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mrender_fps\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 288\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], + "source": [ + "env = simple_v3.parallel_env(render_mode=\"human\")\n", + "while True:\n", + " state = FloatTensor(env.reset()[0][\"agent_0\"])\n", + " while env.agents:\n", + " action = model.act(state.unsqueeze(0))\n", + " actiondict = {}\n", + " actiondict[\"agent_0\"] = action\n", + " next_state, reward, terminations, truncations, _ = env.step(actiondict)\n", + " done = False if \"agent_0\" in terminations else True\n", + "\n", + " if done:\n", + " state = FloatTensor(env.reset()[0][\"agent_0\"])\n", + " else:\n", + " state = FloatTensor(next_state[\"agent_0\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "machine_learning", + "language": "python", + "name": "machine_learning" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/main.py b/main.py index ff5ea2f..cbc6e77 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ import matplotlib.pyplot as plt from tqdm import tqdm import gymnasium as gym -from pettingzoo.mpe import simple_reference_v3 +from pettingzoo.mpe import simple_reference_v3,simple_v3 import pettingzoo class A2C(nn.Module): @@ -53,8 +53,8 @@ class A2C(nn.Module): self.critic_optim = optim.Adam(self.critic.parameters(), lr=critic_lr) self.actor_optim = optim.Adam(self.actor.parameters(), lr=actor_lr) - self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=1) - self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=1) + #self.critic_scheduler = optim.lr_scheduler.StepLR(self.critic_optim, step_size=100, gamma=0.9) + #self.actor_scheduler = optim.lr_scheduler.StepLR(self.actor_optim, step_size=100, gamma=0.9) def forward(self, x: np.array) -> tuple[torch.tensor, torch.tensor]: x = torch.Tensor(x).to(self.device) @@ -89,32 +89,34 @@ class A2C(nn.Module): ent_coef: float, ) -> tuple[torch.tensor, torch.tensor]: - T = len(rewards) - advantages = torch.zeros(T, device=self.device) + #T = len(rewards) + #rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5) + #advantages = torch.zeros(T, device=self.device) # compute the advantages using GAE - gae = 0.0 - for t in reversed(range(T - 1)): - td_error = ( - rewards[t] + gamma * masks[t] * value_preds[t+1] - value_preds[t] - ) - gae = td_error + gamma * 0.95 * masks[t] * gae - advantages[t] = gae + #gae = 0.0 + #for t in reversed(range(T - 1)): + # td_error = ( + # rewards[t] + gamma * masks[t] * value_preds[t+1] - value_preds[t] + # ) + # gae = td_error + gamma * 0.95 * masks[t] * gae + # advantages[t] = gae + + #advantages = (advantages - advantages.mean()) / advantages.std() # calculate the loss of the minibatch for actor and critic - critic_loss = advantages.pow(2).mean() + #critic_loss = advantages.pow(2).mean() #give a bonus for higher entropy to encourage exploration - actor_loss = ( - -(advantages.detach() * action_log_probs).mean() - ent_coef * torch.Tensor(entropy).mean() - ) + #actor_loss = ( + # -(advantages.detach() * action_log_probs).mean() - ent_coef * torch.Tensor(entropy).mean() + #) #advantages = torch.zeros(len(rewards), device=self.device) #compute advantages #mask - 0 if end of episode #gamma - coeffecient for value prediction - #rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5) #for t in range(len(rewards) - 1): #advantages[t] = (rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t])) #print(advantages[t]) @@ -122,25 +124,25 @@ class A2C(nn.Module): #(rewards[t] + masks[t] * gamma * (value_preds[t+1] - value_preds[t])) #rewards = np.array(rewards) #rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-5) - #returns = [] - #R = 0 - #for r, mask in zip(reversed(rewards), reversed(masks)): - # R = r + gamma * R * mask - # returns.insert(0, R) + returns = [] + R = 0 + for r, mask in zip(reversed(rewards), reversed(masks)): + R = r + gamma * R * mask + returns.insert(0, R) - #returns = torch.FloatTensor(returns) - #values = torch.stack(value_preds).squeeze(1) + returns = torch.FloatTensor(returns) + values = torch.stack(value_preds).squeeze(1) - #advantage = returns - values + advantage = returns - values #calculate critic loss - MSE - #critic_loss = advantages.pow(2).mean() + critic_loss = advantage.pow(2).mean() #critic_loss = advantages.pow(2).mean() #calculate actor loss - give bonus for entropy to encourage exploration #actor_loss = -(advantages.detach() * action_log_probs).mean() - ent_coef * entropy.mean() #entropy = -torch.stack(entropy).sum(dim=-1).mean() - #actor_loss = (-action_log_probs * advantages.detach()).mean() - ent_coef * torch.Tensor(entropy).mean() + actor_loss = (-action_log_probs * advantage.detach()).mean() - ent_coef * torch.Tensor(entropy).mean() #print(action_log_probs) #print(actor_loss) return (critic_loss, actor_loss) @@ -150,19 +152,18 @@ class A2C(nn.Module): critic_loss.backward() torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5) self.critic_optim.step() - self.critic_scheduler.step() + #self.critic_scheduler.step() self.actor_optim.zero_grad() actor_loss.backward() torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5) self.actor_optim.step() - self.actor_scheduler.step() + #self.actor_scheduler.step() def set_eval(self): self.critic.eval() self.actor.eval() - fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,5)) fig.suptitle( f"training plots for the Simple Reference environment" @@ -247,7 +248,8 @@ def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr): agent1_rewards = [] agent0_entropy = [] agent1_entropy = [] - env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array") + #env = simple_reference_v3.parallel_env(max_cycles = 50, render_mode="rgb_array") + env = simple_v3.parallel_env(max_cycles = 50, render_mode="rgb_array") #obs_space #action_space @@ -265,7 +267,7 @@ def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr): #alice = A2C(n_features = env.observation_space("alice_0").shape[0], n_actions = env.action_space("alice_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) agent0 = A2C(n_features = env.observation_space("agent_0").shape[0], n_actions = env.action_space("agent_0").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) - agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) + #agent1 = A2C(n_features = env.observation_space("agent_1").shape[0], n_actions = env.action_space("agent_1").n, device = device, critic_lr = critic_lr, actor_lr = actor_lr) #print(env.action_space("agent_0").n) #print(env.observation_space("agent_0")) @@ -293,9 +295,9 @@ def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr): #actions["bob_0"] = bob_action.item() #actions["alice_0"] = alice_action.item() agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0)) - agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0)) + #agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0)) actions["agent_0"] = agent_0_action - actions["agent_1"] = agent_1_action + #actions["agent_1"] = agent_1_action observations, rewards, terminations, truncations, infos = env.step(actions) #print(rewards) agent_0_rewards.append(rewards["agent_0"]) @@ -304,11 +306,11 @@ def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr): agent_0_ents.append(agent_0_ent.item()) agent_0_mask.append( 1 if env.agents else 0) - agent_1_rewards.append(rewards["agent_1"]) - agent_1_probs.append(agent_1_log_probs) - agent_1_pred.append(agent_1_state_val) - agent_1_ents.append(agent_1_ent.item()) - agent_1_mask.append( 1 if env.agents else 0) + #agent_1_rewards.append(rewards["agent_1"]) + #agent_1_probs.append(agent_1_log_probs) + #agent_1_pred.append(agent_1_state_val) + #agent_1_ents.append(agent_1_ent.item()) + #agent_1_mask.append( 1 if env.agents else 0) #eve_closs, eve_aloss = eve.get_losses([rewards["eve_0"]], eve_log_probs, eve_state_val, eve_ent, [1], gamma, ent_coef) #print("Eve: Critic Loss: " + str(eve_closs.item()) + " Actor Loss: " + str(eve_aloss.item())) #eve.update_params(eve_closs, eve_aloss) @@ -318,17 +320,17 @@ def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr): agent0_actor_loss.append(agent_0_aloss.item()) #print("Agent 0 loss: Critic: " + str(agent_0_closs.item()) + ", Actor: " + str(agent_0_aloss.item())) agent0.update_params(agent_0_closs, agent_0_aloss) - agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef) - agent1_critic_loss.append(agent_1_closs.item()) - agent1_actor_loss.append(agent_1_aloss.item()) + #agent_1_closs, agent_1_aloss = agent1.get_losses(agent_1_rewards, torch.stack(agent_1_probs), agent_1_pred, agent_1_ents, agent_1_mask, gamma, ent_coef) + #agent1_critic_loss.append(agent_1_closs.item()) + #agent1_actor_loss.append(agent_1_aloss.item()) #print("Agent 1 loss: Critic: " + str(agent_1_closs.item()) + ", Actor: " + str(agent_1_aloss.item())) - agent1.update_params(agent_1_closs, agent_1_aloss) + #agent1.update_params(agent_1_closs, agent_1_aloss) agent0_rewards.append(np.array(agent_0_rewards).sum()) - agent1_rewards.append(np.array(agent_1_rewards).sum()) + #agent1_rewards.append(np.array(agent_1_rewards).sum()) #print(np.array(agent_0_ents).sum()) - agent0_entropy.append(np.array(agent_0_ents).sum()) - agent1_entropy.append(np.array(agent_1_ents).sum()) + agent0_entropy.append(np.array(agent_0_ents).mean()) + #agent1_entropy.append(np.array(agent_1_ents).sum()) @@ -337,13 +339,42 @@ def train(n_episodes, gamma, ent_coef, actor_lr, critic_lr): plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png') drawPlots() plt.savefig('plots(gamma=' + str(gamma) + ',ent=' + str(ent_coef) + ',alr=' + str(actor_lr) + ',clr=' + str(critic_lr) + ').png') + + actor0_weights_path = "weights/actor0_weights.h5" + critic0_weights_path = "weights/critic0_weights.h5" + actor1_weights_path = "weights/actor1_weights.h5" + critic1_weights_path = "weights/critic1_weights.h5" + + if not os.path.exists("weights"): + os.mkdir("weights") + + torch.save(agent0.actor.state_dict(), actor0_weights_path) + torch.save(agent0.critic.state_dict(), critic0_weights_path) + #torch.save(agent1.actor.state_dict(), actor1_weights_path) + #torch.save(agent1.critic.state_dict(), critic1_weights_path) + + agent0.set_eval() + #agent1.set_eval() + #env = simple_reference_v3.parallel_env(render_mode="human") + env = simple_v3.parallel_env(render_mode="human") + while True: + observations, infos = env.reset() + while env.agents: + plt.pause(0.001) + actions = {} + agent_0_action, agent_0_log_probs, agent_0_state_val, agent_0_ent = agent0.select_action(torch.FloatTensor(observations["agent_0"]).unsqueeze(0)) + #agent_1_action, agent_1_log_probs, agent_1_state_val, agent_1_ent = agent1.select_action(torch.FloatTensor(observations["agent_1"]).unsqueeze(0)) + actions["agent_0"] = agent_0_action + #actions["agent_1"] = agent_1_action + observations, rewards, terminations, truncations, infos = env.step(actions) + env.close() #environment hyperparams n_episodes = 1000 -train(10000, 0.999, 0.01, 0.0001, 0.0005) +train(10000, 0.9, 0.03, 0.001, 0.005) best = 1 for gamma in np.arange(0.999, 0.99, -0.1): for ent_coef in np.arange(0, 0.1, 0.01):