How to Create a Custom Experiment

We will explore how you can create your own custom experiment with a tutorial through a simple example, Treasurehunt. In this example, the evironment contains agents with full vision who can only move up, down, left or right, as well as gems that have a random chance of spawning on empty spaces. The agents’ level of success will be measured by the game score, which is determined by how many gems that they pick up.

Overview

The file structure for our experiment will be as follows (under the examples/ directory):

treasurehunt
├── assets
│ └── <animation sprites>
├── data
│ └── <data records>
├── agents.py
├── entities.py
├── env.py
└── main.py

We will create a custom environment named Treasurehunt, custom entities EmptyEntity, Wall, Sand, and Gem, and a custom agent TreasurehuntAgent. The environment will have two layers: TreasurehuntAgent and EmptyEntity will be on the top layer, and Sand will be on the bottom layer. We will then write a main.py script that carries out the experiment, and render parts of the experiment as gifs.

Let’s get started!

The Entities

In entities.py, we will create the 3 entities that we require: EmptyEntity, Wall, and Gem. All the custom entities will extend the base Entity class provided by Sorrel; see sorrel.entities.Entity for its attributes (including their default values) and methods.

We begin by making the necessary imports:

from pathlib import Path

import numpy as np

from sorrel.entities import Entity
from sorrel.environments import GridworldEnv

Then, we create the classes Wall, Sand, and Gem, with custom constructors that overwrite default parent attribute values and include sprites used for animation later on. These sprites should be placed in a ./assets/ folder. All of these entities do not transition.

class Wall(Entity):
    """An entity that represents a wall in the treasurehunt environment."""

    def __init__(self):
        super().__init__()
        self.value = -1  # Walls penalize contact
        self.sprite = Path(__file__).parent / "./assets/wall.png"
class Sand(Entity):
    """An entity that represents a block of sand in the treasurehunt environment."""

    def __init__(self):
        super().__init__()
        # We technically don't need to make Sand passable here since it's on a different layer from Agent
        self.passable = True
        self.sprite = Path(__file__).parent / "./assets/sand.png"
class Gem(Entity):
    """An entity that represents a gem in the treasurehunt environment."""

    def __init__(self, gem_value):
        super().__init__()
        self.passable = True  # Agents can move onto Gems
        self.value = gem_value
        self.sprite = Path(__file__).parent / "./assets/gem.png"

Note

We use Path(__file__) to ensure that the animation sprite paths are always relative to the path to this entities.py file, no matter where one may be running this code from.

We then create EmptyEntity, which requires a custom transition method. Here we note that the transition method requires information such as spawn probability and gem value which must be provided through the environment. Therefore, we expect them to be attributes of our custom Treasurehunt environment.

class EmptyEntity(Entity):
    """An entity that represents an empty space in the treasurehunt environment."""

    def __init__(self):
        super().__init__()
        self.passable = True  # Agents can enter EmptySpaces
        self.has_transitions = True  # EmptyEntity can transition into Gems
        self.sprite = Path(__file__).parent / "./assets/empty.png"

    def transition(self, env: GridworldEnv):
        """
        EmptySpaces can randomly spawn into Gems based on the item spawn probabilities dictated in the environmnet.
        """
        if (  # NOTE: If the spawn prob is too high, the environment gets overrun
            np.random.random() < env.spawn_prob
        ):
            env.add(self.location, Gem(env.gem_value))

The Environment

In env.py, we will create the environment of our experiment: Treasurehunt. It will extend the base GridworldEnv class provided by Sorrel; see sorrel.environments.GridworldEnv for its attributes and methods.

We write the import statements:

# Import base packages
import numpy as np

# Import experiment specific classes
from examples.treasurehunt.entities import EmptyEntity, Gem, Sand, Wall
# Import primitive types
from sorrel.environments import GridworldEnv

We create the constructor first. In addition to the attributes from GridworldEnv, we add the attributes self.gem_value and self.spawn_prob as noted above. We also add the attributes self.max_turns, self.agents, and self.game_score so that we can access these attributes of the environment at the experiment level later.

class Treasurehunt(GridworldEnv):
    """
    Treasurehunt environment.
    """

    def __init__(self, height, width, gem_value, spawn_prob, max_turns, agents):
        layers = 2
        default_entity = EmptyEntity()
        super().__init__(height, width, layers, default_entity)

        self.gem_value = gem_value
        self.spawn_prob = spawn_prob
        self.agents = agents
        self.max_turns = max_turns

        self.game_score = 0
        self.populate()

