Skip to main content

Rollout Storage for Motion Mimic PPO

Module: rsl_rl.storage.rollout_storage_mm

๐ŸŽฏ Overviewโ€‹

RolloutStorageMM provides efficient storage and batch generation for multi-modal reinforcement learning data. It manages observation-action trajectories with support for reference observations, privileged information, and auxiliary learning modules.


๐Ÿง  Core Componentsโ€‹

๐Ÿ”„ Transition Classโ€‹

Module Name: rsl_rl.storage.rollout_storage_mm.RolloutStorageMM.Transition

Stored Variables:

  • observations (torch.Tensor): Actor observations
  • reference_observations (torch.Tensor): Reference observation data
  • reference_observations_mask (torch.Tensor): Reference availability mask
  • critic_observations (torch.Tensor): Critic observations (privileged info)
  • critic_reference_observations (torch.Tensor): Critic reference data
  • critic_reference_observations_mask (torch.Tensor): Critic reference mask
  • actions (torch.Tensor): Executed actions
  • privileged_actions (torch.Tensor): Teacher actions (distillation only)
  • dagger_actions (torch.Tensor): DAgger network actions
  • rewards (torch.Tensor): Environment rewards
  • dones (torch.Tensor): Episode termination flags
  • values (torch.Tensor): Value function estimates
  • actions_log_prob (torch.Tensor): Action log probabilities
  • action_mean (torch.Tensor): Policy action means
  • action_sigma (torch.Tensor): Policy action standard deviations
  • hidden_states (tuple): RNN hidden states (if applicable)
  • rnd_state (torch.Tensor): RND module states

๐Ÿ“ฆ RolloutStorageMM Classโ€‹

Module Name: rsl_rl.storage.rollout_storage_mm.RolloutStorageMM

Definition:

class RolloutStorageMM:
def __init__(self, training_type, num_envs, num_transitions_per_env,
obs_shape, ref_obs_shape, privileged_obs_shape,
privileged_ref_obs_shape, actions_shape,
apply_dagger_actions=False, rnd_state_shape=None,
amp_cfg=None, device="cpu")

๐Ÿ“ฅ Parameters:

  • training_type (str): "rl" or "distillation"
  • num_envs (int): Number of parallel environments
  • num_transitions_per_env (int): Steps per rollout per environment
  • obs_shape (list): Actor observation shape
  • ref_obs_shape (list): Reference observation shape
  • privileged_obs_shape (list): Critic observation shape
  • privileged_ref_obs_shape (list): Critic reference observation shape
  • actions_shape (list): Action space shape
  • apply_dagger_actions (bool): Enable DAgger action storage. Default is False
  • rnd_state_shape (list | None): RND state shape. Default is None
  • amp_cfg (dict | None): AMP configuration. Default is None
  • device (str): Storage device. Default is "cpu"

๐Ÿ—„๏ธ Storage Variablesโ€‹

๐Ÿ“Š Core Storage (Always Created)โ€‹

Observation Storage:

# Actor observations - Shape: (num_transitions_per_env, num_envs, *obs_shape)
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=device)

# Core action-reward storage
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=device).byte()

๐ŸŽญ Conditional Storage Variablesโ€‹

๐Ÿ” Privileged Observationsโ€‹

Creation Condition: privileged_obs_shape[0] is not None

# Critic observations (privileged information)
self.privileged_observations = torch.zeros(
num_transitions_per_env, num_envs, *privileged_obs_shape, device=device
)

๐Ÿ“š Reference Observationsโ€‹

Creation Condition: ref_obs_shape[0] is not None

# Reference observations and masks
self.reference_observations = torch.zeros(num_transitions_per_env, num_envs, *ref_obs_shape, device=device)
self.reference_observations_mask = torch.zeros(num_transitions_per_env, num_envs, device=device).bool()

๐ŸŽฏ Privileged Reference Observationsโ€‹

