跳到主要内容

pose_transformer_trainer

Module: GBC.utils.data_preparation.pose_transformer_trainer

This module provides a comprehensive training framework for PoseTransformer models, implementing advanced training techniques and sophisticated loss functions for robust human-to-robot motion transfer. The trainer incorporates differentiable forward kinematics, multi-component loss functions, and various regularization strategies to ensure stable and accurate pose retargeting.

📚 Dependencies

This module requires the following Python packages:

  • torch - PyTorch framework for neural networks and optimization
  • human_body_prior - SMPL+H body model implementation
  • smplx - Extended SMPL body models
  • wandb - Experiment tracking and logging (optional)
  • tqdm - Progress bars for training loops
  • dataclasses - Configuration management
  • GBC.utils.base.base_fk - Robot forward kinematics
  • GBC.utils.data_preparation.amass_loader - AMASS dataset loading
  • GBC.utils.data_preparation.pose_transformer - PoseTransformer model

🏗️ Core Components

📊 LossCoefficients

Module Name: GBC.utils.data_preparation.pose_transformer_trainer.LossCoefficients

Definition:

@dataclass
class LossCoefficients:
main_loss: float = 5000.0
aux_loss: float = 1.0
out_of_range_loss: float = 1000.0
high_value_action_loss: float = 1000.0
direct_mapping_loss: float = 1000.0
symmetry_loss: float = 1000.0
reference_action_loss: float = 5000.0
bone_direction_loss: float = 1000.0
disturbance_base: float = 1.0

🔧 Functionality: Configuration class for managing all loss coefficients in the training process. Provides fine-grained control over the relative importance of different loss components.

⚙️ Configuration Parameters:

  • main_loss: Primary FK distance loss weight (λ_dist)
  • symmetry_loss: Bilateral symmetry preservation weight (λ_sym)
  • out_of_range_loss: Joint limit violation penalty (λ_limit)
  • disturbance_base: Action disturbance regularization (λ_disturb)
  • direct_mapping_loss: Single DOF joint mapping loss
  • bone_direction_loss: Bone direction consistency loss
  • aux_loss: Auxiliary model loss (for MoE models)

🎯 PoseFormerTrainer

Module Name: GBC.utils.data_preparation.pose_transformer_trainer.PoseFormerTrainer

Definition:

class PoseFormerTrainer:
def __init__(self,
urdf_path: str,
dataset_path: str,
mapping_table: Dict[str, str],
smplh_model_path: str,
dmpls_model_path: str,
smpl_fits_dir: str,
load_hands: bool = False,
batch_size: int = 1,
train_size: float = 0.8,
device: str = 'cuda',
loss_coefficients: Optional[LossCoefficients] = None,
use_wandb: bool = True,
**kwargs
):

📥 Input Parameters:

  • urdf_path (str): Path to robot URDF description file
  • dataset_path (str): Path to AMASS dataset directory
  • mapping_table (Dict[str, str]): SMPL+H to robot joint mapping (see create_smplh documentation)
  • smplh_model_path (str): Path to SMPL+H model file
  • dmpls_model_path (str): Path to DMPL muscle model file
  • smpl_fits_dir (str): Path to fitted SMPL+H parameters (.pt file)
  • load_hands (bool): Include hand joints in training
  • batch_size (int): Training batch size
  • train_size (float): Training/validation split ratio
  • device (str): Computing device ("cuda" or "cpu")
  • loss_coefficients (LossCoefficients): Loss function weights configuration
  • use_wandb (bool): Enable Weights & Biases logging

🔧 Functionality: Comprehensive training framework for PoseTransformer models with advanced loss functions, data augmentation, and validation capabilities.


🧠 Advanced Loss Functions

The training framework implements a sophisticated multi-component loss function designed to address the inherent challenges in human-to-robot motion retargeting. The composite objective combines several specialized loss terms:

🎯 Primary Loss: FK Distance Loss (L_dist)

Mathematical Formulation:

L_dist(p) = ||FK_hn[f(p)] - FK_hm(p)||_2

🔧 Implementation:

# Compute robot joint positions from predicted actions
robot_joints = self.fk(predicted_actions)