We delegate the task of actually filling in the entities and constructing self.world to the method populate():

    def populate(self):
        """
        Populate the treasurehunt world by creating walls, then randomly spawning the agents.
        Note that every space is already filled with EmptyEntity as part of super().__init__().
        """
        valid_spawn_locations = []

        for index in np.ndindex(self.world.shape):
            y, x, z = index
            if y in [0, self.height - 1] or x in [0, self.width - 1]:
                # Add walls around the edge of the world (when indices are first or last)
                self.add(index, Wall())
            elif z == 0:  # if location is on the bottom layer, put sand there
                self.add(index, Sand())
            elif (
                z == 1
            ):  # if location is on the top layer, indicate that it's possible for an agent to spawn there
                # valid spawn location
                valid_spawn_locations.append(index)

        # spawn the agents
        # using np.random.choice, we choose indices in valid_spawn_locations
        agent_locations_indices = np.random.choice(
            len(valid_spawn_locations), size=len(self.agents), replace=False
        )
        agent_locations = [valid_spawn_locations[i] for i in agent_locations_indices]
        for loc, agent in zip(agent_locations, self.agents):
            loc = tuple(loc)
            self.add(loc, agent)

Note

We had to work around np.random.choice a little in order to use it. We have specifically avoided using random.choices because we would then need to seed np.random and random separately for reproducible results. It’s generally a good idea to choose one random generator and only use that across the scope of your example.

We will also write a reset() method to reset the environment at the end of every game, using sorrel.environments.GridworldEnv.create_world():

    def reset(self):
        """Reset the environment and all its agents."""
        self.create_world()
        self.game_score = 0
        self.populate()
        for agent in self.agents:
            agent.reset()

The Agent

In agents.py, we will create the agent for our experiment: TreasurehuntAgent. It will extend the base Agent class provided by Sorrel; see sorrel.agents.Agent for its attributes and methods.

We make our imports:

from pathlib import Path

import numpy as np

from sorrel.agents import Agent
from sorrel.environments import GridworldEnv

We make our custom constructor:


class TreasurehuntAgent(Agent):
    """
    A treasurehunt agent that uses the iqn model.
    """
    def __init__(self, observation_spec, action_spec, model):
        super().__init__(observation_spec, action_spec, model)
        self.sprite = Path(__file__).parent / "./assets/hero.png"

We will use sorrel.observation.obvservation_spec.OneHotObservationSpec for TreasurehuntAgent’s observation, sorrel.action.action_spec.AcionSpec for TreasurehuntAgent’s actions, and sorrel.models.pytorch.PyTorchIQN for TreasurehuntAgent’s model. We do not create them in this file (they will be passed into TreasurehuntAgent’s constructor externally), but we will use the functionality that they provide by accessing the attributes of this class.

Note that unlike the other base classes we’ve worked on top of so far, Agent is an abstract class, and every custom agent that extends it must implement the methods reset(), pov(), get_action(), act(), and is_done(). Let’s go through them one by one.

To implement sorrel.agents.Agent.reset(), we add a number of all zero SARD’s to the agent’s model’s memory that is equal to the number of frames that it can access. The “zero state” is obtained by getting the shape of the state observed by this agent through self.model.input_size, and then creating an all zeros array with the same shape.

    def reset(self) -> None:
        """Resets the agent by fill in blank images for the memory buffer."""
        state = np.zeros_like(np.prod(self.model.input_size))
        action = 0
        reward = 0.0
        done = False
        for i in range(self.model.num_frames):
            self.add_memory(state, action, reward, done)

To implement sorrel.agents.Agent.pov(), we get the observed image (in Channels x Height x Width) using the provided OneHotObservationSpec.observe() function, and then returning the flattened image.

    def pov(self, env: GridworldEnv) -> np.ndarray:
        """Returns the state observed by the agent, from the flattened visual field."""
        image = self.observation_spec.observe(env, self.location)
        # flatten the image to get the state
        return image.reshape(1, -1)

To implement sorrel.agents.Agent.get_action(), we stack the current state with the previous states in the model’s memory buffer, and pass the stacked frames (as a horizontal vector) into the model to obtain the action chosen. (See SorrelModel.take_action)

    def get_action(self, state: np.ndarray) -> int:
        """Gets the action from the model, using the stacked states."""
        prev_states = self.model.memory.current_state(
            stacked_frames=self.model.num_frames - 1
        )
        stacked_states = np.vstack((prev_states, state))

        model_input = stacked_states.reshape(1, -1)
        action = self.model.take_action(model_input)
        return action

To implement sorrel.agents.Agent.act(), we calculate the new location based on the action taken, record the reward obtained based on the entity at the new location, then try to move the agent to the new location using the provided GridworldEnv.move().

    def act(self, env: GridworldEnv, action: int) -> float:
        """Act on the environment, returning the reward."""

        # Translate the model output to an action string
        action = self.action_spec.get_readable_action(action)

        new_location = self.location
        if action == "up":
            new_location = (self.location[0] - 1, self.location[1], self.location[2])
        if action == "down":
            new_location = (self.location[0] + 1, self.location[1], self.location[2])
        if action == "left":
            new_location = (self.location[0], self.location[1] - 1, self.location[2])
        if action == "right":
            new_location = (self.location[0], self.location[1] + 1, self.location[2])

        # get reward obtained from object at new_location
        target_object = env.observe(new_location)
        reward = target_object.value
        env.game_score += reward

        # try moving to new_location
        env.move(self, new_location)

        return reward

