Skip to main content

ref_buffer

Module: GBC.utils.buffer.ref_buffer

This module implements an exceptionally high-efficiency and GPU memory-friendly reference data buffer system for imitation learning. The BufferManager provides optimized access to reference motion data while supporting advanced features like domain randomization adaptation, efficient integral-based computation for dynamic values (root positions, joint positions), and high-performance cyclic sequence processing. This system serves as the foundation for the reference_observation_manager and significantly accelerates training efficiency through intelligent memory management and computational optimization.

🚀 Key Performance Advantages

Ultra-High Efficiency

  • GPU-Native Operations: Leverages Warp kernels for parallel computation across thousands of environments
  • Memory-Optimized Storage: Intelligent buffer concatenation minimizes GPU memory fragmentation
  • Vectorized Access: Batch processing of reference data across all environments simultaneously
  • Cached Index Computation: Avoids redundant calculations through intelligent caching mechanisms

🧠 Memory-Friendly Design

  • Shared Buffer Architecture: Single consolidated buffer for all reference sequences
  • Dynamic Buffer Types: Support for constant, singular, and recurrent data with different memory patterns
  • Efficient Indexing: Compact index structures minimize memory overhead
  • Zero-Copy Operations: Direct GPU tensor operations without CPU-GPU transfers

🔄 Domain Randomization Adaptation

  • Dynamic Reference Assignment: Real-time environment-to-reference mapping for curriculum learning
  • Efficient Reset Mechanisms: Fast environment resets with minimal computational overhead
  • Multi-Reference Support: Simultaneous handling of diverse motion patterns across environments

📈 Advanced Integral Computation

  • Cumulative Observation Calculation: High-performance integration for position/velocity data
  • Cyclic Sequence Optimization: Specialized handling of periodic motions with mathematical precision
  • Real-Time Base Pose Tracking: Efficient root position and orientation computation from velocity data

🏗️ Core Architecture

📊 BufferType Enumeration

Definition:

class BufferType:
singular = 0 # Single-shot sequences (non-repeating)
recurrent = 1 # Standard cyclic sequences
recurrent_strict = 2 # Strictly enforced cyclic sequences

🎯 Buffer Type Characteristics:

  • Singular: Non-repeating sequences, ideal for one-shot motions (e.g., getting up, specific gestures)
  • Recurrent: Cyclic sequences with flexible repetition (e.g., walking, running gaits)
  • Recurrent Strict: Enforced cyclic patterns with precise temporal constraints

🏭 BufferManager Class

Module Name: GBC.utils.buffer.ref_buffer.BufferManager

Definition:

class BufferManager:
def __init__(self, num_envs: int, num_ref: int, working_mode: str, device: str):
self.num_envs = num_envs # Number of parallel environments
self.num_ref = num_ref # Number of reference sequences
self.device = device # GPU device for operations

# Core buffer management
self.buffer_type = torch.ones(num_ref, dtype=torch.int8, device=device) * buffer_type_id
self.frame_rate = torch.zeros(num_ref, dtype=torch.float32, device=device)
self.max_len = torch.zeros(num_ref, dtype=torch.int32, device=device)
self.recurrent_subseq = torch.ones((num_ref, 2), dtype=torch.int32, device=device) * -1

# Environment-reference mapping
self.env_ref_id = torch.zeros(num_envs, dtype=torch.int32, device=device)

# High-performance storage
self.ref_buffer_list = dict() # Temporary storage during setup
self.ref_buffer = dict() # Optimized consolidated buffers
self.is_constant = dict() # Constant vs. time-varying data flags

📥 Initialization Parameters:

  • num_envs (int): Number of parallel training environments
  • num_ref (int): Number of reference motion sequences
  • working_mode (str): Buffer operation mode ("singular", "recurrent", "recurrent_strict")
  • device (str): GPU device identifier (e.g., "cuda:0")

🔧 Core Functionality

📥 Reference Data Management

Method Signature:

def add_reference(self, 
key: str,
ref_id: int,
buffer_raw: torch.Tensor,
is_constant: bool,
frame_rate: int,
cyclic_subseq: tuple = None):

📥 Input Parameters:

  • key (str): Data identifier (e.g., "joint_pos", "lin_vel", "ang_vel")
  • ref_id (int): Unique reference sequence identifier
  • buffer_raw (torch.Tensor): Raw reference data [T, ...]
  • is_constant (bool): Whether data is time-invariant
  • frame_rate (int): Sequence frame rate (Hz)
  • cyclic_subseq (tuple): Optional cyclic subsequence bounds (start_idx, end_idx)

