Rollout Storage for Motion Mimic PPO
Module: rsl_rl.storage.rollout_storage_mm
🎯 Overview
RolloutStorageMM provides efficient storage and batch generation for multi-modal reinforcement learning data. It manages observation-action trajectories with support for reference observations, privileged information, and auxiliary learning modules.
🧠 Core Components
🔄 Transition Class
Module Name: rsl_rl.storage.rollout_storage_mm.RolloutStorageMM.Transition
Stored Variables:
observations
(torch.Tensor): Actor observationsreference_observations
(torch.Tensor): Reference observation datareference_observations_mask
(torch.Tensor): Reference availability maskcritic_observations
(torch.Tensor): Critic observations (privileged info)critic_reference_observations
(torch.Tensor): Critic reference datacritic_reference_observations_mask
(torch.Tensor): Critic reference maskactions
(torch.Tensor): Executed actionsprivileged_actions
(torch.Tensor): Teacher actions (distillation only)dagger_actions
(torch.Tensor): DAgger network actionsrewards
(torch.Tensor): Environment rewardsdones
(torch.Tensor): Episode termination flagsvalues
(torch.Tensor): Value function estimatesactions_log_prob
(torch.Tensor): Action log probabilitiesaction_mean
(torch.Tensor): Policy action meansaction_sigma
(torch.Tensor): Policy action standard deviationshidden_states
(tuple): RNN hidden states (if applicable)rnd_state
(torch.Tensor): RND module states
📦 RolloutStorageMM Class
Module Name: rsl_rl.storage.rollout_storage_mm.RolloutStorageMM
Definition:
class RolloutStorageMM:
def __init__(self, training_type, num_envs, num_transitions_per_env,
obs_shape, ref_obs_shape, privileged_obs_shape,
privileged_ref_obs_shape, actions_shape,
apply_dagger_actions=False, rnd_state_shape=None,
amp_cfg=None, device="cpu")
📥 Parameters:
training_type
(str): "rl" or "distillation"num_envs
(int): Number of parallel environmentsnum_transitions_per_env
(int): Steps per rollout per environmentobs_shape
(list): Actor observation shaperef_obs_shape
(list): Reference observation shapeprivileged_obs_shape
(list): Critic observation shapeprivileged_ref_obs_shape
(list): Critic reference observation shapeactions_shape
(list): Action space shapeapply_dagger_actions
(bool): Enable DAgger action storage. Default isFalse
rnd_state_shape
(list | None): RND state shape. Default isNone
amp_cfg
(dict | None): AMP configuration. Default isNone
device
(str): Storage device. Default is"cpu"
🗄️ Storage Variables
📊 Core Storage (Always Created)
Observation Storage:
# Actor observations - Shape: (num_transitions_per_env, num_envs, *obs_shape)
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=device)
# Core action-reward storage
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=device).byte()
🎭 Conditional Storage Variables
🔍 Privileged Observations
Creation Condition: privileged_obs_shape[0] is not None
# Critic observations (privileged information)
self.privileged_observations = torch.zeros(
num_transitions_per_env, num_envs, *privileged_obs_shape, device=device
)
📚 Reference Observations
Creation Condition: ref_obs_shape[0] is not None
# Reference observations and masks
self.reference_observations = torch.zeros(num_transitions_per_env, num_envs, *ref_obs_shape, device=device)
self.reference_observations_mask = torch.zeros(num_transitions_per_env, num_envs, device=device).bool()
🎯 Privileged Reference Observations
Creation Condition: ref_obs_shape[0] is not None AND privileged_ref_obs_shape[0] is not None
# Critic reference observations
self.privileged_reference_observations = torch.zeros(
num_transitions_per_env, num_envs, *privileged_ref_obs_shape, device=device
)
self.privileged_reference_observations_mask = torch.zeros(num_transitions_per_env, num_envs, device=device).bool()
🎓 DAgger Actions
Creation Condition: apply_dagger_actions == True
# DAgger network actions for imitation learning
self.dagger_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
🏫 Distillation Training
Creation Condition: training_type == "distillation"
# Teacher actions for knowledge distillation
self.privileged_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
💰 RL Training Variables
Creation Condition: training_type == "rl"
# PPO-specific storage
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
🔍 RND Storage
Creation Condition: rnd_state_shape is not None
# Random Network Distillation states
self.rnd_state = torch.zeros(num_transitions_per_env, num_envs, *rnd_state_shape, device=device)
🎯 Core Methods
📝 Data Management
➕ Add Transitions
def add_transitions(self, transition: Transition)
📥 Input: transition
(Transition): Single step transition data
🔧 Function: Stores transition data at current step index, automatically copying relevant fields based on storage configuration.
🧹 Clear Storage
def clear(self)
🔧 Function: Resets step counter to 0 for new rollout collection.
📈 Return Computation
💰 Compute Returns and Advantages
def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True)
📥 Input:
last_values
(torch.Tensor): Bootstrap values for final stepgamma
(float): Discount factorlam
(float): GAE lambda parameternormalize_advantage
(bool): Normalize advantages. Default isTrue
🔧 Function: Computes returns using GAE and normalizes advantages for PPO training.
🎲 Batch Generation
🔄 Mini-Batch Generator (RL Training)
def mini_batch_generator(self, num_mini_batches, num_epochs=8)
📥 Input:
num_mini_batches
(int): Number of mini-batches per epochnum_epochs
(int): Number of training epochs. Default is8
📤 Output Tuple:
(obs_batch, # Actor observations
ref_obs_batch_rtn, # (reference_obs, ref_mask) or None
critic_observations_batch, # Critic observations
critic_ref_obs_batch_rtn, # (critic_ref_obs, critic_ref_mask) or None
actions_batch, # Actions
target_values_batch, # Value targets
advantages_batch, # GAE advantages
returns_batch, # Returns
old_actions_log_prob_batch, # Old log probabilities
old_mu_batch, # Old action means
old_sigma_batch, # Old action std deviations
dagger_actions_batch, # DAgger actions or None
hidden_states, # (None, None) - placeholder for RNN
masks_batch, # None - placeholder for RNN
rnd_state_batch, # RND states or None
obs_prev_state, # Previous observations for AMP
ref_obs_prev_state, # Previous reference observations for AMP
ref_obs_prev_mask) # Previous reference masks for AMP
🔧 Function: Generates randomized mini-batches for PPO training with proper data organization:
Data Flattening:
# Shape transformation: (num_transitions_per_env, num_envs, ...) → (batch_size, ...)
batch_size = num_envs * num_transitions_per_env
observations = self.observations.flatten(0, 1) # (batch_size, *obs_shape)
Reference Data Handling:
# Reference observations packaged as tuples for masking
ref_obs_batch_rtn = (ref_obs_batch, ref_obs_mask_batch) if ref_obs_batch is not None else None
AMP Integration:
# Previous state extraction for AMP discriminator
if self.amp_cfg and reference_observations is not None:
obs_prev_state = amp_obs_extractor(observations[prev_batch_idx])
ref_obs_prev_state, ref_obs_prev_mask = amp_ref_obs_extractor((
reference_observations[prev_batch_idx],
reference_observations_mask[prev_batch_idx]
))
🎓 Distillation Generator
def generator(self)
📤 Output Tuple:
(observations, # Actor observations
privileged_observations, # Teacher observations
ref_obs_batch_rtn, # (reference_obs, ref_mask)
privileged_ref_obs_batch_rtn, # (privileged_ref_obs, privileged_ref_mask) or None
actions, # Student actions
privileged_actions, # Teacher actions
dones) # Episode termination flags
🔧 Function: Sequential iteration over stored transitions for distillation training.
📊 Statistics
📈 Get Training Statistics
def get_statistics(self)
📤 Output: (mean_trajectory_length, mean_reward)
(tuple)
🔧 Function: Computes episode-based statistics for monitoring training progress.
🔧 Data Organization Summary
📦 Storage Layout
# Time-major storage format
Storage Shape: (num_transitions_per_env, num_envs, feature_dim)
Batch Shape: (batch_size, feature_dim) # After flattening
# Index mapping
step_idx = 0 to num_transitions_per_env-1
env_idx = 0 to num_envs-1
flat_idx = step_idx * num_envs + env_idx
🎲 Sampling Strategy
# Random shuffling for PPO
batch_size = num_envs * num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(batch_size) # Random permutation
# Mini-batch extraction
for i in range(num_mini_batches):
start = i * mini_batch_size
end = (i + 1) * mini_batch_size
batch_idx = indices[start:end]
# Extract data using batch_idx
🎭 Multi-Modal Data Packaging
# Reference observations always packaged with masks
if reference_observations is not None:
ref_obs_tuple = (reference_data, reference_mask)
else:
ref_obs_tuple = None
# Separate handling for actor and critic reference data
actor_ref_tuple = (ref_obs, ref_mask) if available else None
critic_ref_tuple = (critic_ref_obs, critic_ref_mask) if available else None
💡 Usage Pattern
# Initialize storage
storage = RolloutStorageMM(
training_type="rl",
num_envs=4096,
num_transitions_per_env=24,
obs_shape=[48],
ref_obs_shape=[26],
privileged_obs_shape=[48],
privileged_ref_obs_shape=[26],
actions_shape=[19],
apply_dagger_actions=True,
rnd_state_shape=[48],
device="cuda"
)
# Data collection
for step in range(24):
# ... environment interaction ...
storage.add_transitions(transition)
# Compute returns
storage.compute_returns(last_values, gamma=0.99, lam=0.95)
# Training
for epoch in range(5):
for batch_data in storage.mini_batch_generator(num_mini_batches=4):
obs_batch, ref_obs_batch, critic_obs_batch, critic_ref_obs_batch, \
actions_batch, values_batch, advantages_batch, returns_batch, \
old_log_prob_batch, old_mu_batch, old_sigma_batch, dagger_actions_batch, \
hidden_states, masks, rnd_state_batch, \
obs_prev_state, ref_obs_prev_state, ref_obs_prev_mask = batch_data
# PPO update using batch_data
# Clear for next rollout
storage.clear()
💡 Key Point: Storage variables are conditionally created based on training configuration, and data is efficiently organized for both sequential (distillation) and randomized (RL) access patterns.