Creation Condition: ref_obs_shape[0] is not None AND privileged_ref_obs_shape[0] is not None

# Critic reference observations
self.privileged_reference_observations = torch.zeros(
num_transitions_per_env, num_envs, *privileged_ref_obs_shape, device=device
)
self.privileged_reference_observations_mask = torch.zeros(num_transitions_per_env, num_envs, device=device).bool()

๐ŸŽ“ DAgger Actionsโ€‹

Creation Condition: apply_dagger_actions == True

# DAgger network actions for imitation learning
self.dagger_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)

๐Ÿซ Distillation Trainingโ€‹

Creation Condition: training_type == "distillation"

# Teacher actions for knowledge distillation
self.privileged_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)

๐Ÿ’ฐ RL Training Variablesโ€‹

Creation Condition: training_type == "rl"

# PPO-specific storage
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)

๐Ÿ” RND Storageโ€‹

Creation Condition: rnd_state_shape is not None

# Random Network Distillation states
self.rnd_state = torch.zeros(num_transitions_per_env, num_envs, *rnd_state_shape, device=device)

๐ŸŽฏ Core Methodsโ€‹

๐Ÿ“ Data Managementโ€‹

โž• Add Transitionsโ€‹

def add_transitions(self, transition: Transition)

๐Ÿ“ฅ Input: transition (Transition): Single step transition data ๐Ÿ”ง Function: Stores transition data at current step index, automatically copying relevant fields based on storage configuration.

๐Ÿงน Clear Storageโ€‹

def clear(self)

๐Ÿ”ง Function: Resets step counter to 0 for new rollout collection.

๐Ÿ“ˆ Return Computationโ€‹

๐Ÿ’ฐ Compute Returns and Advantagesโ€‹

def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True)

๐Ÿ“ฅ Input:

  • last_values (torch.Tensor): Bootstrap values for final step
  • gamma (float): Discount factor
  • lam (float): GAE lambda parameter
  • normalize_advantage (bool): Normalize advantages. Default is True

๐Ÿ”ง Function: Computes returns using GAE and normalizes advantages for PPO training.

๐ŸŽฒ Batch Generationโ€‹

๐Ÿ”„ Mini-Batch Generator (RL Training)โ€‹

def mini_batch_generator(self, num_mini_batches, num_epochs=8)

๐Ÿ“ฅ Input:

  • num_mini_batches (int): Number of mini-batches per epoch
  • num_epochs (int): Number of training epochs. Default is 8

๐Ÿ“ค Output Tuple:

(obs_batch,                    # Actor observations
ref_obs_batch_rtn, # (reference_obs, ref_mask) or None
critic_observations_batch, # Critic observations
critic_ref_obs_batch_rtn, # (critic_ref_obs, critic_ref_mask) or None
actions_batch, # Actions
target_values_batch, # Value targets
advantages_batch, # GAE advantages
returns_batch, # Returns
old_actions_log_prob_batch, # Old log probabilities
old_mu_batch, # Old action means
old_sigma_batch, # Old action std deviations
dagger_actions_batch, # DAgger actions or None
hidden_states, # (None, None) - placeholder for RNN
masks_batch, # None - placeholder for RNN
rnd_state_batch, # RND states or None
obs_prev_state, # Previous observations for AMP
ref_obs_prev_state, # Previous reference observations for AMP
ref_obs_prev_mask) # Previous reference masks for AMP

๐Ÿ”ง Function: Generates randomized mini-batches for PPO training with proper data organization:

Data Flattening:

# Shape transformation: (num_transitions_per_env, num_envs, ...) โ†’ (batch_size, ...)
batch_size = num_envs * num_transitions_per_env
observations = self.observations.flatten(0, 1) # (batch_size, *obs_shape)

Reference Data Handling:

# Reference observations packaged as tuples for masking
ref_obs_batch_rtn = (ref_obs_batch, ref_obs_mask_batch) if ref_obs_batch is not None else None

