How to Create a Custom Experiment

We will explore how you can create your own custom experiment with a tutorial through a simple example, Treasure Hunt. In this example, the evironment contains agents with full vision who can only move up, down, left or right, as well as gems that have a random chance of spawning on empty spaces. The agents’ level of success will be measured by the game score, which is determined by how many gems that they pick up.

Overview

The file structure for our experiment will be as follows (under the examples/ directory):

treasurehunt
├── assets
│ └── <animation sprites>
├── data
│ └── <data records>
├── agents.py
├── entities.py
├── env.py
├── main.py
└── world.py

We will create a custom environment named TreasurehuntEnv, including a world TreasurehuntWorld, custom entities EmptyEntity, Wall, Sand, and Gem, and a custom agent TreasurehuntAgent. The world will have two layers: TreasurehuntAgent and EmptyEntity will be on the top layer, and Sand will be on the bottom layer. We will then write a env.py script that implements the custom environment TreasurehuntEnv, which will allow us to run and record the experiment.

Let’s get started!

The Entities

In entities.py, we will create the 3 entities that we require: EmptyEntity, Wall, and Gem. All the custom entities will extend the base Entity class provided by Sorrel; see sorrel.entities.Entity for its attributes (including their default values) and methods. Note that the Entity class uses a Generic type; we will specify the type as Treasurehunt in our custom entities as that is the environment that our entities are compatible with.

We begin by making the necessary imports:

from pathlib import Path

import numpy as np

from sorrel.entities import Entity
from sorrel.examples.treasurehunt.world import TreasurehuntWorld

Then, we create the classes Wall, Sand, and Gem, with custom constructors that overwrite default parent attribute values and include sprites used for animation later on. These sprites should be placed in a ./assets/ folder. All of these entities do not transition.

class Wall(Entity[TreasurehuntWorld]):
    """An entity that represents a wall in the treasurehunt environment."""

    def __init__(self):
        super().__init__()
        self.value = -1  # Walls penalize contact
        self.sprite = Path(__file__).parent / "./assets/wall.png"
class Sand(Entity[TreasurehuntWorld]):
    """An entity that represents a block of sand in the treasurehunt environment."""

    def __init__(self):
        super().__init__()
        # We technically don't need to make Sand passable here since it's on a different layer from Agent
        self.passable = True
        self.sprite = Path(__file__).parent / "./assets/sand.png"
        self.kind = "EmptyEntity"
class Gem(Entity[TreasurehuntWorld]):
    """An entity that represents a gem in the treasurehunt environment."""

    def __init__(self, value):
        super().__init__()
        self.passable = True  # Agents can move onto Gems
        self.value = value
        self.sprite = Path(__file__).parent / "./assets/gem.png"

Note

We use Path(__file__) to ensure that the animation sprite paths are always relative to the path to this entities.py file, no matter where one may be running this code from.

We then create EmptyEntity, which requires a custom transition method. Here we note that the transition method requires information such as spawn probability and gem value which must be provided through the environment. Therefore, we expect them to be attributes of our custom Treasurehunt environment.

class EmptyEntity(Entity[TreasurehuntWorld]):
    """An entity that represents an empty space in the treasurehunt environment."""

    def __init__(self):
        super().__init__()
        self.passable = True  # Agents can enter EmptySpaces
        self.has_transitions = True  # EmptyEntity can transition into Gems
        self.sprite = Path(__file__).parent / "./assets/empty.png"

    def transition(self, world: TreasurehuntWorld):
        """EmptySpaces can randomly spawn into Gems based on the item spawn
        probabilities dictated in the environment."""
        if (  # NOTE: If the spawn prob is too high, the environment gets overrun
            np.random.random() < world.spawn_prob
        ):
            entity: Entity = np.random.choice(
                np.array(
                    [
                        Gem(world.values["gem"]),
                        Food(world.values["food"]),
                        Bone(world.values["bone"]),
                    ],
                    dtype=object,
                )
            )
            world.add(self.location, entity)

The Environment

In world.py, we will create the world for our environment: TreasurehuntWorld. It will extend the base Gridworld class provided by Sorrel; see sorrel.worlds.Gridworld for its attributes and methods.

We write the import statements:


from omegaconf import DictConfig, OmegaConf

from sorrel.worlds import Gridworld

We create the constructor. In addition to the attributes from Gridworld, we add the attributes self.gem_value and self.spawn_prob as noted above. We also add the attributes self.max_turns so that it can be accessed by the agents to determine if they are Done after an action.

class TreasurehuntWorld(Gridworld):
    """Treasurehunt world."""

    def __init__(self, config: dict | DictConfig, default_entity):
        layers = 2
        if type(config) != DictConfig:
            config = OmegaConf.create(config)
        super().__init__(
            config.world.height, config.world.width, layers, default_entity
        )

        self.values = {
            "gem": config.world.gem_value,
            "food": config.world.food_value,
            "bone": config.world.bone_value,
        }
        self.spawn_prob = config.world.spawn_prob