# Compute human joint positions from SMPL pose
human_joints = self.body_model(**body_params).Jtr

# Primary distance loss
main_loss = criterion(robot_joints, human_joints[target_indices])

💡 Purpose: Minimizes Euclidean distance between robot end-effector locations and corresponding human joint positions from motion capture data. This forms the core objective for pose retargeting.


⚖️ Joint-Limit Loss (L_limit)

Mathematical Formulation:

L_limit(p) = ||max(a_min - f(p), 0) + max(f(p) - a_max, 0)||_2

🔧 Implementation:

@staticmethod
def out_of_range_loss(actions: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor):
"""Penalize joint limit violations"""
lower_violation = torch.clamp(min_val - actions, min=0.0)
upper_violation = torch.clamp(actions - max_val, min=0.0)
return torch.mean(lower_violation + upper_violation)

💡 Purpose: Prevents joint limit violations by penalizing actions that exceed the robot's physical constraints. Joint limits are parsed from URDF specifications with additional empirical human range bounds.


🌊 Action-Disturbance Loss (L_disturb)

Mathematical Formulation:

L_disturb(p) = ||FK_hn[f(p)] - FK_hn[f(p + δp)]||_2 / ||δp||_2

where δp ~ N(0, σ(2π)²) with σ ≪ 1

🔧 Implementation:

@staticmethod
def disturbance_loss(actions: torch.Tensor, actions_disturbed: torch.Tensor, disturbance: torch.Tensor):
"""Enforce Lipschitz continuity through input perturbation"""
action_diff = torch.norm(actions - actions_disturbed, dim=-1)
disturbance_norm = torch.norm(disturbance, dim=-1)
return torch.mean(action_diff / (disturbance_norm + 1e-8))

💡 Purpose: Enforces Lipschitz continuity by minimizing sensitivity to input perturbations. This regularization ensures smooth, dynamically feasible output sequences and reduces discontinuities in the learned mapping.


🔄 Symmetry Loss (L_sym)

Mathematical Formulation:

L_sym(p) = ||S(FK_hn[f(p)]) - FK_hn[f(S(p))]||_2

where S denotes reflection about the humanoid's sagittal plane

🔧 Implementation:

def symmetry_loss(self, actions, pose_body, pose_hand=None):
"""Preserve bilateral symmetry in retargeting"""
# Apply symmetry transformation to input pose
pose_body_flipped = symmetry_smplh_pose(pose_body)

# Forward pass on flipped pose
flipped_input = self.prepare_model_input(pose_body_flipped, pose_hand)
actions_flipped = self.model(flipped_input)

# Apply symmetry to original actions and compare
actions_sym = self.flip_left_right.apply_action_flip(actions)
return F.mse_loss(actions_sym, actions_flipped)

💡 Purpose: Preserves bilateral symmetry by ensuring that left-right reflected input poses produce correspondingly reflected robot actions. Critical for maintaining natural, symmetric movement patterns.


🎯 Direct Mapping Loss

Mathematical Formulation: For single-DOF joints (e.g., knees), directly map specific angle components:

🔧 Implementation:

def direct_mapping_loss(self, actions: torch.Tensor, pose_body: torch.Tensor, pose_hand: torch.Tensor = None):
"""Direct angle mapping for single-DOF joints"""
pose_body = pose_body.reshape(-1, 21, 3)
ypr = batch_angle_axis_to_ypr(pose_body) # Convert to yaw-pitch-roll

# Extract specified angle components and corresponding actions
ypr_selected = ypr[:, self.pose_body_indices, self.angle_indices]
actions_selected = actions[:, self.action_indices]

return F.mse_loss(actions_selected, ypr_selected.detach())

💡 Purpose: Provides direct supervision for joints with single valid degrees of freedom (e.g., knee pitch). Ensures anatomically correct joint mappings for constrained robot joints.


🦴 Bone Direction Loss

🔧 Implementation:

def bone_direction_loss(self, robot_joints: torch.Tensor, gt_joints: torch.Tensor):
"""Enforce bone direction consistency"""
# Compute bone vectors for predefined bone pairs
robot_bone_vecs = robot_joints[:, self.bone_end_indices] - robot_joints[:, self.bone_start_indices]
gt_bone_vecs = gt_joints[:, self.bone_end_indices] - gt_joints[:, self.bone_start_indices]

