ref_buffer
Module: GBC.utils.buffer.ref_buffer
This module implements an exceptionally high-efficiency and GPU memory-friendly reference data buffer system for imitation learning. The BufferManager provides optimized access to reference motion data while supporting advanced features like domain randomization adaptation, efficient integral-based computation for dynamic values (root positions, joint positions), and high-performance cyclic sequence processing. This system serves as the foundation for the reference_observation_manager and significantly accelerates training efficiency through intelligent memory management and computational optimization.
🚀 Key Performance Advantages
⚡ Ultra-High Efficiency
- GPU-Native Operations: Leverages Warp kernels for parallel computation across thousands of environments
- Memory-Optimized Storage: Intelligent buffer concatenation minimizes GPU memory fragmentation
- Vectorized Access: Batch processing of reference data across all environments simultaneously
- Cached Index Computation: Avoids redundant calculations through intelligent caching mechanisms
🧠 Memory-Friendly Design
- Shared Buffer Architecture: Single consolidated buffer for all reference sequences
- Dynamic Buffer Types: Support for constant, singular, and recurrent data with different memory patterns
- Efficient Indexing: Compact index structures minimize memory overhead
- Zero-Copy Operations: Direct GPU tensor operations without CPU-GPU transfers
🔄 Domain Randomization Adaptation
- Dynamic Reference Assignment: Real-time environment-to-reference mapping for curriculum learning
- Efficient Reset Mechanisms: Fast environment resets with minimal computational overhead
- Multi-Reference Support: Simultaneous handling of diverse motion patterns across environments
📈 Advanced Integral Computation
- Cumulative Observation Calculation: High-performance integration for position/velocity data
- Cyclic Sequence Optimization: Specialized handling of periodic motions with mathematical precision
- Real-Time Base Pose Tracking: Efficient root position and orientation computation from velocity data
🏗️ Core Architecture
📊 BufferType Enumeration
Definition:
class BufferType:
    singular = 0      # Single-shot sequences (non-repeating)
    recurrent = 1     # Standard cyclic sequences
    recurrent_strict = 2  # Strictly enforced cyclic sequences
🎯 Buffer Type Characteristics:
- Singular: Non-repeating sequences, ideal for one-shot motions (e.g., getting up, specific gestures)
- Recurrent: Cyclic sequences with flexible repetition (e.g., walking, running gaits)
- Recurrent Strict: Enforced cyclic patterns with precise temporal constraints
🏭 BufferManager Class
Module Name: GBC.utils.buffer.ref_buffer.BufferManager
Definition:
class BufferManager:
    def __init__(self, num_envs: int, num_ref: int, working_mode: str, device: str):
        self.num_envs = num_envs          # Number of parallel environments
        self.num_ref = num_ref            # Number of reference sequences
        self.device = device              # GPU device for operations
        
        # Core buffer management
        self.buffer_type = torch.ones(num_ref, dtype=torch.int8, device=device) * buffer_type_id
        self.frame_rate = torch.zeros(num_ref, dtype=torch.float32, device=device)
        self.max_len = torch.zeros(num_ref, dtype=torch.int32, device=device)
        self.recurrent_subseq = torch.ones((num_ref, 2), dtype=torch.int32, device=device) * -1
        
        # Environment-reference mapping
        self.env_ref_id = torch.zeros(num_envs, dtype=torch.int32, device=device)
        
        # High-performance storage
        self.ref_buffer_list = dict()     # Temporary storage during setup
        self.ref_buffer = dict()          # Optimized consolidated buffers
        self.is_constant = dict()         # Constant vs. time-varying data flags
📥 Initialization Parameters:
- num_envs (int): Number of parallel training environments
- num_ref (int): Number of reference motion sequences
- working_mode (str): Buffer operation mode ("singular", "recurrent", "recurrent_strict")
- device (str): GPU device identifier (e.g., "cuda:0")
🔧 Core Functionality
📥 Reference Data Management
Method Signature:
def add_reference(self, 
                 key: str, 
                 ref_id: int, 
                 buffer_raw: torch.Tensor, 
                 is_constant: bool, 
                 frame_rate: int, 
                 cyclic_subseq: tuple = None):