AMP Integration:

# Previous state extraction for AMP discriminator
if self.amp_cfg and reference_observations is not None:
obs_prev_state = amp_obs_extractor(observations[prev_batch_idx])
ref_obs_prev_state, ref_obs_prev_mask = amp_ref_obs_extractor((
reference_observations[prev_batch_idx],
reference_observations_mask[prev_batch_idx]
))

๐ŸŽ“ Distillation Generatorโ€‹

def generator(self)

๐Ÿ“ค Output Tuple:

(observations,                    # Actor observations
privileged_observations, # Teacher observations
ref_obs_batch_rtn, # (reference_obs, ref_mask)
privileged_ref_obs_batch_rtn, # (privileged_ref_obs, privileged_ref_mask) or None
actions, # Student actions
privileged_actions, # Teacher actions
dones) # Episode termination flags

๐Ÿ”ง Function: Sequential iteration over stored transitions for distillation training.

๐Ÿ“Š Statisticsโ€‹

๐Ÿ“ˆ Get Training Statisticsโ€‹

def get_statistics(self)

๐Ÿ“ค Output: (mean_trajectory_length, mean_reward) (tuple) ๐Ÿ”ง Function: Computes episode-based statistics for monitoring training progress.


๐Ÿ”ง Data Organization Summaryโ€‹

๐Ÿ“ฆ Storage Layoutโ€‹

# Time-major storage format
Storage Shape: (num_transitions_per_env, num_envs, feature_dim)
Batch Shape: (batch_size, feature_dim) # After flattening

# Index mapping
step_idx = 0 to num_transitions_per_env-1
env_idx = 0 to num_envs-1
flat_idx = step_idx * num_envs + env_idx

๐ŸŽฒ Sampling Strategyโ€‹

# Random shuffling for PPO
batch_size = num_envs * num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(batch_size) # Random permutation

# Mini-batch extraction
for i in range(num_mini_batches):
start = i * mini_batch_size
end = (i + 1) * mini_batch_size
batch_idx = indices[start:end]
# Extract data using batch_idx

๐ŸŽญ Multi-Modal Data Packagingโ€‹

# Reference observations always packaged with masks
if reference_observations is not None:
ref_obs_tuple = (reference_data, reference_mask)
else:
ref_obs_tuple = None

# Separate handling for actor and critic reference data
actor_ref_tuple = (ref_obs, ref_mask) if available else None
critic_ref_tuple = (critic_ref_obs, critic_ref_mask) if available else None

๐Ÿ’ก Usage Patternโ€‹

# Initialize storage
storage = RolloutStorageMM(
training_type="rl",
num_envs=4096,
num_transitions_per_env=24,
obs_shape=[48],
ref_obs_shape=[26],
privileged_obs_shape=[48],
privileged_ref_obs_shape=[26],
actions_shape=[19],
apply_dagger_actions=True,
rnd_state_shape=[48],
device="cuda"
)

# Data collection
for step in range(24):
# ... environment interaction ...
storage.add_transitions(transition)

# Compute returns
storage.compute_returns(last_values, gamma=0.99, lam=0.95)

# Training
for epoch in range(5):
for batch_data in storage.mini_batch_generator(num_mini_batches=4):
obs_batch, ref_obs_batch, critic_obs_batch, critic_ref_obs_batch, \
actions_batch, values_batch, advantages_batch, returns_batch, \
old_log_prob_batch, old_mu_batch, old_sigma_batch, dagger_actions_batch, \
hidden_states, masks, rnd_state_batch, \
obs_prev_state, ref_obs_prev_state, ref_obs_prev_mask = batch_data

# PPO update using batch_data

# Clear for next rollout
storage.clear()

๐Ÿ’ก Key Point: Storage variables are conditionally created based on training configuration, and data is efficiently organized for both sequential (distillation) and randomized (RL) access patterns.