# Normalize and compute cosine similarity
robot_bone_units = F.normalize(robot_bone_vecs, dim=-1)
gt_bone_units = F.normalize(gt_bone_vecs, dim=-1)
cos_sim = torch.sum(robot_bone_units * gt_bone_units, dim=-1)

return torch.mean(1.0 - cos_sim)

💡 Purpose: Maintains consistent bone directions between robot and human skeletons, preserving kinematic chain relationships and anatomical structure.


🚀 Training Configuration

📋 Required Configuration

⚠️ Critical Setup Requirements:

  1. Joint Mapping Configuration:

    • Configure mapping_table following the same principles as SMPL+H fitting
    • Reference: Create SMPL+H Documentation for detailed mapping guidelines
    • Ensure consistent joint correspondence between SMPL+H and robot URDF
  2. Model Paths:

    trainer = PoseFormerTrainer(
    urdf_path="/path/to/robot.urdf",
    smplh_model_path="/path/to/smplh/model.npz",
    dmpls_model_path="/path/to/dmpls/model.npz",
    smpl_fits_dir="/path/to/fitted_params.pt", # From SMPL+H fitting
    dataset_path="/path/to/AMASS_dataset"
    )
  3. Loss Coefficient Tuning:

    # Production-tested weight hierarchy for Turin robot training
    loss_coeffs = LossCoefficients(
    main_loss=5000.0, # Primary FK distance (λ_dist)
    aux_loss=1.0, # Auxiliary model outputs
    out_of_range_loss=1000.0, # Joint limit violations (λ_limit)
    high_value_action_loss=1000.0, # Large action penalties
    direct_mapping_loss=1000.0, # Single DOF joint mapping
    symmetry_loss=1000.0, # Bilateral symmetry (λ_sym)
    reference_action_loss=5000.0, # Reference motion consistency
    bone_direction_loss=50.0, # Bone direction alignment
    disturbance_base=1.0, # Lipschitz regularization base (λ_disturb)
    )
  4. Batch Configuration:

    # Optimized for GPU memory and training stability
    batch_size = 256 # Training batch size
    ik_batch_size = 1024 # Inverse kinematics batch size
    device = "cuda:0" # Primary GPU device
    use_renderer = True # Enable visualization capabilities
    sample_steps = 10 # Validation sampling frequency

🎯 Advanced Training Techniques

📈 Progressive Disturbance Training

🔧 Implementation:

def _get_progressive_disturbance_coeff(self, epoch, total_epochs, min_coeff=1.0, max_coeff=50.0):
"""Progressively increase disturbance loss weight during training"""
progress = min(epoch / total_epochs, 1.0)
return min_coeff + (max_coeff - min_coeff) * progress

💡 Strategy:

  • Early Training: Low disturbance coefficient for stable convergence
  • Late Training: High disturbance coefficient for robust generalization
  • Benefits: Balances learning speed with robustness to input variations

🔄 Data Augmentation Strategies

Noise Injection:

# Gaussian noise perturbation for disturbance loss
noise_std = 0.01
pose_noise = torch.randn_like(pose_body) * noise_std
pose_body_disturbed = pose_body + pose_noise

Symmetry Augmentation:

# Random left-right flipping for bilateral symmetry
if apply_symmetry and random.random() < 0.5:
pose_body = symmetry_smplh_pose(pose_body)
actions = flip_left_right.apply_action_flip(actions)

📊 Learning Rate Scheduling

Warmup + Cosine Annealing:

class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
"""Combines warmup with cosine annealing for stable training"""
def get_lr(self):
if self.last_epoch < self.warmup_epochs:
# Linear warmup
return [base_lr * (self.last_epoch / self.warmup_epochs) for base_lr in self.base_lrs]
else:
# Cosine annealing
progress = (self.last_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
return [self.min_lr + (base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
for base_lr in self.base_lrs]

💡 Usage Examples

🚀 Basic Training Setup

from GBC.utils.data_preparation.pose_transformer_trainer import PoseFormerTrainer, LossCoefficients
from GBC.utils.base.assets import DATA_PATHS

# Configure mapping table defined in create_smplh (this is just an example)
mapping_table = {
"Pelvis": "base_link",
"L_Hip": "l_hip_yaw_link",
"R_Hip": "r_hip_yaw_link",
"L_Knee": "l_knee_link",
"R_Knee": "r_knee_link",
"L_Ankle": "l_ankle_pitch_link",
"R_Ankle": "r_ankle_pitch_link",
"L_Shoulder": "l_arm_roll_link",
"R_Shoulder": "r_arm_roll_link",
"L_Elbow": "l_elbow_roll_link",
"R_Elbow": "r_elbow_roll_link",
"L_Wrist": "l_wrist_roll_link",
"R_Wrist": "r_wrist_roll_link",
}

# Configure loss coefficients
loss_coeffs = LossCoefficients(
main_loss=5000.0,
aux_loss=1.0,
out_of_range_loss=1000.0,
high_value_action_loss=1000.0,
direct_mapping_loss=1000.0,
symmetry_loss=1000.0,
reference_action_loss=5000.0,
bone_direction_loss=50.0,
disturbance_base=1.0,
)

# Initialize trainer
trainer = PoseFormerTrainer(
urdf_path="/path/to/your_robot.urdf",
dataset_path="your_dataset_path",
mapping_table=mapping_table,
smplh_model_path=DATA_PATHS.smplh_model_path,
dmpls_model_path=DATA_PATHS.dmpls_model_path,
smpl_fits_dir="your_smplh_fit_pt", # From SMPL+H fitting
batch_size=256,
device='your_device', # e.g., 'cuda:0'
use_renderer=True,
sample_steps=10,
save_dir="your_save_path",
loss_coefficients=loss_coeffs,
use_wandb=True,
wandb_project="pose_transformer_{}".format("your_project_name")
)

🔧 Advanced Training Configuration

# Configure joint weights for end-effector emphasis
joint_weights = torch.tensor([
1.0, # Pelvis - base reference
1.0, # L_Hip - structural
1.0, # R_Hip - structural
1.5, # L_Knee - important for leg kinematics
1.5, # R_Knee - important for leg kinematics
4.0, # L_Ankle - end effector, high precision needed
4.0, # R_Ankle - end effector, high precision needed
1.0, # L_Shoulder - structural
1.0, # R_Shoulder - structural
1.5, # L_Elbow - important for arm kinematics
1.5, # R_Elbow - important for arm kinematics
2.0, # L_Wrist - end effector, precision needed
2.0, # R_Wrist - end effector, precision needed
])

# Configure bilateral symmetry for Turin robot
from GBC.utils.data_preparation.robot_flip_left_right import YourFlipperModule # This one is critical if you want to enable symmetry
from GBC.utils.base.base_fk import RobotKinematics

flipper = YourFlipperModule()
fk = RobotKinematics(urdf_path, device=device)
flipper.prepare_flip_joint_ids(fk.get_dof_names())
trainer.set_flip_left_right(flipper)

🎯 Production Training Execution

# Execute training with production-grade configuration
trainer.train(
epochs=1000,
lr=1e-4, # Conservative learning rate
min_lr=1e-6, # Lower bound for scheduling
warmup_epochs=20, # Gradual warmup
validation_interval=1, # Validate every epoch
criterion=torch.nn.MSELoss(), # Standard regression loss
save_interval=5, # Save checkpoints every 5 epochs
save_figs=True, # Save validation visualizations
apply_symmetry=True, # Enable symmetry augmentation
apply_noise=True, # Enable disturbance training
visualize=False, # Disable real-time visualization for speed
joint_weights=joint_weights, # Emphasize end effectors
disturbance_min_coeff=100.0, # Higher disturbance for robustness
disturbance_max_coeff=200.0, # Strong final regularization
load=False, # Train from scratch
link_trans_offset=None # No additional link offsets
)

🔍 Validation and Monitoring

📊 Validation Metrics

Core Metrics:

  • FK Distance: Primary pose retargeting accuracy
  • Joint Limit Violations: Safety and feasibility assessment
  • Symmetry Consistency: Bilateral movement preservation
  • Action Smoothness: Temporal continuity evaluation

Implementation:

def validate(self, criterion, joint_weights=None, visualize=False, **kwargs):
"""Comprehensive validation with multiple metrics"""
self.model.eval()
val_metrics = {
'val_main_loss': 0.0,
'val_symmetry_loss': 0.0,
'val_limit_violations': 0.0,
'val_disturbance_loss': 0.0
}

# Validation loop with metric computation
for batch in self.test_loader:
# ... validation logic

return val_metrics

💾 Model Persistence

Checkpoint Management:

# Save training state
trainer.save("checkpoint_epoch_100.pt", optimizer)

# Load from checkpoint
trainer.load("checkpoint_epoch_100.pt", optimizer)

# Model export for deployment
torch.save({
'model_state_dict': trainer.model.state_dict(),
'config': model_config,
'mapping_table': mapping_table
}, 'final_model.pt')

📈 Experiment Tracking

Weights & Biases Integration:

# Automatic hyperparameter logging
config = {
'batch_size': batch_size,
'learning_rate': lr,
'loss_coefficients': loss_coeffs.to_dict(),
'model_architecture': 'PoseTransformer'
}

# Real-time metric monitoring
trainer.logger.log_metrics({
'train_loss': train_loss,
'val_loss': val_loss,
'learning_rate': current_lr
}, step=epoch)

🚨 Best Practices

✅ Training Guidelines

  1. Initialization: Start with converged SMPL+H fitting results
  2. Loss Balancing: Use production-tested weights (main=5000, symmetry=1000, limits=1000)
  3. Progressive Training: Start with low disturbance (100.0) and scale to high (200.0)
  4. Validation: Monitor multiple metrics and save validation figures
  5. Regularization: Enable both noise injection and symmetry augmentation for robustness
  6. Dataset: Use full AMASS dataset for comprehensive coverage
  7. Architecture: Configure appropriate batch sizes (256 train, 1024 IK) for memory efficiency

🔧 Troubleshooting

Common Issues:

  • Convergence Problems:
    • Verify SMPL+H fitting quality (should be from converged fitting)
    • Check joint mapping accuracy against URDF specifications
    • Balance loss coefficients (main_loss=5000.0 typically optimal)
  • Joint Violations:
    • Increase out_of_range_loss=1000.0 coefficient
    • Verify URDF joint limits are realistic
  • Asymmetric Results:
    • Verify YourFlipperModule configuration
    • Increase symmetry_loss=1000.0 coefficient
    • Enable apply_symmetry=True in training
  • Jerky Motions:
    • Increase disturbance coefficient range (100.0-200.0)
    • Enable apply_noise=True for smoothness regularization
    • Use conservative learning rate (lr=1e-4)

🎯 Performance Optimization

Training Efficiency:

  • Batch Size: Use 256 for training, 1024 for IK (optimized for modern GPUs)
  • Device Management: Specify exact device (cuda:0) for consistency
  • Data Loading: Enable renderer (use_renderer=True) only when needed
  • Checkpoint Saving: Use save_interval=5 to balance storage and recovery
  • Validation Frequency: validation_interval=1 for close monitoring
  • Learning Rate: Conservative schedule (1e-4 → 1e-6) with 20-epoch warmup

Memory Management:

# Production memory optimization
force_retrain = True # Clear previous state
use_renderer = True # Only when visualization needed
sample_steps = 10 # Balanced validation sampling

Production Configuration Template:

# Complete production setup based on Turin robot training
if __name__ == "__main__":
force_retrain = True
if not os.path.exists(save_dir) or force_retrain:
trainer.train(
epochs=1000,
lr=1e-4,
min_lr=1e-6,
warmup_epochs=20,
validation_interval=1,
criterion=torch.nn.MSELoss(),
save_interval=5,
save_figs=True,
apply_symmetry=True,
apply_noise=True,
visualize=False, # Disable for production speed
link_trans_offset=None,
load=False, # Train from scratch
joint_weights=joint_weights,
disturbance_min_coeff=100.0, # Production-tested values
disturbance_max_coeff=200.0
)

This comprehensive training framework provides the foundation for robust, accurate human-to-robot motion retargeting with state-of-the-art regularization techniques and validation capabilities.