Note that the world is very barebones. The task of actually filling in the entities and constructing the world is delegated to our custom environment class, as we will see in a moment.

The Agent

In agents.py, we will create the agent for our experiment: TreasurehuntAgent. It will extend the base Agent class provided by Sorrel; see sorrel.agents.Agent for its attributes and methods. Note that Agent is a generic subclass of Entity so just like with our custom entities, we need to specify the type of environment that the custom agent is compatible with when inheriting from Agent.

We make our imports:

from pathlib import Path

import numpy as np

from sorrel.agents import Agent, MovingAgent
from sorrel.examples.treasurehunt.world import TreasurehuntWorld

We make our custom constructor:

class TreasurehuntAgent(MovingAgent[TreasurehuntWorld]):
    """A treasurehunt agent that uses the iqn model."""

    def __init__(self, observation_spec, action_spec, model):
        super().__init__(observation_spec, action_spec, model)
        self.sprite = Path(__file__).parent / "./assets/hero.png"

We will use sorrel.observation.obvservation_spec.OneHotObservationSpec for TreasurehuntAgent’s observation, sorrel.action.action_spec.ActionSpec for TreasurehuntAgent’s actions, and sorrel.models.pytorch.PyTorchIQN for TreasurehuntAgent’s model. We do not create them in this file (they will be passed into TreasurehuntAgent’s constructor externally), but we will use the functionality that they provide by accessing the attributes of this class.

Note that unlike the other base classes we’ve worked on top of so far, Agent is an abstract class, and every custom agent that extends it must implement the methods reset(), pov(), get_action(), act(), and is_done(). Let’s go through them one by one.

To implement sorrel.agents.Agent.reset(), we call the agent’s model’s reset function. This is required of all sorrel models that inherit the base model class.

    def reset(self) -> None:
        """Resets the agent by fill in blank images for the memory buffer."""
        self.model.reset()

To implement sorrel.agents.Agent.pov(), we get the observed image (in Channels x Height x Width) using the provided OneHotObservationSpec.observe() function, and then returning the flattened image.

    def pov(self, world: TreasurehuntWorld) -> np.ndarray:
        """Returns the state observed by the agent, from the flattened visual field."""
        image = self.observation_spec.observe(world, self.location)
        # flatten the image to get the state
        return image.reshape(1, -1)

To implement sorrel.agents.Agent.get_action(), we stack the current state with the previous states in the model’s memory buffer, and pass the stacked frames (as a horizontal vector) into the model to obtain the action chosen. (See BaseModel.take_action)

    def get_action(self, state: np.ndarray) -> int:
        """Gets the action from the model, using the stacked states."""
        prev_states = self.model.memory.current_state()
        stacked_states = np.vstack((prev_states, state))

        model_input = stacked_states.reshape(1, -1)
        action = self.model.take_action(model_input)
        return action

To implement sorrel.agents.Agent.act(), we calculate the new location based on the action taken, record the reward obtained based on the entity at the new location, then try to move the agent to the new location using the provided Gridworld.move().

    def act(self, world: TreasurehuntWorld, action: int) -> float:
        """Act on the environment, returning the reward."""

        # Translate the model output to an action string
        action_name = self.action_spec.get_readable_action(action)

        new_location = self.location
        if action_name == "up":
            new_location = (self.location[0] - 1, self.location[1], self.location[2])
        if action_name == "down":
            new_location = (self.location[0] + 1, self.location[1], self.location[2])
        if action_name == "left":
            new_location = (self.location[0], self.location[1] - 1, self.location[2])
        if action_name == "right":
            new_location = (self.location[0], self.location[1] + 1, self.location[2])

        # get reward obtained from object at new_location
        target_object = world.observe(new_location)
        reward = target_object.value

        # try moving to new_location
        world.move(self, new_location)

        return reward

Finally, we implement sorrel.agents.Agent.is_done() by checking if the current turn (tracked by default in Gridworld.turn) exceeds the maximum number of turns.

    def is_done(self, world: TreasurehuntWorld) -> bool:
        """Returns whether this Agent is done."""
        return world.is_done

Now, we are all done with our custom classes. Time to set up the actual environment!

The Environment: env.py

First, we make our imports as usual:

# general imports
import numpy as np
import torch

from sorrel.action.action_spec import ActionSpec
from sorrel.environment import Environment

# imports from our example
from sorrel.examples.treasurehunt.agents import TreasurehuntAgent
from sorrel.examples.treasurehunt.entities import EmptyEntity, Sand, Wall
from sorrel.examples.treasurehunt.world import TreasurehuntWorld

# sorrel imports
from sorrel.models.pytorch import PyTorchIQN
from sorrel.observation.observation_spec import OneHotObservationSpec

We will now write our custom environment class by inheriting the sorrel.environment.Environment class that has an already implemented run_experiment() method which will run the experiment for us. Much like the custom entities and agents, we need to specify the world this custom environment is using when inheriting from the generic environment.