📥 Input Parameters:
- key (str): Data identifier (e.g., "joint_pos", "lin_vel", "ang_vel")
- ref_id (int): Unique reference sequence identifier
- buffer_raw (torch.Tensor): Raw reference data [T, ...]
- is_constant (bool): Whether data is time-invariant
- frame_rate (int): Sequence frame rate (Hz)
- cyclic_subseq (tuple): Optional cyclic subsequence bounds (start_idx, end_idx)
🔧 Processing Pipeline:
# 1. Configure buffer type and parameters
buffer_type_id = BufferType.recurrent
self.buffer_type[ref_id] = buffer_type_id
self.frame_rate[ref_id] = frame_rate
self.is_constant[key] = is_constant
# 2. Handle cyclic subsequences for data augmentation
if not is_constant and cyclic_subseq is not None:
    st, ed = cyclic_subseq
    buffer_raw = buffer_raw[:ed, ...]  # Extract cyclic pattern
    self.recurrent_subseq[ref_id, 0] = st
    self.recurrent_subseq[ref_id, 1] = ed
    self.max_len[ref_id] = buffer_raw.shape[0]
# 3. Store in temporary buffer list
if key not in self.ref_buffer_list:
    self.ref_buffer_list[key] = [None for _ in range(self.num_ref)]
self.ref_buffer_list[key][ref_id] = buffer_raw
🏗️ Buffer Optimization and Consolidation
Method Signature:
def prepare_buffers(self, key: str):
🔧 Optimization Process:
# Memory-efficient buffer consolidation
if self.is_constant[key]:
    # Constant data: simple stacking
    self.ref_buffer[key] = torch.stack(self.ref_buffer_list[key])
else:
    # Time-varying data: add padding and concatenate
    padding = torch.zeros_like(self.ref_buffer_list[key][0])[:1, ...]
    self.ref_buffer_list[key] = [padding] + self.ref_buffer_list[key]
    self.ref_buffer[key] = torch.concatenate(self.ref_buffer_list[key], dim=0).contiguous()
    
    # Compute efficient start indices for fast access
    if self.start_index is None:
        len_sum = torch.cumsum(self.max_len, dim=0)
        self.start_index = torch.zeros(self.num_ref, dtype=torch.int32, device=self.device)
        self.start_index[1:] = len_sum[:-1]
        self.start_index += 1  # Account for padding
⚡ High-Performance Access Methods
🎯 GPU-Accelerated Index Computation
Warp Kernel Implementation:
@wp.kernel
def compute_idx(
    current_idx: wp.array(dtype=int),
    current_time: wp.array(dtype=float),
    env_ref_id: wp.array(dtype=int),
    frame_rate: wp.array(dtype=float),
    start_index: wp.array(dtype=int),
    max_len: wp.array(dtype=int),
    recurrent_subseq: wp.array(dtype=wp.vec2i),
) -> int:
    tid = wp.tid()  # Thread ID for parallel processing
    rid = env_ref_id[tid]  # Reference ID for this environment
    
    # Convert time to frame index
    idx = wp.int32(current_time[tid] * frame_rate[rid])
    
    if current_time[tid] < 0:
        idx = 0
    elif recurrent_subseq[rid][0] == -1 or max_len[rid] == -1:
        # Singular buffer: linear progression
        if idx >= max_len[rid]:
            idx = -1  # End of sequence marker
        else:
            idx += start_index[rid]
    else:
        # Recurrent buffer: cyclic repetition
        rec_st = recurrent_subseq[rid][0]
        rec_ed = recurrent_subseq[rid][1]
        rec_len = rec_ed - rec_st
        
        if idx >= rec_st:
            # Cyclic wrapping for continuous playback
            idx = rec_st + (idx - rec_st) % rec_len
        idx += start_index[rid]
    
    current_idx[tid] = idx
