跳到主要内容

Rollout Storage for Motion Mimic PPO

Module: rsl_rl.storage.rollout_storage_mm

🎯 Overview

RolloutStorageMM provides efficient storage and batch generation for multi-modal reinforcement learning data. It manages observation-action trajectories with support for reference observations, privileged information, and auxiliary learning modules.


🧠 Core Components

🔄 Transition Class

Module Name: rsl_rl.storage.rollout_storage_mm.RolloutStorageMM.Transition

Stored Variables:

  • observations (torch.Tensor): Actor observations
  • reference_observations (torch.Tensor): Reference observation data
  • reference_observations_mask (torch.Tensor): Reference availability mask
  • critic_observations (torch.Tensor): Critic observations (privileged info)
  • critic_reference_observations (torch.Tensor): Critic reference data
  • critic_reference_observations_mask (torch.Tensor): Critic reference mask
  • actions (torch.Tensor): Executed actions
  • privileged_actions (torch.Tensor): Teacher actions (distillation only)
  • dagger_actions (torch.Tensor): DAgger network actions
  • rewards (torch.Tensor): Environment rewards
  • dones (torch.Tensor): Episode termination flags
  • values (torch.Tensor): Value function estimates
  • actions_log_prob (torch.Tensor): Action log probabilities
  • action_mean (torch.Tensor): Policy action means
  • action_sigma (torch.Tensor): Policy action standard deviations
  • hidden_states (tuple): RNN hidden states (if applicable)
  • rnd_state (torch.Tensor): RND module states

📦 RolloutStorageMM Class

Module Name: rsl_rl.storage.rollout_storage_mm.RolloutStorageMM

Definition:

class RolloutStorageMM:
def __init__(self, training_type, num_envs, num_transitions_per_env,
obs_shape, ref_obs_shape, privileged_obs_shape,
privileged_ref_obs_shape, actions_shape,
apply_dagger_actions=False, rnd_state_shape=None,
amp_cfg=None, device="cpu")

📥 Parameters:

  • training_type (str): "rl" or "distillation"
  • num_envs (int): Number of parallel environments
  • num_transitions_per_env (int): Steps per rollout per environment
  • obs_shape (list): Actor observation shape
  • ref_obs_shape (list): Reference observation shape
  • privileged_obs_shape (list): Critic observation shape
  • privileged_ref_obs_shape (list): Critic reference observation shape
  • actions_shape (list): Action space shape
  • apply_dagger_actions (bool): Enable DAgger action storage. Default is False
  • rnd_state_shape (list | None): RND state shape. Default is None
  • amp_cfg (dict | None): AMP configuration. Default is None
  • device (str): Storage device. Default is "cpu"

🗄️ Storage Variables

📊 Core Storage (Always Created)

Observation Storage:

# Actor observations - Shape: (num_transitions_per_env, num_envs, *obs_shape)
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=device)

# Core action-reward storage
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=device).byte()

🎭 Conditional Storage Variables

🔍 Privileged Observations

Creation Condition: privileged_obs_shape[0] is not None

# Critic observations (privileged information)
self.privileged_observations = torch.zeros(
num_transitions_per_env, num_envs, *privileged_obs_shape, device=device
)

📚 Reference Observations

Creation Condition: ref_obs_shape[0] is not None

# Reference observations and masks
self.reference_observations = torch.zeros(num_transitions_per_env, num_envs, *ref_obs_shape, device=device)
self.reference_observations_mask = torch.zeros(num_transitions_per_env, num_envs, device=device).bool()

🎯 Privileged Reference Observations

Creation Condition: ref_obs_shape[0] is not None AND privileged_ref_obs_shape[0] is not None

# Critic reference observations
self.privileged_reference_observations = torch.zeros(
num_transitions_per_env, num_envs, *privileged_ref_obs_shape, device=device
)
self.privileged_reference_observations_mask = torch.zeros(num_transitions_per_env, num_envs, device=device).bool()

🎓 DAgger Actions

Creation Condition: apply_dagger_actions == True

# DAgger network actions for imitation learning
self.dagger_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)

🏫 Distillation Training

Creation Condition: training_type == "distillation"

# Teacher actions for knowledge distillation
self.privileged_actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)

💰 RL Training Variables

Creation Condition: training_type == "rl"

# PPO-specific storage
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=device)
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=device)

🔍 RND Storage

Creation Condition: rnd_state_shape is not None

# Random Network Distillation states
self.rnd_state = torch.zeros(num_transitions_per_env, num_envs, *rnd_state_shape, device=device)

🎯 Core Methods

📝 Data Management

➕ Add Transitions

def add_transitions(self, transition: Transition)

📥 Input: transition (Transition): Single step transition data 🔧 Function: Stores transition data at current step index, automatically copying relevant fields based on storage configuration.

🧹 Clear Storage

def clear(self)

🔧 Function: Resets step counter to 0 for new rollout collection.

📈 Return Computation

💰 Compute Returns and Advantages

def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True)

📥 Input:

  • last_values (torch.Tensor): Bootstrap values for final step
  • gamma (float): Discount factor
  • lam (float): GAE lambda parameter
  • normalize_advantage (bool): Normalize advantages. Default is True

🔧 Function: Computes returns using GAE and normalizes advantages for PPO training.

