ref_buffer
Module: GBC.utils.buffer.ref_buffer
This module implements an exceptionally high-efficiency and GPU memory-friendly reference data buffer system for imitation learning. The BufferManager
provides optimized access to reference motion data while supporting advanced features like domain randomization adaptation, efficient integral-based computation for dynamic values (root positions, joint positions), and high-performance cyclic sequence processing. This system serves as the foundation for the reference_observation_manager
and significantly accelerates training efficiency through intelligent memory management and computational optimization.
🚀 Key Performance Advantages
⚡ Ultra-High Efficiency
- GPU-Native Operations: Leverages Warp kernels for parallel computation across thousands of environments
- Memory-Optimized Storage: Intelligent buffer concatenation minimizes GPU memory fragmentation
- Vectorized Access: Batch processing of reference data across all environments simultaneously
- Cached Index Computation: Avoids redundant calculations through intelligent caching mechanisms
🧠 Memory-Friendly Design
- Shared Buffer Architecture: Single consolidated buffer for all reference sequences
- Dynamic Buffer Types: Support for constant, singular, and recurrent data with different memory patterns
- Efficient Indexing: Compact index structures minimize memory overhead
- Zero-Copy Operations: Direct GPU tensor operations without CPU-GPU transfers
🔄 Domain Randomization Adaptation
- Dynamic Reference Assignment: Real-time environment-to-reference mapping for curriculum learning
- Efficient Reset Mechanisms: Fast environment resets with minimal computational overhead
- Multi-Reference Support: Simultaneous handling of diverse motion patterns across environments
📈 Advanced Integral Computation
- Cumulative Observation Calculation: High-performance integration for position/velocity data
- Cyclic Sequence Optimization: Specialized handling of periodic motions with mathematical precision
- Real-Time Base Pose Tracking: Efficient root position and orientation computation from velocity data
🏗️ Core Architecture
📊 BufferType Enumeration
Definition:
class BufferType:
singular = 0 # Single-shot sequences (non-repeating)
recurrent = 1 # Standard cyclic sequences
recurrent_strict = 2 # Strictly enforced cyclic sequences
🎯 Buffer Type Characteristics:
- Singular: Non-repeating sequences, ideal for one-shot motions (e.g., getting up, specific gestures)
- Recurrent: Cyclic sequences with flexible repetition (e.g., walking, running gaits)
- Recurrent Strict: Enforced cyclic patterns with precise temporal constraints
🏭 BufferManager Class
Module Name: GBC.utils.buffer.ref_buffer.BufferManager
Definition:
class BufferManager:
def __init__(self, num_envs: int, num_ref: int, working_mode: str, device: str):
self.num_envs = num_envs # Number of parallel environments
self.num_ref = num_ref # Number of reference sequences
self.device = device # GPU device for operations
# Core buffer management
self.buffer_type = torch.ones(num_ref, dtype=torch.int8, device=device) * buffer_type_id
self.frame_rate = torch.zeros(num_ref, dtype=torch.float32, device=device)
self.max_len = torch.zeros(num_ref, dtype=torch.int32, device=device)
self.recurrent_subseq = torch.ones((num_ref, 2), dtype=torch.int32, device=device) * -1
# Environment-reference mapping
self.env_ref_id = torch.zeros(num_envs, dtype=torch.int32, device=device)
# High-performance storage
self.ref_buffer_list = dict() # Temporary storage during setup
self.ref_buffer = dict() # Optimized consolidated buffers
self.is_constant = dict() # Constant vs. time-varying data flags
📥 Initialization Parameters:
- num_envs (int): Number of parallel training environments
- num_ref (int): Number of reference motion sequences
- working_mode (str): Buffer operation mode ("singular", "recurrent", "recurrent_strict")
- device (str): GPU device identifier (e.g., "cuda:0")
🔧 Core Functionality
📥 Reference Data Management
Method Signature:
def add_reference(self,
key: str,
ref_id: int,
buffer_raw: torch.Tensor,
is_constant: bool,
frame_rate: int,
cyclic_subseq: tuple = None):
📥 Input Parameters:
- key (str): Data identifier (e.g., "joint_pos", "lin_vel", "ang_vel")
- ref_id (int): Unique reference sequence identifier
- buffer_raw (torch.Tensor): Raw reference data
[T, ...]
- is_constant (bool): Whether data is time-invariant
- frame_rate (int): Sequence frame rate (Hz)
- cyclic_subseq (tuple): Optional cyclic subsequence bounds
(start_idx, end_idx)
🔧 Processing Pipeline:
# 1. Configure buffer type and parameters
buffer_type_id = BufferType.recurrent
self.buffer_type[ref_id] = buffer_type_id
self.frame_rate[ref_id] = frame_rate
self.is_constant[key] = is_constant
# 2. Handle cyclic subsequences for data augmentation
if not is_constant and cyclic_subseq is not None:
st, ed = cyclic_subseq
buffer_raw = buffer_raw[:ed, ...] # Extract cyclic pattern
self.recurrent_subseq[ref_id, 0] = st
self.recurrent_subseq[ref_id, 1] = ed
self.max_len[ref_id] = buffer_raw.shape[0]
# 3. Store in temporary buffer list
if key not in self.ref_buffer_list:
self.ref_buffer_list[key] = [None for _ in range(self.num_ref)]
self.ref_buffer_list[key][ref_id] = buffer_raw
🏗️ Buffer Optimization and Consolidation
Method Signature:
def prepare_buffers(self, key: str):
🔧 Optimization Process:
# Memory-efficient buffer consolidation
if self.is_constant[key]:
# Constant data: simple stacking
self.ref_buffer[key] = torch.stack(self.ref_buffer_list[key])
else:
# Time-varying data: add padding and concatenate
padding = torch.zeros_like(self.ref_buffer_list[key][0])[:1, ...]
self.ref_buffer_list[key] = [padding] + self.ref_buffer_list[key]
self.ref_buffer[key] = torch.concatenate(self.ref_buffer_list[key], dim=0).contiguous()
# Compute efficient start indices for fast access
if self.start_index is None:
len_sum = torch.cumsum(self.max_len, dim=0)
self.start_index = torch.zeros(self.num_ref, dtype=torch.int32, device=self.device)
self.start_index[1:] = len_sum[:-1]
self.start_index += 1 # Account for padding
⚡ High-Performance Access Methods
🎯 GPU-Accelerated Index Computation
Warp Kernel Implementation:
@wp.kernel
def compute_idx(
current_idx: wp.array(dtype=int),
current_time: wp.array(dtype=float),
env_ref_id: wp.array(dtype=int),
frame_rate: wp.array(dtype=float),
start_index: wp.array(dtype=int),
max_len: wp.array(dtype=int),
recurrent_subseq: wp.array(dtype=wp.vec2i),
) -> int:
tid = wp.tid() # Thread ID for parallel processing
rid = env_ref_id[tid] # Reference ID for this environment
# Convert time to frame index
idx = wp.int32(current_time[tid] * frame_rate[rid])
if current_time[tid] < 0:
idx = 0
elif recurrent_subseq[rid][0] == -1 or max_len[rid] == -1:
# Singular buffer: linear progression
if idx >= max_len[rid]:
idx = -1 # End of sequence marker
else:
idx += start_index[rid]
else:
# Recurrent buffer: cyclic repetition
rec_st = recurrent_subseq[rid][0]
rec_ed = recurrent_subseq[rid][1]
rec_len = rec_ed - rec_st
if idx >= rec_st:
# Cyclic wrapping for continuous playback
idx = rec_st + (idx - rec_st) % rec_len
idx += start_index[rid]
current_idx[tid] = idx
Performance Benefits:
- Parallel Processing: Simultaneous index computation for all environments
- Cyclic Optimization: Efficient modulo operations for seamless looping
- Memory Coalescing: Optimized memory access patterns for GPU architectures
📊 Standard Observation Access
Method Signature:
def calc_obs(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
🔧 Implementation:
def calc_obs(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
"""High-performance observation retrieval with intelligent caching"""
if self.is_constant[key]:
# Constant data: direct environment-based indexing
return self.ref_buffer[key][self.env_ref_id, ...]
# Time-varying data: efficient index computation with caching
current_idx = self.calc_idx(current_time)
return self.ref_buffer[key][current_idx, ...]
🧮 Advanced Cumulative Computation
Method Signature:
def calc_cumulative_obs_v2(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
🔧 Optimized Implementation:
def calc_cumulative_obs_v2(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
"""Ultra-high performance cumulative observation calculation"""
if self.is_constant[key]:
return self.ref_buffer[key][self.env_ref_id, ...]
# Compute cyclic sequence parameters
current_num_cyclic_subseq, current_idx, current_begin_idx, current_end_idx = \
self.calc_num_cyclic_subseq(current_time)
start_idx = self.start_index[self.env_ref_id] # shape: (num_envs,)
# Initialize output tensor
cumulative_obs = torch.zeros(
(self.num_envs, *self.ref_buffer[key].shape[1:]),
dtype=self.ref_buffer[key].dtype,
device=self.device
)
# Vectorized index computation for all environments
begin_indices = current_begin_idx + start_idx
end_indices = current_end_idx + start_idx
current_indices = current_idx + start_idx
# Dynamic sequence length determination
max_first_len = (begin_indices - start_idx).max()
max_cyclic_len = (end_indices - begin_indices).max()
max_last_len = (current_indices - begin_indices).max()
# Parallel index tensor creation
first_indices = torch.arange(max_first_len, device=self.device).unsqueeze(0) + start_idx.unsqueeze(1)
cyclic_indices = torch.arange(max_cyclic_len, device=self.device).unsqueeze(0) + begin_indices.unsqueeze(1)
last_indices = torch.arange(max_last_len, device=self.device).unsqueeze(0) + begin_indices.unsqueeze(1)
# Efficient masking for valid indices
first_mask = first_indices < begin_indices.unsqueeze(1)
cyclic_mask = cyclic_indices < end_indices.unsqueeze(1)
last_mask = last_indices < current_indices.unsqueeze(1)
# Boundary protection for buffer access
buffer_max_len = self.ref_buffer[key].shape[0]
first_indices = torch.where(first_indices < buffer_max_len, first_indices, buffer_max_len - 1)
cyclic_indices = torch.where(cyclic_indices < buffer_max_len, cyclic_indices, buffer_max_len - 1)
last_indices = torch.where(last_indices < buffer_max_len, last_indices, buffer_max_len - 1)
# High-performance vectorized data retrieval
first_sequences = self.ref_buffer[key][first_indices] * first_mask.unsqueeze(-1)
cyclic_sequences = self.ref_buffer[key][cyclic_indices] * cyclic_mask.unsqueeze(-1)
last_sequences = self.ref_buffer[key][last_indices] * last_mask.unsqueeze(-1)
# Efficient integration with cyclic repetition support
cumulative_obs += torch.sum(first_sequences, dim=1)
if torch.any(current_end_idx > current_begin_idx):
cumulative_obs += torch.sum(cyclic_sequences, dim=1) * current_num_cyclic_subseq.to(torch.float32).unsqueeze(1)
cumulative_obs += torch.sum(last_sequences, dim=1)
return cumulative_obs
⚡ Performance Optimizations:
- Vectorized Operations: Batch processing across all environments
- Memory Coalescing: Optimized GPU memory access patterns
- Intelligent Masking: Efficient boundary handling without branching
- Cyclic Integration: Mathematical precision for periodic sequences
🌐 Dynamic Base Pose Computation
🎯 Real-Time Pose Integration
Method Signature:
def calc_base_pose(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
🔧 Implementation:
def calc_base_pose(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
"""Real-time base pose computation from velocity data"""
# Retrieve velocity observations
lin_vel_yaw_frame = self.calc_obs(lin_vel_name, current_time)
ang_vel = self.calc_obs(ang_vel_name, current_time)
# Incremental pose integration
self.step_robot_base_pose(current_time, lin_vel_yaw_frame, ang_vel)
# Return combined pose (position + quaternion)
return torch.cat([self.base_pos, self.base_quat], dim=1)
def step_robot_base_pose(self, current_time: torch.Tensor, lin_vel_yaw_frame: torch.Tensor, ang_vel: torch.Tensor):
"""High-precision incremental pose integration"""
# Compute time delta with safety checks
dt = torch.where(self.last_pose_tme < 0, torch.zeros_like(self.last_pose_tme), current_time - self.last_pose_tme)
dt = dt.clamp(min=0) # Ensure non-negative time steps
# Transform linear velocity to world frame
quat_yaw = yaw_quat(self.base_quat) # Extract yaw-only quaternion
lin_vel = quat_apply(quat_yaw, lin_vel_yaw_frame)
# Integrate position
self.base_pos += lin_vel * dt.unsqueeze(1)
# Integrate orientation using quaternion mathematics
rot_vec = quat_apply(quat_inv(self.base_quat), ang_vel) * dt.unsqueeze(1)
self.base_quat = quat_mul(self.base_quat, quat_inv(angle_axis_to_quaternion(rot_vec)))
# Update time tracking
self.last_pose_tme = torch.where(
torch.logical_and(self.last_pose_tme < 0, current_time < 1e-5),
self.last_pose_tme,
current_time.clone()
)
📈 Cumulative Base Pose Computation
Method Signature:
def calc_base_pose_cumulative(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
🔧 Advanced Implementation:
def calc_base_pose_cumulative(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
"""Ultra-efficient cumulative pose computation using integral methods"""
# Frame rate-based time step calculation
dt = 1.0 / self.frame_rate[self.env_ref_id]
# High-performance cumulative integration
lin_pos = self.calc_cumulative_obs_v2(lin_vel_name, current_time) * dt.unsqueeze(1)
ang_pos = self.calc_cumulative_obs_v2(ang_vel_name, current_time) * dt.unsqueeze(1)
# Normalize angular position to [0, 2π] range
ang_pos = ang_pos % (2 * np.pi)
# Convert to quaternion representation
base_quat = quat_from_euler_xyz(*ang_pos.T)
return torch.cat([lin_pos, base_quat], dim=1)
📊 Performance Benefits:
- Integral-Based Computation: Direct mathematical integration without iterative steps
- Vectorized Processing: Simultaneous computation across all environments
- Numerical Stability: Robust handling of angular wrap-around and discontinuities
🎛️ Environment Management
🔄 Intelligent Reset System
Method Signature:
def reset(self, env, env_ids: Sequence[int] | None = None):
🔧 Implementation:
def reset(self, env, env_ids: Sequence[int] | None = None):
"""Efficient environment reset with domain randomization support"""
if env_ids is None:
env_ids = slice(None) # Reset all environments
# Random reference assignment for curriculum learning
self.env_ref_id[env_ids] = torch.randint(
0, self.num_ref,
size=self.env_ref_id[env_ids].shape,
dtype=self.env_ref_id.dtype,
device=self.env_ref_id.device,
)
# Reset pose tracking
self.last_pose_tme[env_ids] = -1
# Initialize base pose from environment state
root_pose_w = env.scene["robot"].data.root_state_w[env_ids, :7]
root_pose_w = root_pose_w.clone()
self.base_pos[env_ids] = root_pose_w[:, :3] - env.scene.env_origins[env_ids]
self.base_quat[env_ids] = root_pose_w[:, 3:]
📊 Sequence Validation
Method Signature:
def calc_mask(self, current_time: torch.Tensor) -> torch.Tensor:
🔧 Implementation:
def calc_mask(self, current_time: torch.Tensor) -> torch.Tensor:
"""Efficient sequence validity checking"""
current_idx = self.calc_idx(current_time)
return current_idx >= 0 # Valid sequence positions
💡 Usage Examples
🚀 Basic Buffer Setup
from GBC.utils.buffer.ref_buffer import BufferManager
import torch
# Initialize high-performance buffer manager
buffer_manager = BufferManager(
num_envs=4096, # Large-scale parallel training
num_ref=100, # Multiple reference sequences
working_mode="recurrent", # Cyclic sequence support
device="cuda:0"
)
# Add joint position reference data
joint_positions = torch.randn(200, 29, device="cuda:0") # [T, num_joints]
buffer_manager.add_reference(
key="joint_pos",
ref_id=0,
buffer_raw=joint_positions,
is_constant=False,
frame_rate=30,
cyclic_subseq=(20, 180) # Detected walking cycle
)
# Add velocity data for dynamic pose computation
lin_velocities = torch.randn(200, 3, device="cuda:0") # [T, 3]
ang_velocities = torch.randn(200, 3, device="cuda:0") # [T, 3]
buffer_manager.add_reference("lin_vel", 0, lin_velocities, False, 30, (20, 180))
buffer_manager.add_reference("ang_vel", 0, ang_velocities, False, 30, (20, 180))
# Optimize buffers for maximum performance
buffer_manager.prepare_buffers("joint_pos")
buffer_manager.prepare_buffers("lin_vel")
buffer_manager.prepare_buffers("ang_vel")
🎯 High-Performance Training Loop
# Training loop with ultra-efficient data access
current_time = torch.rand(4096, device="cuda:0") * 10.0 # Random time points
# Standard observation retrieval (sub-millisecond access)
joint_targets = buffer_manager.calc_obs("joint_pos", current_time) # [4096, 29]
# Advanced cumulative computation for integral-based features
cumulative_displacement = buffer_manager.calc_cumulative_obs_v2("lin_vel", current_time) # [4096, 3]
# Real-time base pose computation
base_poses = buffer_manager.calc_base_pose(current_time, "lin_vel", "ang_vel") # [4096, 7]
# Sequence validity checking
valid_mask = buffer_manager.calc_mask(current_time) # [4096] boolean mask
print(f"Joint targets shape: {joint_targets.shape}")
print(f"Valid environments: {valid_mask.sum().item()}/{len(valid_mask)}")
🔄 Domain Randomization Integration
# Efficient environment reset with automatic reference randomization
env_ids_to_reset = torch.where(episode_done)[0] # Completed episodes
buffer_manager.reset(env, env_ids_to_reset) # Automatic reference shuffling
# Reference assignment for curriculum learning
# Easy sequences for new environments
easy_ref_ids = torch.randint(0, 20, (1024,), device="cuda:0") # First 20 refs are easy
buffer_manager.set_all_env_ref_id(easy_ref_ids)
# Advanced sequences for experienced environments
hard_ref_ids = torch.randint(80, 100, (1024,), device="cuda:0") # Last 20 refs are hard
buffer_manager.env_ref_id[2048:3072] = hard_ref_ids # Assign to specific envs
📈 Performance Benchmarking
import time
# Benchmark cumulative computation efficiency
num_iterations = 1000
current_time = torch.rand(4096, device="cuda:0") * 10.0
# Standard method timing
start_time = time.time()
for _ in range(num_iterations):
result_v1 = buffer_manager.calc_cumulative_obs("lin_vel", current_time)
v1_time = time.time() - start_time
# Optimized method timing
start_time = time.time()
for _ in range(num_iterations):
result_v2 = buffer_manager.calc_cumulative_obs_v2("lin_vel", current_time)
v2_time = time.time() - start_time
print(f"Standard method: {v1_time:.4f}s")
print(f"Optimized method: {v2_time:.4f}s")
print(f"Speedup: {v1_time/v2_time:.2f}x")
print(f"Results match: {torch.allclose(result_v1, result_v2)}")
🚨 Performance Optimization Guidelines
✅ Memory Efficiency Best Practices
- Buffer Consolidation: Always call
prepare_buffers()
after adding all references for a key - Device Consistency: Ensure all tensors are on the same GPU device
- Data Type Optimization: Use appropriate precision (float32 vs float64) based on requirements
- Batch Size Tuning: Optimize
num_envs
for your GPU memory capacity
⚡ Computational Efficiency
- Cyclic Sequence Design: Use
cyclic_subseq
parameters for walking/running motions - Reference Diversity: Balance number of references vs. memory usage
- Access Pattern Optimization: Minimize random access, prefer sequential patterns
- Caching Utilization: Leverage built-in caching for repeated time queries
🔧 Integration Optimization
# Optimal setup for maximum performance
def setup_high_performance_buffer():
"""Configure buffer for maximum training efficiency"""
buffer_manager = BufferManager(
num_envs=8192, # Power of 2 for optimal GPU utilization
num_ref=128, # Sufficient diversity for robust training
working_mode="recurrent",
device="cuda:0"
)
# Batch reference addition for efficiency
reference_data = load_reference_dataset() # Your data loading function
for ref_id, (key, data, cyclic_info) in enumerate(reference_data):
buffer_manager.add_reference(
key=key,
ref_id=ref_id,
buffer_raw=data.to("cuda:0"),
is_constant=False,
frame_rate=30,
cyclic_subseq=cyclic_info
)
# Optimize all buffers simultaneously
for key in ["joint_pos", "lin_vel", "ang_vel", "foot_contact"]:
buffer_manager.prepare_buffers(key)
return buffer_manager
This ultra-high-performance buffer system provides the foundation for efficient imitation learning training, enabling massive-scale parallel environments while maintaining sub-millisecond access times and minimal GPU memory footprint. The intelligent integral-based computation and advanced cyclic sequence handling make it an essential component for production-grade reinforcement learning pipelines.