pose_transformer_trainer
Module: GBC.utils.data_preparation.pose_transformer_trainer
This module provides a comprehensive training framework for PoseTransformer models, implementing advanced training techniques and sophisticated loss functions for robust human-to-robot motion transfer. The trainer incorporates differentiable forward kinematics, multi-component loss functions, and various regularization strategies to ensure stable and accurate pose retargeting.
📚 Dependencies
This module requires the following Python packages:
- torch- PyTorch framework for neural networks and optimization
- human_body_prior- SMPL+H body model implementation
- smplx- Extended SMPL body models
- wandb- Experiment tracking and logging (optional)
- tqdm- Progress bars for training loops
- dataclasses- Configuration management
- GBC.utils.base.base_fk- Robot forward kinematics
- GBC.utils.data_preparation.amass_loader- AMASS dataset loading
- GBC.utils.data_preparation.pose_transformer- PoseTransformer model
🏗️ Core Components
📊 LossCoefficients
Module Name: GBC.utils.data_preparation.pose_transformer_trainer.LossCoefficients
Definition:
@dataclass
class LossCoefficients:
    main_loss: float = 5000.0
    aux_loss: float = 1.0
    out_of_range_loss: float = 1000.0
    high_value_action_loss: float = 1000.0
    direct_mapping_loss: float = 1000.0
    symmetry_loss: float = 1000.0
    reference_action_loss: float = 5000.0
    bone_direction_loss: float = 1000.0
    disturbance_base: float = 1.0
🔧 Functionality: Configuration class for managing all loss coefficients in the training process. Provides fine-grained control over the relative importance of different loss components.
⚙️ Configuration Parameters:
- main_loss: Primary FK distance loss weight (λ_dist)
- symmetry_loss: Bilateral symmetry preservation weight (λ_sym)
- out_of_range_loss: Joint limit violation penalty (λ_limit)
- disturbance_base: Action disturbance regularization (λ_disturb)
- direct_mapping_loss: Single DOF joint mapping loss
- bone_direction_loss: Bone direction consistency loss
- aux_loss: Auxiliary model loss (for MoE models)
🎯 PoseFormerTrainer
Module Name: GBC.utils.data_preparation.pose_transformer_trainer.PoseFormerTrainer
Definition:
class PoseFormerTrainer:
    def __init__(self,
                 urdf_path: str,
                 dataset_path: str,
                 mapping_table: Dict[str, str],
                 smplh_model_path: str,
                 dmpls_model_path: str,
                 smpl_fits_dir: str,
                 load_hands: bool = False,
                 batch_size: int = 1,
                 train_size: float = 0.8,
                 device: str = 'cuda',
                 loss_coefficients: Optional[LossCoefficients] = None,
                 use_wandb: bool = True,
                 **kwargs
                ):
📥 Input Parameters:
- urdf_path (str): Path to robot URDF description file
- dataset_path (str): Path to AMASS dataset directory
- mapping_table (Dict[str, str]): SMPL+H to robot joint mapping (see create_smplh documentation)
- smplh_model_path (str): Path to SMPL+H model file
- dmpls_model_path (str): Path to DMPL muscle model file
- smpl_fits_dir (str): Path to fitted SMPL+H parameters (.pt file)
- load_hands (bool): Include hand joints in training
- batch_size (int): Training batch size
- train_size (float): Training/validation split ratio
- device (str): Computing device ("cuda" or "cpu")
- loss_coefficients (LossCoefficients): Loss function weights configuration
- use_wandb (bool): Enable Weights & Biases logging
🔧 Functionality: Comprehensive training framework for PoseTransformer models with advanced loss functions, data augmentation, and validation capabilities.
🧠 Advanced Loss Functions
The training framework implements a sophisticated multi-component loss function designed to address the inherent challenges in human-to-robot motion retargeting. The composite objective combines several specialized loss terms:
🎯 Primary Loss: FK Distance Loss (L_dist)
Mathematical Formulation:
L_dist(p) = ||FK_hn[f(p)] - FK_hm(p)||_2
🔧 Implementation:
# Compute robot joint positions from predicted actions
robot_joints = self.fk(predicted_actions)
# Compute human joint positions from SMPL pose
human_joints = self.body_model(**body_params).Jtr
# Primary distance loss
main_loss = criterion(robot_joints, human_joints[target_indices])
💡 Purpose: Minimizes Euclidean distance between robot end-effector locations and corresponding human joint positions from motion capture data. This forms the core objective for pose retargeting.
⚖️ Joint-Limit Loss (L_limit)
Mathematical Formulation:
L_limit(p) = ||max(a_min - f(p), 0) + max(f(p) - a_max, 0)||_2
🔧 Implementation:
@staticmethod
def out_of_range_loss(actions: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor):
    """Penalize joint limit violations"""
    lower_violation = torch.clamp(min_val - actions, min=0.0)
    upper_violation = torch.clamp(actions - max_val, min=0.0)
    return torch.mean(lower_violation + upper_violation)