🎲 Batch Generation

🔄 Mini-Batch Generator (RL Training)

def mini_batch_generator(self, num_mini_batches, num_epochs=8)

📥 Input:

  • num_mini_batches (int): Number of mini-batches per epoch
  • num_epochs (int): Number of training epochs. Default is 8

📤 Output Tuple:

(obs_batch,                    # Actor observations
ref_obs_batch_rtn, # (reference_obs, ref_mask) or None
critic_observations_batch, # Critic observations
critic_ref_obs_batch_rtn, # (critic_ref_obs, critic_ref_mask) or None
actions_batch, # Actions
target_values_batch, # Value targets
advantages_batch, # GAE advantages
returns_batch, # Returns
old_actions_log_prob_batch, # Old log probabilities
old_mu_batch, # Old action means
old_sigma_batch, # Old action std deviations
dagger_actions_batch, # DAgger actions or None
hidden_states, # (None, None) - placeholder for RNN
masks_batch, # None - placeholder for RNN
rnd_state_batch, # RND states or None
obs_prev_state, # Previous observations for AMP
ref_obs_prev_state, # Previous reference observations for AMP
ref_obs_prev_mask) # Previous reference masks for AMP

🔧 Function: Generates randomized mini-batches for PPO training with proper data organization:

Data Flattening:

# Shape transformation: (num_transitions_per_env, num_envs, ...) → (batch_size, ...)
batch_size = num_envs * num_transitions_per_env
observations = self.observations.flatten(0, 1) # (batch_size, *obs_shape)

Reference Data Handling:

# Reference observations packaged as tuples for masking
ref_obs_batch_rtn = (ref_obs_batch, ref_obs_mask_batch) if ref_obs_batch is not None else None

AMP Integration:

# Previous state extraction for AMP discriminator
if self.amp_cfg and reference_observations is not None:
obs_prev_state = amp_obs_extractor(observations[prev_batch_idx])
ref_obs_prev_state, ref_obs_prev_mask = amp_ref_obs_extractor((
reference_observations[prev_batch_idx],
reference_observations_mask[prev_batch_idx]
))

🎓 Distillation Generator

def generator(self)

📤 Output Tuple:

(observations,                    # Actor observations
privileged_observations, # Teacher observations
ref_obs_batch_rtn, # (reference_obs, ref_mask)
privileged_ref_obs_batch_rtn, # (privileged_ref_obs, privileged_ref_mask) or None
actions, # Student actions
privileged_actions, # Teacher actions
dones) # Episode termination flags

🔧 Function: Sequential iteration over stored transitions for distillation training.

📊 Statistics

📈 Get Training Statistics

def get_statistics(self)

📤 Output: (mean_trajectory_length, mean_reward) (tuple) 🔧 Function: Computes episode-based statistics for monitoring training progress.


🔧 Data Organization Summary

📦 Storage Layout

# Time-major storage format
Storage Shape: (num_transitions_per_env, num_envs, feature_dim)
Batch Shape: (batch_size, feature_dim) # After flattening

# Index mapping
step_idx = 0 to num_transitions_per_env-1
env_idx = 0 to num_envs-1
flat_idx = step_idx * num_envs + env_idx

🎲 Sampling Strategy

# Random shuffling for PPO
batch_size = num_envs * num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(batch_size) # Random permutation

# Mini-batch extraction
for i in range(num_mini_batches):
start = i * mini_batch_size
end = (i + 1) * mini_batch_size
batch_idx = indices[start:end]
# Extract data using batch_idx

🎭 Multi-Modal Data Packaging

# Reference observations always packaged with masks
if reference_observations is not None:
ref_obs_tuple = (reference_data, reference_mask)
else:
ref_obs_tuple = None

# Separate handling for actor and critic reference data
actor_ref_tuple = (ref_obs, ref_mask) if available else None
critic_ref_tuple = (critic_ref_obs, critic_ref_mask) if available else None

💡 Usage Pattern

# Initialize storage
storage = RolloutStorageMM(
training_type="rl",
num_envs=4096,
num_transitions_per_env=24,
obs_shape=[48],
ref_obs_shape=[26],
privileged_obs_shape=[48],
privileged_ref_obs_shape=[26],
actions_shape=[19],
apply_dagger_actions=True,
rnd_state_shape=[48],
device="cuda"
)

# Data collection
for step in range(24):
# ... environment interaction ...
storage.add_transitions(transition)

# Compute returns
storage.compute_returns(last_values, gamma=0.99, lam=0.95)

# Training
for epoch in range(5):
for batch_data in storage.mini_batch_generator(num_mini_batches=4):
obs_batch, ref_obs_batch, critic_obs_batch, critic_ref_obs_batch, \
actions_batch, values_batch, advantages_batch, returns_batch, \
old_log_prob_batch, old_mu_batch, old_sigma_batch, dagger_actions_batch, \
hidden_states, masks, rnd_state_batch, \
obs_prev_state, ref_obs_prev_state, ref_obs_prev_mask = batch_data

# PPO update using batch_data

# Clear for next rollout
storage.clear()

💡 Key Point: Storage variables are conditionally created based on training configuration, and data is efficiently organized for both sequential (distillation) and randomized (RL) access patterns.