🔧 Processing Pipeline:

# 1. Configure buffer type and parameters
buffer_type_id = BufferType.recurrent
self.buffer_type[ref_id] = buffer_type_id
self.frame_rate[ref_id] = frame_rate
self.is_constant[key] = is_constant

# 2. Handle cyclic subsequences for data augmentation
if not is_constant and cyclic_subseq is not None:
st, ed = cyclic_subseq
buffer_raw = buffer_raw[:ed, ...] # Extract cyclic pattern
self.recurrent_subseq[ref_id, 0] = st
self.recurrent_subseq[ref_id, 1] = ed
self.max_len[ref_id] = buffer_raw.shape[0]

# 3. Store in temporary buffer list
if key not in self.ref_buffer_list:
self.ref_buffer_list[key] = [None for _ in range(self.num_ref)]
self.ref_buffer_list[key][ref_id] = buffer_raw

🏗️ Buffer Optimization and Consolidation

Method Signature:

def prepare_buffers(self, key: str):

🔧 Optimization Process:

# Memory-efficient buffer consolidation
if self.is_constant[key]:
# Constant data: simple stacking
self.ref_buffer[key] = torch.stack(self.ref_buffer_list[key])
else:
# Time-varying data: add padding and concatenate
padding = torch.zeros_like(self.ref_buffer_list[key][0])[:1, ...]
self.ref_buffer_list[key] = [padding] + self.ref_buffer_list[key]
self.ref_buffer[key] = torch.concatenate(self.ref_buffer_list[key], dim=0).contiguous()

# Compute efficient start indices for fast access
if self.start_index is None:
len_sum = torch.cumsum(self.max_len, dim=0)
self.start_index = torch.zeros(self.num_ref, dtype=torch.int32, device=self.device)
self.start_index[1:] = len_sum[:-1]
self.start_index += 1 # Account for padding

⚡ High-Performance Access Methods

🎯 GPU-Accelerated Index Computation

Warp Kernel Implementation:

@wp.kernel
def compute_idx(
current_idx: wp.array(dtype=int),
current_time: wp.array(dtype=float),
env_ref_id: wp.array(dtype=int),
frame_rate: wp.array(dtype=float),
start_index: wp.array(dtype=int),
max_len: wp.array(dtype=int),
recurrent_subseq: wp.array(dtype=wp.vec2i),
) -> int:
tid = wp.tid() # Thread ID for parallel processing
rid = env_ref_id[tid] # Reference ID for this environment

# Convert time to frame index
idx = wp.int32(current_time[tid] * frame_rate[rid])

if current_time[tid] < 0:
idx = 0
elif recurrent_subseq[rid][0] == -1 or max_len[rid] == -1:
# Singular buffer: linear progression
if idx >= max_len[rid]:
idx = -1 # End of sequence marker
else:
idx += start_index[rid]
else:
# Recurrent buffer: cyclic repetition
rec_st = recurrent_subseq[rid][0]
rec_ed = recurrent_subseq[rid][1]
rec_len = rec_ed - rec_st

if idx >= rec_st:
# Cyclic wrapping for continuous playback
idx = rec_st + (idx - rec_st) % rec_len
idx += start_index[rid]

current_idx[tid] = idx

Performance Benefits:

  • Parallel Processing: Simultaneous index computation for all environments
  • Cyclic Optimization: Efficient modulo operations for seamless looping
  • Memory Coalescing: Optimized memory access patterns for GPU architectures

📊 Standard Observation Access

Method Signature:

def calc_obs(self, key: str, current_time: torch.Tensor) -> torch.Tensor:

🔧 Implementation:

def calc_obs(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
"""High-performance observation retrieval with intelligent caching"""

if self.is_constant[key]:
# Constant data: direct environment-based indexing
return self.ref_buffer[key][self.env_ref_id, ...]

# Time-varying data: efficient index computation with caching
current_idx = self.calc_idx(current_time)
return self.ref_buffer[key][current_idx, ...]

🧮 Advanced Cumulative Computation

Method Signature:

def calc_cumulative_obs_v2(self, key: str, current_time: torch.Tensor) -> torch.Tensor:

🔧 Optimized Implementation:

def calc_cumulative_obs_v2(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
"""Ultra-high performance cumulative observation calculation"""

if self.is_constant[key]:
return self.ref_buffer[key][self.env_ref_id, ...]

# Compute cyclic sequence parameters
current_num_cyclic_subseq, current_idx, current_begin_idx, current_end_idx = \
self.calc_num_cyclic_subseq(current_time)

start_idx = self.start_index[self.env_ref_id] # shape: (num_envs,)

# Initialize output tensor
cumulative_obs = torch.zeros(
(self.num_envs, *self.ref_buffer[key].shape[1:]),
dtype=self.ref_buffer[key].dtype,
device=self.device
)

# Vectorized index computation for all environments
begin_indices = current_begin_idx + start_idx
end_indices = current_end_idx + start_idx
current_indices = current_idx + start_idx

# Dynamic sequence length determination
max_first_len = (begin_indices - start_idx).max()
max_cyclic_len = (end_indices - begin_indices).max()
max_last_len = (current_indices - begin_indices).max()

# Parallel index tensor creation
first_indices = torch.arange(max_first_len, device=self.device).unsqueeze(0) + start_idx.unsqueeze(1)
cyclic_indices = torch.arange(max_cyclic_len, device=self.device).unsqueeze(0) + begin_indices.unsqueeze(1)
last_indices = torch.arange(max_last_len, device=self.device).unsqueeze(0) + begin_indices.unsqueeze(1)

# Efficient masking for valid indices
first_mask = first_indices < begin_indices.unsqueeze(1)
cyclic_mask = cyclic_indices < end_indices.unsqueeze(1)
last_mask = last_indices < current_indices.unsqueeze(1)

# Boundary protection for buffer access
buffer_max_len = self.ref_buffer[key].shape[0]
first_indices = torch.where(first_indices < buffer_max_len, first_indices, buffer_max_len - 1)
cyclic_indices = torch.where(cyclic_indices < buffer_max_len, cyclic_indices, buffer_max_len - 1)
last_indices = torch.where(last_indices < buffer_max_len, last_indices, buffer_max_len - 1)

# High-performance vectorized data retrieval
first_sequences = self.ref_buffer[key][first_indices] * first_mask.unsqueeze(-1)
cyclic_sequences = self.ref_buffer[key][cyclic_indices] * cyclic_mask.unsqueeze(-1)
last_sequences = self.ref_buffer[key][last_indices] * last_mask.unsqueeze(-1)

# Efficient integration with cyclic repetition support
cumulative_obs += torch.sum(first_sequences, dim=1)
if torch.any(current_end_idx > current_begin_idx):
cumulative_obs += torch.sum(cyclic_sequences, dim=1) * current_num_cyclic_subseq.to(torch.float32).unsqueeze(1)
cumulative_obs += torch.sum(last_sequences, dim=1)

return cumulative_obs

⚡ Performance Optimizations:

  • Vectorized Operations: Batch processing across all environments
  • Memory Coalescing: Optimized GPU memory access patterns
  • Intelligent Masking: Efficient boundary handling without branching
  • Cyclic Integration: Mathematical precision for periodic sequences

🌐 Dynamic Base Pose Computation

🎯 Real-Time Pose Integration

Method Signature:

def calc_base_pose(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:

🔧 Implementation:

def calc_base_pose(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
"""Real-time base pose computation from velocity data"""

# Retrieve velocity observations
lin_vel_yaw_frame = self.calc_obs(lin_vel_name, current_time)
ang_vel = self.calc_obs(ang_vel_name, current_time)

# Incremental pose integration
self.step_robot_base_pose(current_time, lin_vel_yaw_frame, ang_vel)

# Return combined pose (position + quaternion)
return torch.cat([self.base_pos, self.base_quat], dim=1)

def step_robot_base_pose(self, current_time: torch.Tensor, lin_vel_yaw_frame: torch.Tensor, ang_vel: torch.Tensor):
"""High-precision incremental pose integration"""

# Compute time delta with safety checks
dt = torch.where(self.last_pose_tme < 0, torch.zeros_like(self.last_pose_tme), current_time - self.last_pose_tme)
dt = dt.clamp(min=0) # Ensure non-negative time steps

# Transform linear velocity to world frame
quat_yaw = yaw_quat(self.base_quat) # Extract yaw-only quaternion
lin_vel = quat_apply(quat_yaw, lin_vel_yaw_frame)

# Integrate position
self.base_pos += lin_vel * dt.unsqueeze(1)

# Integrate orientation using quaternion mathematics
rot_vec = quat_apply(quat_inv(self.base_quat), ang_vel) * dt.unsqueeze(1)
self.base_quat = quat_mul(self.base_quat, quat_inv(angle_axis_to_quaternion(rot_vec)))

# Update time tracking
self.last_pose_tme = torch.where(
torch.logical_and(self.last_pose_tme < 0, current_time < 1e-5),
self.last_pose_tme,
current_time.clone()
)

📈 Cumulative Base Pose Computation

Method Signature:

def calc_base_pose_cumulative(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:

🔧 Advanced Implementation:

def calc_base_pose_cumulative(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
"""Ultra-efficient cumulative pose computation using integral methods"""

# Frame rate-based time step calculation
dt = 1.0 / self.frame_rate[self.env_ref_id]

# High-performance cumulative integration
lin_pos = self.calc_cumulative_obs_v2(lin_vel_name, current_time) * dt.unsqueeze(1)
ang_pos = self.calc_cumulative_obs_v2(ang_vel_name, current_time) * dt.unsqueeze(1)

# Normalize angular position to [0, 2π] range
ang_pos = ang_pos % (2 * np.pi)

# Convert to quaternion representation
base_quat = quat_from_euler_xyz(*ang_pos.T)

return torch.cat([lin_pos, base_quat], dim=1)

📊 Performance Benefits:

  • Integral-Based Computation: Direct mathematical integration without iterative steps
  • Vectorized Processing: Simultaneous computation across all environments
  • Numerical Stability: Robust handling of angular wrap-around and discontinuities

🎛️ Environment Management

🔄 Intelligent Reset System

Method Signature:

def reset(self, env, env_ids: Sequence[int] | None = None):

🔧 Implementation:

def reset(self, env, env_ids: Sequence[int] | None = None):
"""Efficient environment reset with domain randomization support"""

if env_ids is None:
env_ids = slice(None) # Reset all environments

# Random reference assignment for curriculum learning
self.env_ref_id[env_ids] = torch.randint(
0, self.num_ref,
size=self.env_ref_id[env_ids].shape,
dtype=self.env_ref_id.dtype,
device=self.env_ref_id.device,
)

# Reset pose tracking
self.last_pose_tme[env_ids] = -1

# Initialize base pose from environment state
root_pose_w = env.scene["robot"].data.root_state_w[env_ids, :7]
root_pose_w = root_pose_w.clone()

self.base_pos[env_ids] = root_pose_w[:, :3] - env.scene.env_origins[env_ids]
self.base_quat[env_ids] = root_pose_w[:, 3:]

📊 Sequence Validation

Method Signature:

def calc_mask(self, current_time: torch.Tensor) -> torch.Tensor:

🔧 Implementation:

def calc_mask(self, current_time: torch.Tensor) -> torch.Tensor:
"""Efficient sequence validity checking"""

current_idx = self.calc_idx(current_time)
return current_idx >= 0 # Valid sequence positions

💡 Usage Examples

🚀 Basic Buffer Setup

from GBC.utils.buffer.ref_buffer import BufferManager
import torch

# Initialize high-performance buffer manager
buffer_manager = BufferManager(
num_envs=4096, # Large-scale parallel training
num_ref=100, # Multiple reference sequences
working_mode="recurrent", # Cyclic sequence support
device="cuda:0"
)

# Add joint position reference data
joint_positions = torch.randn(200, 29, device="cuda:0") # [T, num_joints]
buffer_manager.add_reference(
key="joint_pos",
ref_id=0,
buffer_raw=joint_positions,
is_constant=False,
frame_rate=30,
cyclic_subseq=(20, 180) # Detected walking cycle
)

# Add velocity data for dynamic pose computation
lin_velocities = torch.randn(200, 3, device="cuda:0") # [T, 3]
ang_velocities = torch.randn(200, 3, device="cuda:0") # [T, 3]

buffer_manager.add_reference("lin_vel", 0, lin_velocities, False, 30, (20, 180))
buffer_manager.add_reference("ang_vel", 0, ang_velocities, False, 30, (20, 180))

# Optimize buffers for maximum performance
buffer_manager.prepare_buffers("joint_pos")
buffer_manager.prepare_buffers("lin_vel")
buffer_manager.prepare_buffers("ang_vel")

🎯 High-Performance Training Loop

# Training loop with ultra-efficient data access
current_time = torch.rand(4096, device="cuda:0") * 10.0 # Random time points

# Standard observation retrieval (sub-millisecond access)
joint_targets = buffer_manager.calc_obs("joint_pos", current_time) # [4096, 29]

# Advanced cumulative computation for integral-based features
cumulative_displacement = buffer_manager.calc_cumulative_obs_v2("lin_vel", current_time) # [4096, 3]

# Real-time base pose computation
base_poses = buffer_manager.calc_base_pose(current_time, "lin_vel", "ang_vel") # [4096, 7]

# Sequence validity checking
valid_mask = buffer_manager.calc_mask(current_time) # [4096] boolean mask

print(f"Joint targets shape: {joint_targets.shape}")
print(f"Valid environments: {valid_mask.sum().item()}/{len(valid_mask)}")

🔄 Domain Randomization Integration

# Efficient environment reset with automatic reference randomization
env_ids_to_reset = torch.where(episode_done)[0] # Completed episodes

buffer_manager.reset(env, env_ids_to_reset) # Automatic reference shuffling

# Reference assignment for curriculum learning
# Easy sequences for new environments
easy_ref_ids = torch.randint(0, 20, (1024,), device="cuda:0") # First 20 refs are easy
buffer_manager.set_all_env_ref_id(easy_ref_ids)

# Advanced sequences for experienced environments
hard_ref_ids = torch.randint(80, 100, (1024,), device="cuda:0") # Last 20 refs are hard
buffer_manager.env_ref_id[2048:3072] = hard_ref_ids # Assign to specific envs

📈 Performance Benchmarking

import time

# Benchmark cumulative computation efficiency
num_iterations = 1000
current_time = torch.rand(4096, device="cuda:0") * 10.0

# Standard method timing
start_time = time.time()
for _ in range(num_iterations):
result_v1 = buffer_manager.calc_cumulative_obs("lin_vel", current_time)
v1_time = time.time() - start_time

# Optimized method timing
start_time = time.time()
for _ in range(num_iterations):
result_v2 = buffer_manager.calc_cumulative_obs_v2("lin_vel", current_time)
v2_time = time.time() - start_time

print(f"Standard method: {v1_time:.4f}s")
print(f"Optimized method: {v2_time:.4f}s")
print(f"Speedup: {v1_time/v2_time:.2f}x")
print(f"Results match: {torch.allclose(result_v1, result_v2)}")

🚨 Performance Optimization Guidelines

Memory Efficiency Best Practices

  1. Buffer Consolidation: Always call prepare_buffers() after adding all references for a key
  2. Device Consistency: Ensure all tensors are on the same GPU device
  3. Data Type Optimization: Use appropriate precision (float32 vs float64) based on requirements
  4. Batch Size Tuning: Optimize num_envs for your GPU memory capacity

Computational Efficiency

  1. Cyclic Sequence Design: Use cyclic_subseq parameters for walking/running motions
  2. Reference Diversity: Balance number of references vs. memory usage
  3. Access Pattern Optimization: Minimize random access, prefer sequential patterns
  4. Caching Utilization: Leverage built-in caching for repeated time queries

🔧 Integration Optimization

# Optimal setup for maximum performance
def setup_high_performance_buffer():
"""Configure buffer for maximum training efficiency"""

buffer_manager = BufferManager(
num_envs=8192, # Power of 2 for optimal GPU utilization
num_ref=128, # Sufficient diversity for robust training
working_mode="recurrent",
device="cuda:0"
)

# Batch reference addition for efficiency
reference_data = load_reference_dataset() # Your data loading function

for ref_id, (key, data, cyclic_info) in enumerate(reference_data):
buffer_manager.add_reference(
key=key,
ref_id=ref_id,
buffer_raw=data.to("cuda:0"),
is_constant=False,
frame_rate=30,
cyclic_subseq=cyclic_info
)

# Optimize all buffers simultaneously
for key in ["joint_pos", "lin_vel", "ang_vel", "foot_contact"]:
buffer_manager.prepare_buffers(key)

return buffer_manager

This ultra-high-performance buffer system provides the foundation for efficient imitation learning training, enabling massive-scale parallel environments while maintaining sub-millisecond access times and minimal GPU memory footprint. The intelligent integral-based computation and advanced cyclic sequence handling make it an essential component for production-grade reinforcement learning pipelines.