Source code for aind_behavior_gym.dynamic_foraging.task.base

"""A general gymnasium environment for dynamic foraging tasks in AIND.

Adapted from Han's code for the project in Neuromatch Academy: Deep Learning
https://github.com/hanhou/meta_rl/blob/bd9b5b1d6eb93d217563ff37608aaa2f572c08e6/han/environment/dynamic_bandit_env.py

See also Po-Chen Kuo's implementation:
https://github.com/pckuo/meta_rl/blob/main/environments/bandit/bandit.py
"""

import gymnasium as gym
import numpy as np
from gymnasium import spaces

L = 0
R = 1
IGNORE = 2


[docs] class DynamicForagingTaskBase(gym.Env): """ A general gymnasium environment for dynamic bandit task Adapted from https://github.com/thinkjrs/gym-bandit-environments/blob/master/gym_bandits/bandit.py # noqa E501 """ def __init__( self, reward_baiting: bool = False, # Whether the reward is baited allow_ignore: bool = False, # Allow the agent to ignore the task num_arms: int = 2, # Number of arms in the bandit num_trials: int = 1000, # Number of trials in the session seed=None, ): """Init""" self.num_trials = num_trials self.reward_baiting = reward_baiting self.num_arms = num_arms self.allow_ignore = allow_ignore # State space # - Time (trial number) is the only observable state to the agent self.observation_space = spaces.Dict( { "trial": spaces.Box(low=0, high=self.num_trials, dtype=np.int64), } ) # Action space num_actions = num_arms + int(allow_ignore) # Add the last action as ignore if allowed self.action_space = spaces.Discrete(num_actions) # Random seed self.rng = np.random.default_rng(seed)
[docs] def reset(self, options={}): """ The reset method will be called to initiate a new episode. You may assume that the `step` method will not be called before `reset` has been called. Moreover, `reset` should be called whenever a done signal has been issued. This should *NOT* automatically reset the task! Resetting the task is handled in the wrapper. """ # Some mandatory initialization for any dynamic foraging task self.trial = 0 self.trial_p_reward = np.empty((self.num_trials, self.num_arms)) self.reward_assigned_before_action = np.zeros_like( self.trial_p_reward ) # Whether the reward exists in a certain trial before action self.reward_assigned_after_action = np.zeros_like( self.trial_p_reward ) # Whether the reward exists in a certain trial after action self.random_numbers = np.empty_like( self.trial_p_reward ) # Cache the generated random numbers self.action = np.empty(self.num_trials, dtype=int) self.reward = np.empty(self.num_trials) self.generate_new_trial() # Generate a new p_reward for the first trial return self._get_obs(), self._get_info()
[docs] def step(self, action): """ Execute one step in the environment. Should return: (observation, reward, terminated, truncated, info) If terminated or truncated is true, the user needs to call reset(). """ # Action should be type integer in [0, num_arms-1] if not allow_ignore else [0, num_arms] assert self.action_space.contains(action) self.action[self.trial] = action # Generate reward reward = self.generate_reward(action) self.reward[self.trial] = reward # Decide termination before trial += 1 terminated = bool((self.trial == self.num_trials - 1)) # self.trial starts from 0 # State transition if not terminated (trial += 1 here) if not terminated: self.trial += 1 # tick time here self.generate_new_trial() return self._get_obs(), reward, terminated, False, self._get_info()
[docs] def generate_reward(self, action): """Compute reward, could be overridden by subclasses for more complex reward structures""" # -- Refilling rewards on this trial -- self.random_numbers[self.trial] = self.rng.uniform(0, 1, size=self.num_arms) reward_assigned = ( self.random_numbers[self.trial] < self.trial_p_reward[self.trial] ).astype(float) # -- Reward baited from the last trial -- if self.reward_baiting and self.trial > 0: reward_assigned = np.logical_or( reward_assigned, self.reward_assigned_after_action[self.trial - 1] ).astype(float) # Cache the reward assignment self.reward_assigned_before_action[self.trial] = reward_assigned self.reward_assigned_after_action[self.trial] = reward_assigned # -- Reward delivery -- if action == IGNORE: # Note that reward may be still refilled even if the agent ignores the trial return 0 # Clear up the reward_assigned_after_action slot and return the reward self.reward_assigned_after_action[self.trial, action] = 0 return reward_assigned[action]
[docs] def generate_new_trial(self): """Generate p_reward for a new trial Note that self.trial already increased by 1 here """ raise NotImplementedError("generate_next_trial() should be overridden by subclasses")
[docs] def get_choice_history(self): """Return the history of actions in format that is compatible with other library such as aind_dynamic_foraging_basic_analysis """ actions = self.action.astype(float) actions[actions == IGNORE] = np.nan return actions
[docs] def get_reward_history(self): """Return the history of rewards in format that is compatible with other library such as aind_dynamic_foraging_basic_analysis """ return self.reward
[docs] def get_p_reward(self): """Return the reward probabilities for each arm in each trial which is compatible with other library such as aind_dynamic_foraging_basic_analysis """ return self.trial_p_reward.T
def _get_obs(self): """Return the observation""" return {"trial": self.trial} def _get_info(self): """ Info about the environment that the agents is not supposed to know. For instance, info can reveal the index of the optimal arm, or the value of prior parameter. Can be useful to evaluate the agent's perfomance """ return { "trial": self.trial, "task_object": self, # Return the whole task object for debugging }