Finally, we implement sorrel.agents.Agent.is_done() by checking if the current turn (tracked by default in GridworldEnv.turn) exceeds the maximum number of turns.

    def is_done(self, env: GridworldEnv) -> bool:
        """Returns whether this Agent is done."""
        return env.turn >= env.max_turns

Now, we are all done with our custom classes. Time to set up the actual experiment!

The Experiment Script: main.py

First, we make our imports as usual:

# general imports
from pathlib import Path

import numpy as np
import torch

# imports from our example
from examples.treasurehunt.agents import TreasurehuntAgent
from examples.treasurehunt.env import Treasurehunt
from sorrel.action.action_spec import ActionSpec
# sorrel imports
from sorrel.models.pytorch import PyTorchIQN
from sorrel.observation.observation_spec import OneHotObservationSpec
from sorrel.utils.visualization import (animate, image_from_array,
                                        visual_field_sprite)

Then, we will define our experiment parameters as global constants:

EPOCHS = 500
MAX_TURNS = 100
EPSILON_DECAY = 0.0001
ENTITY_LIST = ["EmptyEntity", "Wall", "Sand", "Gem", "TreasurehuntAgent"]
RECORD_PERIOD = 50  # how many epochs in each data recording period

These parameters, as well as the world configuration and model hyperparameters later, can be extracted from this script for faster and easier adjustments using configuration files. Here is a quick tutorial.

We will first create the observation specification, the models, the agents, and the environment. The entities will not need to be created explicitly as they will be generated by the environment.

def setup() -> Treasurehunt:
    """Set up all the whole environment and everything within."""
    # object configurations
    world_height = 10
    world_width = 10
    gem_value = 10
    spawn_prob = 0.002
    agent_vision_radius = 2

    # make the agents
    agent_num = 2
    agents = []
    for _ in range(agent_num):
        observation_spec = OneHotObservationSpec(
            ENTITY_LIST, vision_radius=agent_vision_radius
        )
        observation_spec.override_input_size(
            np.array(observation_spec.input_size).reshape(1, -1)
        )
        action_spec = ActionSpec(["up", "down", "left", "right"])

        model = PyTorchIQN(
            # the agent can see r blocks on each side, so the size of the observation is (2r+1) * (2r+1)
            input_size=observation_spec.input_size,
            action_space=action_spec.n_actions,
            layer_size=250,
            epsilon=0.7,
            device="cpu",
            seed=torch.random.seed(),
            num_frames=5,
            n_step=3,
            sync_freq=200,
            model_update_freq=4,
            BATCH_SIZE=64,
            memory_size=1024,
            LR=0.00025,
            TAU=0.001,
            GAMMA=0.99,
            N=12,
        )

        agents.append(
            TreasurehuntAgent(
                observation_spec=observation_spec, action_spec=action_spec, model=model
            )
        )

    # make the environment
    env = Treasurehunt(
        world_height, world_width, gem_value, spawn_prob, MAX_TURNS, agents
    )
    return env

Then, we will run the experiment. Most of the work here is done by calling GridworldEnv.take_turn(), which transitions every entity in the environment, then every agent, then increments the turn count by one. In addition to printing information about each recording period on the terminal, we also use functions from sorrel.utils.visualization to record states as images and animate them into a gif.

def run(env: Treasurehunt):
    """Run the experiment."""
    imgs = []
    total_score = 0
    total_loss = 0
    for epoch in range(EPOCHS + 1):
        # Reset the environment at the start of each epoch
        env.reset()
        for agent in env.agents:
            agent.model.start_epoch_action(**locals())

        while not env.turn >= env.max_turns:
            if epoch % RECORD_PERIOD == 0:
                full_sprite = visual_field_sprite(env)
                imgs.append(image_from_array(full_sprite))

            env.take_turn()

        # At the end of each epoch, train as long as the batch size is large enough.
        if epoch > 10:
            for agent in env.agents:
                loss = agent.model.train_step()
                total_loss += loss

        total_score += env.game_score

        if epoch % RECORD_PERIOD == 0:
            avg_score = total_score / RECORD_PERIOD
            animate(
                imgs, f"treasurehunt_epoch{epoch}", Path(__file__).parent / "./data/"
            )
            # reset the data
            imgs = []
            total_score = 0
            total_loss = 0

        # update epsilon
        for agent in env.agents:
            new_epsilon = agent.model.epsilon - EPSILON_DECAY
            agent.model.epsilon = max(new_epsilon, 0.01)

Finally, write the main block:

if __name__ == "__main__":
    env = setup()
    run(env)

And we’re done! You can run this script from command line, and see the animations in treasurehunt\data.