class TreasurehuntEnv(Environment[TreasurehuntWorld]):
    """The experiment for treasurehunt."""

    def __init__(self, world: TreasurehuntWorld, config: dict) -> None:
        super().__init__(world, config)

Note that the environment takes in a config that can be accessed at self.config which stores the configurations used for this experiment. Certain config values are required when using the default methods: see the documentation for more details.

Like Agent, Experiment requires us to implement two abstract methods.

The first is sorrel.environment.Environment.setup_agents(), where we create the agents used in this specific environment and save them in the attribute self.agents:

    def setup_agents(self):
        """Create the agents for this experiment and assign them to self.agents.

        Requires self.config.model.agent_vision_radius to be defined.
        """
        agent_num = 2
        agents = []
        for _ in range(agent_num):
            # create the observation spec
            entity_list = [
                "EmptyEntity",
                "Wall",
                "Gem",
                "Bone",
                "Food",
                "TreasurehuntAgent",
            ]
            observation_spec = OneHotObservationSpec(
                entity_list,
                full_view=False,
                # note that here we require self.config to have the entry model.agent_vision_radius
                # don't forget to pass it in as part of config when creating this experiment!
                vision_radius=self.config.model.agent_vision_radius,
            )
            observation_spec.override_input_size(
                (int(np.prod(observation_spec.input_size)),)
            )

            # create the action spec
            action_spec = ActionSpec(["up", "down", "left", "right"])

            # create the model
            model = PyTorchIQN(
                input_size=observation_spec.input_size,
                action_space=action_spec.n_actions,
                layer_size=250,
                epsilon=0.6,
                device="cpu",
                seed=torch.random.seed(),
                n_frames=5,
                n_step=3,
                sync_freq=200,
                model_update_freq=4,
                batch_size=64,
                memory_size=1024,
                LR=0.00025,
                TAU=0.001,
                GAMMA=0.99,
                n_quantiles=12,
            )

            agents.append(
                TreasurehuntAgent(
                    observation_spec=observation_spec,
                    action_spec=action_spec,
                    model=model,
                )
            )

        self.agents = agents

The second is sorrel.environment.Environment.populate_environment(), where we create all entities and populate self.env with the entities as well as the agents.

    def populate_environment(self):
        """Populate the treasurehunt world by creating walls, then randomly spawning the
        agents.

        Note that self.world.map is already created with the specified dimensions, and
        every space is filled with EmptyEntity, as part of super().__init__() when this
        experiment is constructed.
        """
        valid_spawn_locations = []

        for index in np.ndindex(self.world.map.shape):
            y, x, z = index
            if (y in [0, self.world.height - 1] or x in [0, self.world.width - 1]) and (
                z == 1
            ):
                # Add walls around the edge of the world (when indices are first or last)
                self.world.add(index, Wall())
            elif z == 0:  # if location is on the bottom layer, put sand there
                self.world.add(index, Sand())
            elif (
                z == 1
            ):  # if location is on the top layer, indicate that it's possible for an agent to spawn there
                # valid spawn location
                valid_spawn_locations.append(index)

        # spawn the agents
        # using np.random.choice, we choose indices in valid_spawn_locations
        agent_locations_indices = np.random.choice(
            len(valid_spawn_locations), size=len(self.agents), replace=False
        )
        agent_locations = [valid_spawn_locations[i] for i in agent_locations_indices]
        for loc, agent in zip(agent_locations, self.agents):
            loc = tuple(loc)
            self.world.add(loc, agent)

Note

We had to work around np.random.choice a little in order to use it. We have specifically avoided using random.choices because we would then need to seed np.random and random separately for reproducible results. It’s generally a good idea to choose one random generator and only use that across the scope of your example.

The Experiment Script: main.py

Lastly, we will run the experiment. Most of the work is done by calling the Experiment.run_experiment() method.

if __name__ == "__main__":

    # object configurations
    config = {
        "experiment": {
            "epochs": 1000,
            "max_turns": 100,
            "record_period": 50,
            "log_dir": Path(__file__).parent
            / f"./data/logs/{datetime.now().strftime('%Y-%m-%d %H-%M-%S')}",
        },
        "model": {
            "agent_vision_radius": 2,
            "epsilon_decay": 0.0005,
        },
        "world": {
            "height": 21,
            "width": 21,
            "gem_value": 10,
            "food_value": 5,
            "bone_value": -10,
            "spawn_prob": 0.005,
        },
    }

    # construct the world
    world = TreasurehuntWorld(config=config, default_entity=EmptyEntity())
    # construct the environment
    env = TreasurehuntEnv(world, config)
    # run the experiment with default parameters
    env.run_experiment(
        output_dir=Path(__file__).parent / "./data",
        logger=TensorboardLogger.from_config(config),
    )

Here, we use a dictionary to store our configs for the experiment and pass in constants for the environment parameters. In general, we recommend using configuration files for a more clean and centralized approach: here’s a quick tutorial.

And we’re done! You can run this script from command line, and see the animations in ./data (under the current working directory).