How to Create a Custom Experiment
We will explore how you can create your own custom experiment with a tutorial through a simple example, Treasurehunt. In this example, the evironment contains agents with full vision who can only move up, down, left or right, as well as gems that have a random chance of spawning on empty spaces. The agents’ level of success will be measured by the game score, which is determined by how many gems that they pick up.
Overview
The file structure for our experiment will be as follows (under the examples/ directory):
treasurehunt
├── assets
│ └── <animation sprites>
├── data
│ └── <data records>
├── agents.py
├── entities.py
├── env.py
└── main.py
We will create a custom environment named Treasurehunt, custom entities EmptyEntity, Wall, Sand, and Gem, and a custom agent TreasurehuntAgent.
The environment will have two layers: TreasurehuntAgent and EmptyEntity will be on the top layer, and Sand will be on the bottom layer.
We will then write a main.py script that carries out the experiment, and render parts of the experiment as gifs.
Let’s get started!
The Entities
In entities.py, we will create the 3 entities that we require: EmptyEntity, Wall, and Gem.
All the custom entities will extend the base Entity class provided by Sorrel; see sorrel.entities.Entity
for its attributes (including their default values) and methods.
We begin by making the necessary imports:
from pathlib import Path
import numpy as np
from sorrel.entities import Entity
from sorrel.environments import GridworldEnv
Then, we create the classes Wall, Sand, and Gem, with custom constructors that overwrite default parent attribute values and include sprites used for animation later on.
These sprites should be placed in a ./assets/ folder. All of these entities do not transition.
class Wall(Entity):
"""An entity that represents a wall in the treasurehunt environment."""
def __init__(self):
super().__init__()
self.value = -1 # Walls penalize contact
self.sprite = Path(__file__).parent / "./assets/wall.png"
class Sand(Entity):
"""An entity that represents a block of sand in the treasurehunt environment."""
def __init__(self):
super().__init__()
# We technically don't need to make Sand passable here since it's on a different layer from Agent
self.passable = True
self.sprite = Path(__file__).parent / "./assets/sand.png"
class Gem(Entity):
"""An entity that represents a gem in the treasurehunt environment."""
def __init__(self, gem_value):
super().__init__()
self.passable = True # Agents can move onto Gems
self.value = gem_value
self.sprite = Path(__file__).parent / "./assets/gem.png"
Note
We use Path(__file__) to ensure that the animation sprite paths are always relative to the path to this entities.py file, no matter where one may be running this code from.
We then create EmptyEntity, which requires a custom transition method.
Here we note that the transition method requires information such as spawn probability and gem value which must be provided through the environment.
Therefore, we expect them to be attributes of our custom Treasurehunt environment.
class EmptyEntity(Entity):
"""An entity that represents an empty space in the treasurehunt environment."""
def __init__(self):
super().__init__()
self.passable = True # Agents can enter EmptySpaces
self.has_transitions = True # EmptyEntity can transition into Gems
self.sprite = Path(__file__).parent / "./assets/empty.png"
def transition(self, env: GridworldEnv):
"""
EmptySpaces can randomly spawn into Gems based on the item spawn probabilities dictated in the environmnet.
"""
if ( # NOTE: If the spawn prob is too high, the environment gets overrun
np.random.random() < env.spawn_prob
):
env.add(self.location, Gem(env.gem_value))
The Environment
In env.py, we will create the environment of our experiment: Treasurehunt.
It will extend the base GridworldEnv class provided by Sorrel;
see sorrel.environments.GridworldEnv for its attributes and methods.
We write the import statements:
# Import base packages
import numpy as np
# Import experiment specific classes
from examples.treasurehunt.entities import EmptyEntity, Gem, Sand, Wall
# Import primitive types
from sorrel.environments import GridworldEnv
We create the constructor first. In addition to the attributes from GridworldEnv, we add the attributes self.gem_value
and self.spawn_prob as noted above. We also add the attributes self.max_turns, self.agents, and self.game_score
so that we can access these attributes of the environment at the experiment level later.
class Treasurehunt(GridworldEnv):
"""
Treasurehunt environment.
"""
def __init__(self, height, width, gem_value, spawn_prob, max_turns, agents):
layers = 2
default_entity = EmptyEntity()
super().__init__(height, width, layers, default_entity)
self.gem_value = gem_value
self.spawn_prob = spawn_prob
self.agents = agents
self.max_turns = max_turns
self.game_score = 0
self.populate()
We delegate the task of actually filling in the entities and constructing self.world to the method populate():
def populate(self):
"""
Populate the treasurehunt world by creating walls, then randomly spawning the agents.
Note that every space is already filled with EmptyEntity as part of super().__init__().
"""
valid_spawn_locations = []
for index in np.ndindex(self.world.shape):
y, x, z = index
if y in [0, self.height - 1] or x in [0, self.width - 1]:
# Add walls around the edge of the world (when indices are first or last)
self.add(index, Wall())
elif z == 0: # if location is on the bottom layer, put sand there
self.add(index, Sand())
elif (
z == 1
): # if location is on the top layer, indicate that it's possible for an agent to spawn there
# valid spawn location
valid_spawn_locations.append(index)
# spawn the agents
# using np.random.choice, we choose indices in valid_spawn_locations
agent_locations_indices = np.random.choice(
len(valid_spawn_locations), size=len(self.agents), replace=False
)
agent_locations = [valid_spawn_locations[i] for i in agent_locations_indices]
for loc, agent in zip(agent_locations, self.agents):
loc = tuple(loc)
self.add(loc, agent)
Note
We had to work around np.random.choice a little in order to use it.
We have specifically avoided using random.choices because we would then need to seed np.random and random separately
for reproducible results. It’s generally a good idea to choose one random generator and only use that across the scope of your example.
We will also write a reset() method to reset the environment at the end of every game, using sorrel.environments.GridworldEnv.create_world():
def reset(self):
"""Reset the environment and all its agents."""
self.create_world()
self.game_score = 0
self.populate()
for agent in self.agents:
agent.reset()
The Agent
In agents.py, we will create the agent for our experiment: TreasurehuntAgent.
It will extend the base Agent class provided by Sorrel;
see sorrel.agents.Agent for its attributes and methods.
We make our imports:
from pathlib import Path
import numpy as np
from sorrel.agents import Agent
from sorrel.environments import GridworldEnv
We make our custom constructor:
class TreasurehuntAgent(Agent):
"""
A treasurehunt agent that uses the iqn model.
"""
def __init__(self, observation_spec, action_spec, model):
super().__init__(observation_spec, action_spec, model)
self.sprite = Path(__file__).parent / "./assets/hero.png"
We will use sorrel.observation.obvservation_spec.OneHotObservationSpec for TreasurehuntAgent’s observation, sorrel.action.action_spec.AcionSpec for TreasurehuntAgent’s actions, and sorrel.models.pytorch.PyTorchIQN for TreasurehuntAgent’s model.
We do not create them in this file (they will be passed into TreasurehuntAgent’s constructor externally),
but we will use the functionality that they provide by accessing the attributes of this class.
Note that unlike the other base classes we’ve worked on top of so far, Agent is an abstract class, and every custom agent that extends it must implement the methods
reset(), pov(), get_action(), act(), and is_done(). Let’s go through them one by one.
To implement sorrel.agents.Agent.reset(), we add a number of all zero SARD’s to the agent’s model’s memory that is equal to the number of frames that it can access.
The “zero state” is obtained by getting the shape of the state observed by this agent through self.model.input_size,
and then creating an all zeros array with the same shape.
def reset(self) -> None:
"""Resets the agent by fill in blank images for the memory buffer."""
state = np.zeros_like(np.prod(self.model.input_size))
action = 0
reward = 0.0
done = False
for i in range(self.model.num_frames):
self.add_memory(state, action, reward, done)
To implement sorrel.agents.Agent.pov(), we get the observed image (in Channels x Height x Width)
using the provided OneHotObservationSpec.observe() function, and then returning the flattened image.
def pov(self, env: GridworldEnv) -> np.ndarray:
"""Returns the state observed by the agent, from the flattened visual field."""
image = self.observation_spec.observe(env, self.location)
# flatten the image to get the state
return image.reshape(1, -1)
To implement sorrel.agents.Agent.get_action(), we stack the current state with the previous states in the model’s memory buffer,
and pass the stacked frames (as a horizontal vector) into the model to obtain the action chosen. (See SorrelModel.take_action)
def get_action(self, state: np.ndarray) -> int:
"""Gets the action from the model, using the stacked states."""
prev_states = self.model.memory.current_state(
stacked_frames=self.model.num_frames - 1
)
stacked_states = np.vstack((prev_states, state))
model_input = stacked_states.reshape(1, -1)
action = self.model.take_action(model_input)
return action
To implement sorrel.agents.Agent.act(), we calculate the new location based on the action taken,
record the reward obtained based on the entity at the new location, then try to move the agent to the new location using the provided GridworldEnv.move().
def act(self, env: GridworldEnv, action: int) -> float:
"""Act on the environment, returning the reward."""
# Translate the model output to an action string
action = self.action_spec.get_readable_action(action)
new_location = self.location
if action == "up":
new_location = (self.location[0] - 1, self.location[1], self.location[2])
if action == "down":
new_location = (self.location[0] + 1, self.location[1], self.location[2])
if action == "left":
new_location = (self.location[0], self.location[1] - 1, self.location[2])
if action == "right":
new_location = (self.location[0], self.location[1] + 1, self.location[2])
# get reward obtained from object at new_location
target_object = env.observe(new_location)
reward = target_object.value
env.game_score += reward
# try moving to new_location
env.move(self, new_location)
return reward
Finally, we implement sorrel.agents.Agent.is_done() by checking if the current turn (tracked by default in GridworldEnv.turn)
exceeds the maximum number of turns.
def is_done(self, env: GridworldEnv) -> bool:
"""Returns whether this Agent is done."""
return env.turn >= env.max_turns
Now, we are all done with our custom classes. Time to set up the actual experiment!
The Experiment Script: main.py
First, we make our imports as usual:
# general imports
from pathlib import Path
import numpy as np
import torch
# imports from our example
from examples.treasurehunt.agents import TreasurehuntAgent
from examples.treasurehunt.env import Treasurehunt
from sorrel.action.action_spec import ActionSpec
# sorrel imports
from sorrel.models.pytorch import PyTorchIQN
from sorrel.observation.observation_spec import OneHotObservationSpec
from sorrel.utils.visualization import (animate, image_from_array,
visual_field_sprite)
Then, we will define our experiment parameters as global constants:
EPOCHS = 500
MAX_TURNS = 100
EPSILON_DECAY = 0.0001
ENTITY_LIST = ["EmptyEntity", "Wall", "Sand", "Gem", "TreasurehuntAgent"]
RECORD_PERIOD = 50 # how many epochs in each data recording period
These parameters, as well as the world configuration and model hyperparameters later, can be extracted from this script for faster and easier adjustments using configuration files. Here is a quick tutorial.
We will first create the observation specification, the models, the agents, and the environment. The entities will not need to be created explicitly as they will be generated by the environment.
def setup() -> Treasurehunt:
"""Set up all the whole environment and everything within."""
# object configurations
world_height = 10
world_width = 10
gem_value = 10
spawn_prob = 0.002
agent_vision_radius = 2
# make the agents
agent_num = 2
agents = []
for _ in range(agent_num):
observation_spec = OneHotObservationSpec(
ENTITY_LIST, vision_radius=agent_vision_radius
)
observation_spec.override_input_size(
np.array(observation_spec.input_size).reshape(1, -1)
)
action_spec = ActionSpec(["up", "down", "left", "right"])
model = PyTorchIQN(
# the agent can see r blocks on each side, so the size of the observation is (2r+1) * (2r+1)
input_size=observation_spec.input_size,
action_space=action_spec.n_actions,
layer_size=250,
epsilon=0.7,
device="cpu",
seed=torch.random.seed(),
num_frames=5,
n_step=3,
sync_freq=200,
model_update_freq=4,
BATCH_SIZE=64,
memory_size=1024,
LR=0.00025,
TAU=0.001,
GAMMA=0.99,
N=12,
)
agents.append(
TreasurehuntAgent(
observation_spec=observation_spec, action_spec=action_spec, model=model
)
)
# make the environment
env = Treasurehunt(
world_height, world_width, gem_value, spawn_prob, MAX_TURNS, agents
)
return env
Then, we will run the experiment. Most of the work here is done by calling GridworldEnv.take_turn(),
which transitions every entity in the environment, then every agent, then increments the turn count by one.
In addition to printing information about each recording period on the terminal,
we also use functions from sorrel.utils.visualization to record states as images and animate them into a gif.
def run(env: Treasurehunt):
"""Run the experiment."""
imgs = []
total_score = 0
total_loss = 0
for epoch in range(EPOCHS + 1):
# Reset the environment at the start of each epoch
env.reset()
for agent in env.agents:
agent.model.start_epoch_action(**locals())
while not env.turn >= env.max_turns:
if epoch % RECORD_PERIOD == 0:
full_sprite = visual_field_sprite(env)
imgs.append(image_from_array(full_sprite))
env.take_turn()
# At the end of each epoch, train as long as the batch size is large enough.
if epoch > 10:
for agent in env.agents:
loss = agent.model.train_step()
total_loss += loss
total_score += env.game_score
if epoch % RECORD_PERIOD == 0:
avg_score = total_score / RECORD_PERIOD
animate(
imgs, f"treasurehunt_epoch{epoch}", Path(__file__).parent / "./data/"
)
# reset the data
imgs = []
total_score = 0
total_loss = 0
# update epsilon
for agent in env.agents:
new_epsilon = agent.model.epsilon - EPSILON_DECAY
agent.model.epsilon = max(new_epsilon, 0.01)
Finally, write the main block:
if __name__ == "__main__":
env = setup()
run(env)
And we’re done! You can run this script from command line, and see the animations in treasurehunt\data.