💡 Purpose: Prevents joint limit violations by penalizing actions that exceed the robot's physical constraints. Joint limits are parsed from URDF specifications with additional empirical human range bounds.
🌊 Action-Disturbance Loss (L_disturb)
Mathematical Formulation:
L_disturb(p) = ||FK_hn[f(p)] - FK_hn[f(p + δp)]||_2 / ||δp||_2
where δp ~ N(0, σ(2π)²) with σ ≪ 1
🔧 Implementation:
@staticmethod
def disturbance_loss(actions: torch.Tensor, actions_disturbed: torch.Tensor, disturbance: torch.Tensor):
    """Enforce Lipschitz continuity through input perturbation"""
    action_diff = torch.norm(actions - actions_disturbed, dim=-1)
    disturbance_norm = torch.norm(disturbance, dim=-1)
    return torch.mean(action_diff / (disturbance_norm + 1e-8))
💡 Purpose: Enforces Lipschitz continuity by minimizing sensitivity to input perturbations. This regularization ensures smooth, dynamically feasible output sequences and reduces discontinuities in the learned mapping.
🔄 Symmetry Loss (L_sym)
Mathematical Formulation:
L_sym(p) = ||S(FK_hn[f(p)]) - FK_hn[f(S(p))]||_2
where S denotes reflection about the humanoid's sagittal plane
🔧 Implementation:
def symmetry_loss(self, actions, pose_body, pose_hand=None):
    """Preserve bilateral symmetry in retargeting"""
    # Apply symmetry transformation to input pose
    pose_body_flipped = symmetry_smplh_pose(pose_body)
    
    # Forward pass on flipped pose
    flipped_input = self.prepare_model_input(pose_body_flipped, pose_hand)
    actions_flipped = self.model(flipped_input)
    
    # Apply symmetry to original actions and compare
    actions_sym = self.flip_left_right.apply_action_flip(actions)
    return F.mse_loss(actions_sym, actions_flipped)
💡 Purpose: Preserves bilateral symmetry by ensuring that left-right reflected input poses produce correspondingly reflected robot actions. Critical for maintaining natural, symmetric movement patterns.
🎯 Direct Mapping Loss
Mathematical Formulation: For single-DOF joints (e.g., knees), directly map specific angle components:
🔧 Implementation:
def direct_mapping_loss(self, actions: torch.Tensor, pose_body: torch.Tensor, pose_hand: torch.Tensor = None):
    """Direct angle mapping for single-DOF joints"""
    pose_body = pose_body.reshape(-1, 21, 3)
    ypr = batch_angle_axis_to_ypr(pose_body)  # Convert to yaw-pitch-roll
    
    # Extract specified angle components and corresponding actions
    ypr_selected = ypr[:, self.pose_body_indices, self.angle_indices]
    actions_selected = actions[:, self.action_indices]
    
    return F.mse_loss(actions_selected, ypr_selected.detach())
💡 Purpose: Provides direct supervision for joints with single valid degrees of freedom (e.g., knee pitch). Ensures anatomically correct joint mappings for constrained robot joints.
🦴 Bone Direction Loss
🔧 Implementation:
def bone_direction_loss(self, robot_joints: torch.Tensor, gt_joints: torch.Tensor):
    """Enforce bone direction consistency"""
    # Compute bone vectors for predefined bone pairs
    robot_bone_vecs = robot_joints[:, self.bone_end_indices] - robot_joints[:, self.bone_start_indices]
    gt_bone_vecs = gt_joints[:, self.bone_end_indices] - gt_joints[:, self.bone_start_indices]
    
    # Normalize and compute cosine similarity
    robot_bone_units = F.normalize(robot_bone_vecs, dim=-1)
    gt_bone_units = F.normalize(gt_bone_vecs, dim=-1)
    cos_sim = torch.sum(robot_bone_units * gt_bone_units, dim=-1)
    
    return torch.mean(1.0 - cos_sim)
