Rollout Storage for Motion Mimic PPO
Module: rsl_rl.storage.rollout_storage_mm
🎯 Overview
RolloutStorageMM provides efficient storage and batch generation for multi-modal reinforcement learning data. It manages observation-action trajectories with support for reference observations, privileged information, and auxiliary learning modules.
🧠 Core Components
🔄 Transition Class
Module Name: rsl_rl.storage.rollout_storage_mm.RolloutStorageMM.Transition
Stored Variables:
- observations(torch.Tensor): Actor observations
- reference_observations(torch.Tensor): Reference observation data
- reference_observations_mask(torch.Tensor): Reference availability mask
- critic_observations(torch.Tensor): Critic observations (privileged info)
- critic_reference_observations(torch.Tensor): Critic reference data
- critic_reference_observations_mask(torch.Tensor): Critic reference mask
- actions(torch.Tensor): Executed actions
- privileged_actions(torch.Tensor): Teacher actions (distillation only)
- dagger_actions(torch.Tensor): DAgger network actions
- rewards(torch.Tensor): Environment rewards
- dones(torch.Tensor): Episode termination flags
- values(torch.Tensor): Value function estimates
- actions_log_prob(torch.Tensor): Action log probabilities
- action_mean(torch.Tensor): Policy action means
- action_sigma(torch.Tensor): Policy action standard deviations
- hidden_states(tuple): RNN hidden states (if applicable)
- rnd_state(torch.Tensor): RND module states
📦 RolloutStorageMM Class
Module Name: rsl_rl.storage.rollout_storage_mm.RolloutStorageMM
Definition:
class RolloutStorageMM:
    def __init__(self, training_type, num_envs, num_transitions_per_env, 
                 obs_shape, ref_obs_shape, privileged_obs_shape, 
                 privileged_ref_obs_shape, actions_shape, 
                 apply_dagger_actions=False, rnd_state_shape=None, 
                 amp_cfg=None, device="cpu")
📥 Parameters:
- training_type(str): "rl" or "distillation"
- num_envs(int): Number of parallel environments
- num_transitions_per_env(int): Steps per rollout per environment
- obs_shape(list): Actor observation shape
- ref_obs_shape(list): Reference observation shape
- privileged_obs_shape(list): Critic observation shape
- privileged_ref_obs_shape(list): Critic reference observation shape
- actions_shape(list): Action space shape
- apply_dagger_actions(bool): Enable DAgger action storage. Default is- False
- rnd_state_shape(list | None): RND state shape. Default is- None
- amp_cfg(dict | None): AMP configuration. Default is- None
- device(str): Storage device. Default is- "cpu"
🗄️ Storage Variables
📊 Core Storage (Always Created)
Observation Storage:
# Actor observations - Shape: (num_transitions_per_env, num_envs, *obs_shape)
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=device)
# Core action-reward storage
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=device).byte()
🎭 Conditional Storage Variables
🔍 Privileged Observations
Creation Condition: privileged_obs_shape[0] is not None
# Critic observations (privileged information)
self.privileged_observations = torch.zeros(
    num_transitions_per_env, num_envs, *privileged_obs_shape, device=device
)
📚 Reference Observations
Creation Condition: ref_obs_shape[0] is not None
# Reference observations and masks
self.reference_observations = torch.zeros(num_transitions_per_env, num_envs, *ref_obs_shape, device=device)
self.reference_observations_mask = torch.zeros(num_transitions_per_env, num_envs, device=device).bool()
🎯 Privileged Reference Observations
Creation Condition: ref_obs_shape[0] is not None AND privileged_ref_obs_shape[0] is not None
# Critic reference observations
self.privileged_reference_observations = torch.zeros(
    num_transitions_per_env, num_envs, *privileged_ref_obs_shape, device=device
)
self.privileged_reference_observations_mask = torch.zeros(num_transitions_per_env, num_envs, device=device).bool()
🎓 DAgger Actions
Creation Condition: apply_dagger_actions == True
# DAgger network actions for imitation learning
self.dagger_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
🏫 Distillation Training
Creation Condition: training_type == "distillation"
# Teacher actions for knowledge distillation
self.privileged_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
💰 RL Training Variables
Creation Condition: training_type == "rl"
# PPO-specific storage
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
🔍 RND Storage
Creation Condition: rnd_state_shape is not None
# Random Network Distillation states
self.rnd_state = torch.zeros(num_transitions_per_env, num_envs, *rnd_state_shape, device=device)
🎯 Core Methods
📝 Data Management
➕ Add Transitions
def add_transitions(self, transition: Transition)
📥 Input: transition (Transition): Single step transition data
🔧 Function: Stores transition data at current step index, automatically copying relevant fields based on storage configuration.
🧹 Clear Storage
def clear(self)
🔧 Function: Resets step counter to 0 for new rollout collection.
📈 Return Computation
💰 Compute Returns and Advantages
def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True)
📥 Input:
- last_values(torch.Tensor): Bootstrap values for final step
- gamma(float): Discount factor
- lam(float): GAE lambda parameter
- normalize_advantage(bool): Normalize advantages. Default is- True
🔧 Function: Computes returns using GAE and normalizes advantages for PPO training.
🎲 Batch Generation
🔄 Mini-Batch Generator (RL Training)
def mini_batch_generator(self, num_mini_batches, num_epochs=8)
📥 Input:
- num_mini_batches(int): Number of mini-batches per epoch
- num_epochs(int): Number of training epochs. Default is- 8
📤 Output Tuple:
(obs_batch,                    # Actor observations
 ref_obs_batch_rtn,           # (reference_obs, ref_mask) or None
 critic_observations_batch,    # Critic observations
 critic_ref_obs_batch_rtn,    # (critic_ref_obs, critic_ref_mask) or None
 actions_batch,               # Actions
 target_values_batch,         # Value targets
 advantages_batch,            # GAE advantages
 returns_batch,               # Returns
 old_actions_log_prob_batch,  # Old log probabilities
 old_mu_batch,                # Old action means
 old_sigma_batch,             # Old action std deviations
 dagger_actions_batch,        # DAgger actions or None
 hidden_states,               # (None, None) - placeholder for RNN
 masks_batch,                 # None - placeholder for RNN
 rnd_state_batch,             # RND states or None
 obs_prev_state,              # Previous observations for AMP
 ref_obs_prev_state,          # Previous reference observations for AMP
 ref_obs_prev_mask)           # Previous reference masks for AMP