Performance Benefits:
- Parallel Processing: Simultaneous index computation for all environments
- Cyclic Optimization: Efficient modulo operations for seamless looping
- Memory Coalescing: Optimized memory access patterns for GPU architectures
📊 Standard Observation Access
Method Signature:
def calc_obs(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
🔧 Implementation:
def calc_obs(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
    """High-performance observation retrieval with intelligent caching"""
    
    if self.is_constant[key]:
        # Constant data: direct environment-based indexing
        return self.ref_buffer[key][self.env_ref_id, ...]
    
    # Time-varying data: efficient index computation with caching
    current_idx = self.calc_idx(current_time)
    return self.ref_buffer[key][current_idx, ...]
🧮 Advanced Cumulative Computation
Method Signature:
def calc_cumulative_obs_v2(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
🔧 Optimized Implementation:
def calc_cumulative_obs_v2(self, key: str, current_time: torch.Tensor) -> torch.Tensor:
    """Ultra-high performance cumulative observation calculation"""
    
    if self.is_constant[key]:
        return self.ref_buffer[key][self.env_ref_id, ...]
    
    # Compute cyclic sequence parameters
    current_num_cyclic_subseq, current_idx, current_begin_idx, current_end_idx = \
        self.calc_num_cyclic_subseq(current_time)
    
    start_idx = self.start_index[self.env_ref_id]  # shape: (num_envs,)
    
    # Initialize output tensor
    cumulative_obs = torch.zeros(
        (self.num_envs, *self.ref_buffer[key].shape[1:]),
        dtype=self.ref_buffer[key].dtype,
        device=self.device
    )
    
    # Vectorized index computation for all environments
    begin_indices = current_begin_idx + start_idx
    end_indices = current_end_idx + start_idx
    current_indices = current_idx + start_idx
    
    # Dynamic sequence length determination
    max_first_len = (begin_indices - start_idx).max()
    max_cyclic_len = (end_indices - begin_indices).max()
    max_last_len = (current_indices - begin_indices).max()
    
    # Parallel index tensor creation
    first_indices = torch.arange(max_first_len, device=self.device).unsqueeze(0) + start_idx.unsqueeze(1)
    cyclic_indices = torch.arange(max_cyclic_len, device=self.device).unsqueeze(0) + begin_indices.unsqueeze(1)
    last_indices = torch.arange(max_last_len, device=self.device).unsqueeze(0) + begin_indices.unsqueeze(1)
    
    # Efficient masking for valid indices
    first_mask = first_indices < begin_indices.unsqueeze(1)
    cyclic_mask = cyclic_indices < end_indices.unsqueeze(1)
    last_mask = last_indices < current_indices.unsqueeze(1)
    
    # Boundary protection for buffer access
    buffer_max_len = self.ref_buffer[key].shape[0]
    first_indices = torch.where(first_indices < buffer_max_len, first_indices, buffer_max_len - 1)
    cyclic_indices = torch.where(cyclic_indices < buffer_max_len, cyclic_indices, buffer_max_len - 1)
    last_indices = torch.where(last_indices < buffer_max_len, last_indices, buffer_max_len - 1)
    
    # High-performance vectorized data retrieval
    first_sequences = self.ref_buffer[key][first_indices] * first_mask.unsqueeze(-1)
    cyclic_sequences = self.ref_buffer[key][cyclic_indices] * cyclic_mask.unsqueeze(-1)
    last_sequences = self.ref_buffer[key][last_indices] * last_mask.unsqueeze(-1)
    
    # Efficient integration with cyclic repetition support
    cumulative_obs += torch.sum(first_sequences, dim=1)
    if torch.any(current_end_idx > current_begin_idx):
        cumulative_obs += torch.sum(cyclic_sequences, dim=1) * current_num_cyclic_subseq.to(torch.float32).unsqueeze(1)
    cumulative_obs += torch.sum(last_sequences, dim=1)
    
    return cumulative_obs
⚡ Performance Optimizations:
- Vectorized Operations: Batch processing across all environments
- Memory Coalescing: Optimized GPU memory access patterns
- Intelligent Masking: Efficient boundary handling without branching
- Cyclic Integration: Mathematical precision for periodic sequences
🌐 Dynamic Base Pose Computation
🎯 Real-Time Pose Integration
Method Signature:
def calc_base_pose(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
🔧 Implementation:
def calc_base_pose(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
    """Real-time base pose computation from velocity data"""
    
    # Retrieve velocity observations
    lin_vel_yaw_frame = self.calc_obs(lin_vel_name, current_time)
    ang_vel = self.calc_obs(ang_vel_name, current_time)
    
    # Incremental pose integration
    self.step_robot_base_pose(current_time, lin_vel_yaw_frame, ang_vel)
    
    # Return combined pose (position + quaternion)
    return torch.cat([self.base_pos, self.base_quat], dim=1)
def step_robot_base_pose(self, current_time: torch.Tensor, lin_vel_yaw_frame: torch.Tensor, ang_vel: torch.Tensor):
    """High-precision incremental pose integration"""
    
    # Compute time delta with safety checks
    dt = torch.where(self.last_pose_tme < 0, torch.zeros_like(self.last_pose_tme), current_time - self.last_pose_tme)
    dt = dt.clamp(min=0)  # Ensure non-negative time steps
    
    # Transform linear velocity to world frame
    quat_yaw = yaw_quat(self.base_quat)  # Extract yaw-only quaternion
    lin_vel = quat_apply(quat_yaw, lin_vel_yaw_frame)
    
    # Integrate position
    self.base_pos += lin_vel * dt.unsqueeze(1)
    
    # Integrate orientation using quaternion mathematics
    rot_vec = quat_apply(quat_inv(self.base_quat), ang_vel) * dt.unsqueeze(1)
    self.base_quat = quat_mul(self.base_quat, quat_inv(angle_axis_to_quaternion(rot_vec)))
    
    # Update time tracking
    self.last_pose_tme = torch.where(
        torch.logical_and(self.last_pose_tme < 0, current_time < 1e-5),
        self.last_pose_tme,
        current_time.clone()
    )
📈 Cumulative Base Pose Computation
Method Signature:
def calc_base_pose_cumulative(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
🔧 Advanced Implementation:
def calc_base_pose_cumulative(self, current_time: torch.Tensor, lin_vel_name: str, ang_vel_name: str) -> torch.Tensor:
    """Ultra-efficient cumulative pose computation using integral methods"""
    
    # Frame rate-based time step calculation
    dt = 1.0 / self.frame_rate[self.env_ref_id]
    
    # High-performance cumulative integration
    lin_pos = self.calc_cumulative_obs_v2(lin_vel_name, current_time) * dt.unsqueeze(1)
    ang_pos = self.calc_cumulative_obs_v2(ang_vel_name, current_time) * dt.unsqueeze(1)
    
    # Normalize angular position to [0, 2π] range
    ang_pos = ang_pos % (2 * np.pi)
    
    # Convert to quaternion representation
    base_quat = quat_from_euler_xyz(*ang_pos.T)
    
    return torch.cat([lin_pos, base_quat], dim=1)
📊 Performance Benefits:
- Integral-Based Computation: Direct mathematical integration without iterative steps
- Vectorized Processing: Simultaneous computation across all environments
- Numerical Stability: Robust handling of angular wrap-around and discontinuities
🎛️ Environment Management
🔄 Intelligent Reset System
Method Signature:
def reset(self, env, env_ids: Sequence[int] | None = None):
🔧 Implementation:
def reset(self, env, env_ids: Sequence[int] | None = None):
    """Efficient environment reset with domain randomization support"""
    
    if env_ids is None:
        env_ids = slice(None)  # Reset all environments
    
    # Random reference assignment for curriculum learning
    self.env_ref_id[env_ids] = torch.randint(
        0, self.num_ref,
        size=self.env_ref_id[env_ids].shape,
        dtype=self.env_ref_id.dtype,
        device=self.env_ref_id.device,
    )
    
    # Reset pose tracking
    self.last_pose_tme[env_ids] = -1
    
    # Initialize base pose from environment state
    root_pose_w = env.scene["robot"].data.root_state_w[env_ids, :7]
    root_pose_w = root_pose_w.clone()
    
    self.base_pos[env_ids] = root_pose_w[:, :3] - env.scene.env_origins[env_ids]
    self.base_quat[env_ids] = root_pose_w[:, 3:]
📊 Sequence Validation
Method Signature:
def calc_mask(self, current_time: torch.Tensor) -> torch.Tensor:
🔧 Implementation:
def calc_mask(self, current_time: torch.Tensor) -> torch.Tensor:
    """Efficient sequence validity checking"""
    
    current_idx = self.calc_idx(current_time)
    return current_idx >= 0  # Valid sequence positions
💡 Usage Examples
🚀 Basic Buffer Setup
from GBC.utils.buffer.ref_buffer import BufferManager
import torch
# Initialize high-performance buffer manager
buffer_manager = BufferManager(
    num_envs=4096,          # Large-scale parallel training
    num_ref=100,            # Multiple reference sequences
    working_mode="recurrent", # Cyclic sequence support
    device="cuda:0"
)
# Add joint position reference data
joint_positions = torch.randn(200, 29, device="cuda:0")  # [T, num_joints]
buffer_manager.add_reference(
    key="joint_pos",
    ref_id=0,
    buffer_raw=joint_positions,
    is_constant=False,
    frame_rate=30,
    cyclic_subseq=(20, 180)  # Detected walking cycle
)
# Add velocity data for dynamic pose computation
lin_velocities = torch.randn(200, 3, device="cuda:0")  # [T, 3]
ang_velocities = torch.randn(200, 3, device="cuda:0")  # [T, 3]
buffer_manager.add_reference("lin_vel", 0, lin_velocities, False, 30, (20, 180))
buffer_manager.add_reference("ang_vel", 0, ang_velocities, False, 30, (20, 180))
# Optimize buffers for maximum performance
buffer_manager.prepare_buffers("joint_pos")
buffer_manager.prepare_buffers("lin_vel")
buffer_manager.prepare_buffers("ang_vel")
🎯 High-Performance Training Loop
# Training loop with ultra-efficient data access
current_time = torch.rand(4096, device="cuda:0") * 10.0  # Random time points
# Standard observation retrieval (sub-millisecond access)
joint_targets = buffer_manager.calc_obs("joint_pos", current_time)  # [4096, 29]
# Advanced cumulative computation for integral-based features
cumulative_displacement = buffer_manager.calc_cumulative_obs_v2("lin_vel", current_time)  # [4096, 3]
# Real-time base pose computation
base_poses = buffer_manager.calc_base_pose(current_time, "lin_vel", "ang_vel")  # [4096, 7]
# Sequence validity checking
valid_mask = buffer_manager.calc_mask(current_time)  # [4096] boolean mask
print(f"Joint targets shape: {joint_targets.shape}")
print(f"Valid environments: {valid_mask.sum().item()}/{len(valid_mask)}")
🔄 Domain Randomization Integration
# Efficient environment reset with automatic reference randomization
env_ids_to_reset = torch.where(episode_done)[0]  # Completed episodes
buffer_manager.reset(env, env_ids_to_reset)  # Automatic reference shuffling
# Reference assignment for curriculum learning
# Easy sequences for new environments
easy_ref_ids = torch.randint(0, 20, (1024,), device="cuda:0")  # First 20 refs are easy
buffer_manager.set_all_env_ref_id(easy_ref_ids)
# Advanced sequences for experienced environments
hard_ref_ids = torch.randint(80, 100, (1024,), device="cuda:0")  # Last 20 refs are hard
buffer_manager.env_ref_id[2048:3072] = hard_ref_ids  # Assign to specific envs
📈 Performance Benchmarking
import time
# Benchmark cumulative computation efficiency
num_iterations = 1000
current_time = torch.rand(4096, device="cuda:0") * 10.0
# Standard method timing
start_time = time.time()
for _ in range(num_iterations):
    result_v1 = buffer_manager.calc_cumulative_obs("lin_vel", current_time)
v1_time = time.time() - start_time
# Optimized method timing
start_time = time.time()
for _ in range(num_iterations):
    result_v2 = buffer_manager.calc_cumulative_obs_v2("lin_vel", current_time)
v2_time = time.time() - start_time
print(f"Standard method: {v1_time:.4f}s")
print(f"Optimized method: {v2_time:.4f}s")
print(f"Speedup: {v1_time/v2_time:.2f}x")
print(f"Results match: {torch.allclose(result_v1, result_v2)}")
🚨 Performance Optimization Guidelines
✅ Memory Efficiency Best Practices
- Buffer Consolidation: Always call prepare_buffers()after adding all references for a key
- Device Consistency: Ensure all tensors are on the same GPU device
- Data Type Optimization: Use appropriate precision (float32 vs float64) based on requirements
- Batch Size Tuning: Optimize num_envsfor your GPU memory capacity
⚡ Computational Efficiency
- Cyclic Sequence Design: Use cyclic_subseqparameters for walking/running motions
- Reference Diversity: Balance number of references vs. memory usage
- Access Pattern Optimization: Minimize random access, prefer sequential patterns
- Caching Utilization: Leverage built-in caching for repeated time queries
🔧 Integration Optimization
# Optimal setup for maximum performance
def setup_high_performance_buffer():
    """Configure buffer for maximum training efficiency"""
    
    buffer_manager = BufferManager(
        num_envs=8192,      # Power of 2 for optimal GPU utilization
        num_ref=128,        # Sufficient diversity for robust training
        working_mode="recurrent",
        device="cuda:0"
    )
    
    # Batch reference addition for efficiency
    reference_data = load_reference_dataset()  # Your data loading function
    
    for ref_id, (key, data, cyclic_info) in enumerate(reference_data):
        buffer_manager.add_reference(
            key=key,
            ref_id=ref_id,
            buffer_raw=data.to("cuda:0"),
            is_constant=False,
            frame_rate=30,
            cyclic_subseq=cyclic_info
        )
    
    # Optimize all buffers simultaneously
    for key in ["joint_pos", "lin_vel", "ang_vel", "foot_contact"]:
        buffer_manager.prepare_buffers(key)
    
    return buffer_manager
This ultra-high-performance buffer system provides the foundation for efficient imitation learning training, enabling massive-scale parallel environments while maintaining sub-millisecond access times and minimal GPU memory footprint. The intelligent integral-based computation and advanced cyclic sequence handling make it an essential component for production-grade reinforcement learning pipelines.