💡 Purpose: Maintains consistent bone directions between robot and human skeletons, preserving kinematic chain relationships and anatomical structure.
🚀 Training Configuration
📋 Required Configuration
⚠️ Critical Setup Requirements:
- 
Joint Mapping Configuration: - Configure mapping_tablefollowing the same principles as SMPL+H fitting
- Reference: Create SMPL+H Documentation for detailed mapping guidelines
- Ensure consistent joint correspondence between SMPL+H and robot URDF
 
- Configure 
- 
Model Paths: trainer = PoseFormerTrainer(
 urdf_path="/path/to/robot.urdf",
 smplh_model_path="/path/to/smplh/model.npz",
 dmpls_model_path="/path/to/dmpls/model.npz",
 smpl_fits_dir="/path/to/fitted_params.pt", # From SMPL+H fitting
 dataset_path="/path/to/AMASS_dataset"
 )
- 
Loss Coefficient Tuning: # Production-tested weight hierarchy for Turin robot training
 loss_coeffs = LossCoefficients(
 main_loss=5000.0, # Primary FK distance (λ_dist)
 aux_loss=1.0, # Auxiliary model outputs
 out_of_range_loss=1000.0, # Joint limit violations (λ_limit)
 high_value_action_loss=1000.0, # Large action penalties
 direct_mapping_loss=1000.0, # Single DOF joint mapping
 symmetry_loss=1000.0, # Bilateral symmetry (λ_sym)
 reference_action_loss=5000.0, # Reference motion consistency
 bone_direction_loss=50.0, # Bone direction alignment
 disturbance_base=1.0, # Lipschitz regularization base (λ_disturb)
 )
- 
Batch Configuration: # Optimized for GPU memory and training stability
 batch_size = 256 # Training batch size
 ik_batch_size = 1024 # Inverse kinematics batch size
 device = "cuda:0" # Primary GPU device
 use_renderer = True # Enable visualization capabilities
 sample_steps = 10 # Validation sampling frequency
🎯 Advanced Training Techniques
📈 Progressive Disturbance Training
🔧 Implementation:
def _get_progressive_disturbance_coeff(self, epoch, total_epochs, min_coeff=1.0, max_coeff=50.0):
    """Progressively increase disturbance loss weight during training"""
    progress = min(epoch / total_epochs, 1.0)
    return min_coeff + (max_coeff - min_coeff) * progress
💡 Strategy:
- Early Training: Low disturbance coefficient for stable convergence
- Late Training: High disturbance coefficient for robust generalization
- Benefits: Balances learning speed with robustness to input variations
🔄 Data Augmentation Strategies
Noise Injection:
# Gaussian noise perturbation for disturbance loss
noise_std = 0.01
pose_noise = torch.randn_like(pose_body) * noise_std
pose_body_disturbed = pose_body + pose_noise
Symmetry Augmentation:
# Random left-right flipping for bilateral symmetry
if apply_symmetry and random.random() < 0.5:
    pose_body = symmetry_smplh_pose(pose_body)
    actions = flip_left_right.apply_action_flip(actions)