🔧 Function: Generates randomized mini-batches for PPO training with proper data organization:
Data Flattening:
# Shape transformation: (num_transitions_per_env, num_envs, ...) → (batch_size, ...)
batch_size = num_envs * num_transitions_per_env
observations = self.observations.flatten(0, 1)  # (batch_size, *obs_shape)
Reference Data Handling:
# Reference observations packaged as tuples for masking
ref_obs_batch_rtn = (ref_obs_batch, ref_obs_mask_batch) if ref_obs_batch is not None else None
AMP Integration:
# Previous state extraction for AMP discriminator
if self.amp_cfg and reference_observations is not None:
    obs_prev_state = amp_obs_extractor(observations[prev_batch_idx])
    ref_obs_prev_state, ref_obs_prev_mask = amp_ref_obs_extractor((
        reference_observations[prev_batch_idx],
        reference_observations_mask[prev_batch_idx]
    ))
🎓 Distillation Generator
def generator(self)
📤 Output Tuple:
(observations,                    # Actor observations
 privileged_observations,         # Teacher observations
 ref_obs_batch_rtn,              # (reference_obs, ref_mask)
 privileged_ref_obs_batch_rtn,   # (privileged_ref_obs, privileged_ref_mask) or None
 actions,                        # Student actions
 privileged_actions,             # Teacher actions
 dones)                          # Episode termination flags
🔧 Function: Sequential iteration over stored transitions for distillation training.
📊 Statistics
📈 Get Training Statistics
def get_statistics(self)
📤 Output: (mean_trajectory_length, mean_reward) (tuple)
🔧 Function: Computes episode-based statistics for monitoring training progress.
🔧 Data Organization Summary
📦 Storage Layout
# Time-major storage format
Storage Shape: (num_transitions_per_env, num_envs, feature_dim)
Batch Shape:   (batch_size, feature_dim)  # After flattening
# Index mapping
step_idx = 0 to num_transitions_per_env-1
env_idx = 0 to num_envs-1
flat_idx = step_idx * num_envs + env_idx
🎲 Sampling Strategy
# Random shuffling for PPO
batch_size = num_envs * num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(batch_size)  # Random permutation
# Mini-batch extraction
for i in range(num_mini_batches):
    start = i * mini_batch_size
    end = (i + 1) * mini_batch_size
    batch_idx = indices[start:end]
    # Extract data using batch_idx
🎭 Multi-Modal Data Packaging
# Reference observations always packaged with masks
if reference_observations is not None:
    ref_obs_tuple = (reference_data, reference_mask)
else:
    ref_obs_tuple = None
# Separate handling for actor and critic reference data
actor_ref_tuple = (ref_obs, ref_mask) if available else None
critic_ref_tuple = (critic_ref_obs, critic_ref_mask) if available else None
💡 Usage Pattern
# Initialize storage
storage = RolloutStorageMM(
    training_type="rl",
    num_envs=4096,
    num_transitions_per_env=24,
    obs_shape=[48],
    ref_obs_shape=[26],
    privileged_obs_shape=[48],
    privileged_ref_obs_shape=[26],
    actions_shape=[19],
    apply_dagger_actions=True,
    rnd_state_shape=[48],
    device="cuda"
)
# Data collection
for step in range(24):
    # ... environment interaction ...
    storage.add_transitions(transition)
# Compute returns
storage.compute_returns(last_values, gamma=0.99, lam=0.95)
# Training
for epoch in range(5):
    for batch_data in storage.mini_batch_generator(num_mini_batches=4):
        obs_batch, ref_obs_batch, critic_obs_batch, critic_ref_obs_batch, \
        actions_batch, values_batch, advantages_batch, returns_batch, \
        old_log_prob_batch, old_mu_batch, old_sigma_batch, dagger_actions_batch, \
        hidden_states, masks, rnd_state_batch, \
        obs_prev_state, ref_obs_prev_state, ref_obs_prev_mask = batch_data
        
        # PPO update using batch_data
        
# Clear for next rollout
storage.clear()
💡 Key Point: Storage variables are conditionally created based on training configuration, and data is efficiently organized for both sequential (distillation) and randomized (RL) access patterns.