📊 Learning Rate Scheduling
Warmup + Cosine Annealing:
class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
    """Combines warmup with cosine annealing for stable training"""
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Linear warmup
            return [base_lr * (self.last_epoch / self.warmup_epochs) for base_lr in self.base_lrs]
        else:
            # Cosine annealing
            progress = (self.last_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            return [self.min_lr + (base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
                    for base_lr in self.base_lrs]
💡 Usage Examples
🚀 Basic Training Setup
from GBC.utils.data_preparation.pose_transformer_trainer import PoseFormerTrainer, LossCoefficients
from GBC.utils.base.assets import DATA_PATHS
# Configure mapping table defined in create_smplh (this is just an example)
mapping_table = {
    "Pelvis": "base_link",
    "L_Hip": "l_hip_yaw_link",
    "R_Hip": "r_hip_yaw_link",
    "L_Knee": "l_knee_link",
    "R_Knee": "r_knee_link",
    "L_Ankle": "l_ankle_pitch_link",
    "R_Ankle": "r_ankle_pitch_link",
    "L_Shoulder": "l_arm_roll_link",
    "R_Shoulder": "r_arm_roll_link",
    "L_Elbow": "l_elbow_roll_link",
    "R_Elbow": "r_elbow_roll_link",
    "L_Wrist": "l_wrist_roll_link",
    "R_Wrist": "r_wrist_roll_link",
}
# Configure loss coefficients
loss_coeffs = LossCoefficients(
    main_loss=5000.0,
    aux_loss=1.0,
    out_of_range_loss=1000.0,
    high_value_action_loss=1000.0,
    direct_mapping_loss=1000.0,
    symmetry_loss=1000.0,
    reference_action_loss=5000.0,
    bone_direction_loss=50.0,
    disturbance_base=1.0,
)
# Initialize trainer
trainer = PoseFormerTrainer(
    urdf_path="/path/to/your_robot.urdf",
    dataset_path="your_dataset_path",
    mapping_table=mapping_table,
    smplh_model_path=DATA_PATHS.smplh_model_path,
    dmpls_model_path=DATA_PATHS.dmpls_model_path,
    smpl_fits_dir="your_smplh_fit_pt",  # From SMPL+H fitting
    batch_size=256,
    device='your_device',  # e.g., 'cuda:0'
    use_renderer=True,
    sample_steps=10,
    save_dir="your_save_path",
    loss_coefficients=loss_coeffs,
    use_wandb=True,
    wandb_project="pose_transformer_{}".format("your_project_name")
)
🔧 Advanced Training Configuration
# Configure joint weights for end-effector emphasis
joint_weights = torch.tensor([
    1.0,  # Pelvis - base reference
    1.0,  # L_Hip - structural
    1.0,  # R_Hip - structural
    1.5,  # L_Knee - important for leg kinematics
    1.5,  # R_Knee - important for leg kinematics
    4.0,  # L_Ankle - end effector, high precision needed
    4.0,  # R_Ankle - end effector, high precision needed
    1.0,  # L_Shoulder - structural
    1.0,  # R_Shoulder - structural
    1.5,  # L_Elbow - important for arm kinematics
    1.5,  # R_Elbow - important for arm kinematics
    2.0,  # L_Wrist - end effector, precision needed
    2.0,  # R_Wrist - end effector, precision needed
])
# Configure bilateral symmetry for Turin robot
from GBC.utils.data_preparation.robot_flip_left_right import YourFlipperModule # This one is critical if you want to enable symmetry
from GBC.utils.base.base_fk import RobotKinematics
flipper = YourFlipperModule()
fk = RobotKinematics(urdf_path, device=device)
flipper.prepare_flip_joint_ids(fk.get_dof_names())
trainer.set_flip_left_right(flipper)
🎯 Production Training Execution
# Execute training with production-grade configuration
trainer.train(
    epochs=1000,
    lr=1e-4,                       # Conservative learning rate
    min_lr=1e-6,                   # Lower bound for scheduling
    warmup_epochs=20,              # Gradual warmup
    validation_interval=1,         # Validate every epoch
    criterion=torch.nn.MSELoss(),  # Standard regression loss
    save_interval=5,               # Save checkpoints every 5 epochs
    save_figs=True,                # Save validation visualizations
    apply_symmetry=True,           # Enable symmetry augmentation
    apply_noise=True,              # Enable disturbance training
    visualize=False,               # Disable real-time visualization for speed
    joint_weights=joint_weights,   # Emphasize end effectors
    disturbance_min_coeff=100.0,   # Higher disturbance for robustness
    disturbance_max_coeff=200.0,   # Strong final regularization
    load=False,                    # Train from scratch
    link_trans_offset=None         # No additional link offsets
)
🔍 Validation and Monitoring
📊 Validation Metrics
Core Metrics:
- FK Distance: Primary pose retargeting accuracy
- Joint Limit Violations: Safety and feasibility assessment
- Symmetry Consistency: Bilateral movement preservation
- Action Smoothness: Temporal continuity evaluation
Implementation:
def validate(self, criterion, joint_weights=None, visualize=False, **kwargs):
    """Comprehensive validation with multiple metrics"""
    self.model.eval()
    val_metrics = {
        'val_main_loss': 0.0,
        'val_symmetry_loss': 0.0,
        'val_limit_violations': 0.0,
        'val_disturbance_loss': 0.0
    }
    
    # Validation loop with metric computation
    for batch in self.test_loader:
        # ... validation logic
        
    return val_metrics
💾 Model Persistence
Checkpoint Management:
# Save training state
trainer.save("checkpoint_epoch_100.pt", optimizer)
# Load from checkpoint
trainer.load("checkpoint_epoch_100.pt", optimizer)
# Model export for deployment
torch.save({
    'model_state_dict': trainer.model.state_dict(),
    'config': model_config,
    'mapping_table': mapping_table
}, 'final_model.pt')
📈 Experiment Tracking
Weights & Biases Integration:
# Automatic hyperparameter logging
config = {
    'batch_size': batch_size,
    'learning_rate': lr,
    'loss_coefficients': loss_coeffs.to_dict(),
    'model_architecture': 'PoseTransformer'
}
# Real-time metric monitoring
trainer.logger.log_metrics({
    'train_loss': train_loss,
    'val_loss': val_loss,
    'learning_rate': current_lr
}, step=epoch)
🚨 Best Practices
✅ Training Guidelines
- Initialization: Start with converged SMPL+H fitting results
- Loss Balancing: Use production-tested weights (main=5000, symmetry=1000, limits=1000)
- Progressive Training: Start with low disturbance (100.0) and scale to high (200.0)
- Validation: Monitor multiple metrics and save validation figures
- Regularization: Enable both noise injection and symmetry augmentation for robustness
- Dataset: Use full AMASS dataset for comprehensive coverage
- Architecture: Configure appropriate batch sizes (256 train, 1024 IK) for memory efficiency
🔧 Troubleshooting
Common Issues:
- Convergence Problems:
- Verify SMPL+H fitting quality (should be from converged fitting)
- Check joint mapping accuracy against URDF specifications
- Balance loss coefficients (main_loss=5000.0 typically optimal)
 
- Joint Violations:
- Increase out_of_range_loss=1000.0coefficient
- Verify URDF joint limits are realistic
 
- Increase 
- Asymmetric Results:
- Verify YourFlipperModuleconfiguration
- Increase symmetry_loss=1000.0coefficient
- Enable apply_symmetry=Truein training
 
- Verify 
- Jerky Motions:
- Increase disturbance coefficient range (100.0-200.0)
- Enable apply_noise=Truefor smoothness regularization
- Use conservative learning rate (lr=1e-4)
 
🎯 Performance Optimization
Training Efficiency:
- Batch Size: Use 256 for training, 1024 for IK (optimized for modern GPUs)
- Device Management: Specify exact device (cuda:0) for consistency
- Data Loading: Enable renderer (use_renderer=True) only when needed
- Checkpoint Saving: Use save_interval=5to balance storage and recovery
- Validation Frequency: validation_interval=1for close monitoring
- Learning Rate: Conservative schedule (1e-4 → 1e-6) with 20-epoch warmup
Memory Management:
# Production memory optimization
force_retrain = True  # Clear previous state
use_renderer = True   # Only when visualization needed
sample_steps = 10     # Balanced validation sampling
Production Configuration Template:
# Complete production setup based on Turin robot training
if __name__ == "__main__":
    force_retrain = True
    if not os.path.exists(save_dir) or force_retrain:
        trainer.train(
            epochs=1000, 
            lr=1e-4, 
            min_lr=1e-6,
            warmup_epochs=20,
            validation_interval=1,
            criterion=torch.nn.MSELoss(),
            save_interval=5, 
            save_figs=True, 
            apply_symmetry=True, 
            apply_noise=True, 
            visualize=False,  # Disable for production speed
            link_trans_offset=None, 
            load=False,  # Train from scratch
            joint_weights=joint_weights,
            disturbance_min_coeff=100.0,   # Production-tested values
            disturbance_max_coeff=200.0
        )
This comprehensive training framework provides the foundation for robust, accurate human-to-robot motion retargeting with state-of-the-art regularization techniques and validation capabilities.