pose_transformer_trainer
Module: GBC.utils.data_preparation.pose_transformer_trainer
This module provides a comprehensive training framework for PoseTransformer models, implementing advanced training techniques and sophisticated loss functions for robust human-to-robot motion transfer. The trainer incorporates differentiable forward kinematics, multi-component loss functions, and various regularization strategies to ensure stable and accurate pose retargeting.
📚 Dependencies
This module requires the following Python packages:
torch
- PyTorch framework for neural networks and optimizationhuman_body_prior
- SMPL+H body model implementationsmplx
- Extended SMPL body modelswandb
- Experiment tracking and logging (optional)tqdm
- Progress bars for training loopsdataclasses
- Configuration managementGBC.utils.base.base_fk
- Robot forward kinematicsGBC.utils.data_preparation.amass_loader
- AMASS dataset loadingGBC.utils.data_preparation.pose_transformer
- PoseTransformer model
🏗️ Core Components
📊 LossCoefficients
Module Name: GBC.utils.data_preparation.pose_transformer_trainer.LossCoefficients
Definition:
@dataclass
class LossCoefficients:
main_loss: float = 5000.0
aux_loss: float = 1.0
out_of_range_loss: float = 1000.0
high_value_action_loss: float = 1000.0
direct_mapping_loss: float = 1000.0
symmetry_loss: float = 1000.0
reference_action_loss: float = 5000.0
bone_direction_loss: float = 1000.0
disturbance_base: float = 1.0
🔧 Functionality: Configuration class for managing all loss coefficients in the training process. Provides fine-grained control over the relative importance of different loss components.
⚙️ Configuration Parameters:
- main_loss: Primary FK distance loss weight (λ_dist)
- symmetry_loss: Bilateral symmetry preservation weight (λ_sym)
- out_of_range_loss: Joint limit violation penalty (λ_limit)
- disturbance_base: Action disturbance regularization (λ_disturb)
- direct_mapping_loss: Single DOF joint mapping loss
- bone_direction_loss: Bone direction consistency loss
- aux_loss: Auxiliary model loss (for MoE models)
🎯 PoseFormerTrainer
Module Name: GBC.utils.data_preparation.pose_transformer_trainer.PoseFormerTrainer
Definition:
class PoseFormerTrainer:
def __init__(self,
urdf_path: str,
dataset_path: str,
mapping_table: Dict[str, str],
smplh_model_path: str,
dmpls_model_path: str,
smpl_fits_dir: str,
load_hands: bool = False,
batch_size: int = 1,
train_size: float = 0.8,
device: str = 'cuda',
loss_coefficients: Optional[LossCoefficients] = None,
use_wandb: bool = True,
**kwargs
):
📥 Input Parameters:
- urdf_path (str): Path to robot URDF description file
- dataset_path (str): Path to AMASS dataset directory
- mapping_table (Dict[str, str]): SMPL+H to robot joint mapping (see create_smplh documentation)
- smplh_model_path (str): Path to SMPL+H model file
- dmpls_model_path (str): Path to DMPL muscle model file
- smpl_fits_dir (str): Path to fitted SMPL+H parameters (.pt file)
- load_hands (bool): Include hand joints in training
- batch_size (int): Training batch size
- train_size (float): Training/validation split ratio
- device (str): Computing device ("cuda" or "cpu")
- loss_coefficients (LossCoefficients): Loss function weights configuration
- use_wandb (bool): Enable Weights & Biases logging
🔧 Functionality: Comprehensive training framework for PoseTransformer models with advanced loss functions, data augmentation, and validation capabilities.
🧠 Advanced Loss Functions
The training framework implements a sophisticated multi-component loss function designed to address the inherent challenges in human-to-robot motion retargeting. The composite objective combines several specialized loss terms:
🎯 Primary Loss: FK Distance Loss (L_dist)
Mathematical Formulation:
L_dist(p) = ||FK_hn[f(p)] - FK_hm(p)||_2
🔧 Implementation:
# Compute robot joint positions from predicted actions
robot_joints = self.fk(predicted_actions)
# Compute human joint positions from SMPL pose
human_joints = self.body_model(**body_params).Jtr
# Primary distance loss
main_loss = criterion(robot_joints, human_joints[target_indices])
💡 Purpose: Minimizes Euclidean distance between robot end-effector locations and corresponding human joint positions from motion capture data. This forms the core objective for pose retargeting.
⚖️ Joint-Limit Loss (L_limit)
Mathematical Formulation:
L_limit(p) = ||max(a_min - f(p), 0) + max(f(p) - a_max, 0)||_2
🔧 Implementation:
@staticmethod
def out_of_range_loss(actions: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor):
"""Penalize joint limit violations"""
lower_violation = torch.clamp(min_val - actions, min=0.0)
upper_violation = torch.clamp(actions - max_val, min=0.0)
return torch.mean(lower_violation + upper_violation)
💡 Purpose: Prevents joint limit violations by penalizing actions that exceed the robot's physical constraints. Joint limits are parsed from URDF specifications with additional empirical human range bounds.
🌊 Action-Disturbance Loss (L_disturb)
Mathematical Formulation:
L_disturb(p) = ||FK_hn[f(p)] - FK_hn[f(p + δp)]||_2 / ||δp||_2
where δp ~ N(0, σ(2π)²)
with σ ≪ 1
🔧 Implementation:
@staticmethod
def disturbance_loss(actions: torch.Tensor, actions_disturbed: torch.Tensor, disturbance: torch.Tensor):
"""Enforce Lipschitz continuity through input perturbation"""
action_diff = torch.norm(actions - actions_disturbed, dim=-1)
disturbance_norm = torch.norm(disturbance, dim=-1)
return torch.mean(action_diff / (disturbance_norm + 1e-8))
💡 Purpose: Enforces Lipschitz continuity by minimizing sensitivity to input perturbations. This regularization ensures smooth, dynamically feasible output sequences and reduces discontinuities in the learned mapping.
🔄 Symmetry Loss (L_sym)
Mathematical Formulation:
L_sym(p) = ||S(FK_hn[f(p)]) - FK_hn[f(S(p))]||_2
where S
denotes reflection about the humanoid's sagittal plane
🔧 Implementation:
def symmetry_loss(self, actions, pose_body, pose_hand=None):
"""Preserve bilateral symmetry in retargeting"""
# Apply symmetry transformation to input pose
pose_body_flipped = symmetry_smplh_pose(pose_body)
# Forward pass on flipped pose
flipped_input = self.prepare_model_input(pose_body_flipped, pose_hand)
actions_flipped = self.model(flipped_input)
# Apply symmetry to original actions and compare
actions_sym = self.flip_left_right.apply_action_flip(actions)
return F.mse_loss(actions_sym, actions_flipped)
💡 Purpose: Preserves bilateral symmetry by ensuring that left-right reflected input poses produce correspondingly reflected robot actions. Critical for maintaining natural, symmetric movement patterns.
🎯 Direct Mapping Loss
Mathematical Formulation: For single-DOF joints (e.g., knees), directly map specific angle components:
🔧 Implementation:
def direct_mapping_loss(self, actions: torch.Tensor, pose_body: torch.Tensor, pose_hand: torch.Tensor = None):
"""Direct angle mapping for single-DOF joints"""
pose_body = pose_body.reshape(-1, 21, 3)
ypr = batch_angle_axis_to_ypr(pose_body) # Convert to yaw-pitch-roll
# Extract specified angle components and corresponding actions
ypr_selected = ypr[:, self.pose_body_indices, self.angle_indices]
actions_selected = actions[:, self.action_indices]
return F.mse_loss(actions_selected, ypr_selected.detach())
💡 Purpose: Provides direct supervision for joints with single valid degrees of freedom (e.g., knee pitch). Ensures anatomically correct joint mappings for constrained robot joints.
🦴 Bone Direction Loss
🔧 Implementation:
def bone_direction_loss(self, robot_joints: torch.Tensor, gt_joints: torch.Tensor):
"""Enforce bone direction consistency"""
# Compute bone vectors for predefined bone pairs
robot_bone_vecs = robot_joints[:, self.bone_end_indices] - robot_joints[:, self.bone_start_indices]
gt_bone_vecs = gt_joints[:, self.bone_end_indices] - gt_joints[:, self.bone_start_indices]
# Normalize and compute cosine similarity
robot_bone_units = F.normalize(robot_bone_vecs, dim=-1)
gt_bone_units = F.normalize(gt_bone_vecs, dim=-1)
cos_sim = torch.sum(robot_bone_units * gt_bone_units, dim=-1)
return torch.mean(1.0 - cos_sim)
💡 Purpose: Maintains consistent bone directions between robot and human skeletons, preserving kinematic chain relationships and anatomical structure.
🚀 Training Configuration
📋 Required Configuration
⚠️ Critical Setup Requirements:
-
Joint Mapping Configuration:
- Configure
mapping_table
following the same principles as SMPL+H fitting - Reference: Create SMPL+H Documentation for detailed mapping guidelines
- Ensure consistent joint correspondence between SMPL+H and robot URDF
- Configure
-
Model Paths:
trainer = PoseFormerTrainer(
urdf_path="/path/to/robot.urdf",
smplh_model_path="/path/to/smplh/model.npz",
dmpls_model_path="/path/to/dmpls/model.npz",
smpl_fits_dir="/path/to/fitted_params.pt", # From SMPL+H fitting
dataset_path="/path/to/AMASS_dataset"
) -
Loss Coefficient Tuning:
# Production-tested weight hierarchy for Turin robot training
loss_coeffs = LossCoefficients(
main_loss=5000.0, # Primary FK distance (λ_dist)
aux_loss=1.0, # Auxiliary model outputs
out_of_range_loss=1000.0, # Joint limit violations (λ_limit)
high_value_action_loss=1000.0, # Large action penalties
direct_mapping_loss=1000.0, # Single DOF joint mapping
symmetry_loss=1000.0, # Bilateral symmetry (λ_sym)
reference_action_loss=5000.0, # Reference motion consistency
bone_direction_loss=50.0, # Bone direction alignment
disturbance_base=1.0, # Lipschitz regularization base (λ_disturb)
) -
Batch Configuration:
# Optimized for GPU memory and training stability
batch_size = 256 # Training batch size
ik_batch_size = 1024 # Inverse kinematics batch size
device = "cuda:0" # Primary GPU device
use_renderer = True # Enable visualization capabilities
sample_steps = 10 # Validation sampling frequency
🎯 Advanced Training Techniques
📈 Progressive Disturbance Training
🔧 Implementation:
def _get_progressive_disturbance_coeff(self, epoch, total_epochs, min_coeff=1.0, max_coeff=50.0):
"""Progressively increase disturbance loss weight during training"""
progress = min(epoch / total_epochs, 1.0)
return min_coeff + (max_coeff - min_coeff) * progress
💡 Strategy:
- Early Training: Low disturbance coefficient for stable convergence
- Late Training: High disturbance coefficient for robust generalization
- Benefits: Balances learning speed with robustness to input variations
🔄 Data Augmentation Strategies
Noise Injection:
# Gaussian noise perturbation for disturbance loss
noise_std = 0.01
pose_noise = torch.randn_like(pose_body) * noise_std
pose_body_disturbed = pose_body + pose_noise
Symmetry Augmentation:
# Random left-right flipping for bilateral symmetry
if apply_symmetry and random.random() < 0.5:
pose_body = symmetry_smplh_pose(pose_body)
actions = flip_left_right.apply_action_flip(actions)
📊 Learning Rate Scheduling
Warmup + Cosine Annealing:
class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler):
"""Combines warmup with cosine annealing for stable training"""
def get_lr(self):
if self.last_epoch < self.warmup_epochs:
# Linear warmup
return [base_lr * (self.last_epoch / self.warmup_epochs) for base_lr in self.base_lrs]
else:
# Cosine annealing
progress = (self.last_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
return [self.min_lr + (base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
for base_lr in self.base_lrs]
💡 Usage Examples
🚀 Basic Training Setup
from GBC.utils.data_preparation.pose_transformer_trainer import PoseFormerTrainer, LossCoefficients
from GBC.utils.base.assets import DATA_PATHS
# Configure mapping table defined in create_smplh (this is just an example)
mapping_table = {
"Pelvis": "base_link",
"L_Hip": "l_hip_yaw_link",
"R_Hip": "r_hip_yaw_link",
"L_Knee": "l_knee_link",
"R_Knee": "r_knee_link",
"L_Ankle": "l_ankle_pitch_link",
"R_Ankle": "r_ankle_pitch_link",
"L_Shoulder": "l_arm_roll_link",
"R_Shoulder": "r_arm_roll_link",
"L_Elbow": "l_elbow_roll_link",
"R_Elbow": "r_elbow_roll_link",
"L_Wrist": "l_wrist_roll_link",
"R_Wrist": "r_wrist_roll_link",
}
# Configure loss coefficients
loss_coeffs = LossCoefficients(
main_loss=5000.0,
aux_loss=1.0,
out_of_range_loss=1000.0,
high_value_action_loss=1000.0,
direct_mapping_loss=1000.0,
symmetry_loss=1000.0,
reference_action_loss=5000.0,
bone_direction_loss=50.0,
disturbance_base=1.0,
)
# Initialize trainer
trainer = PoseFormerTrainer(
urdf_path="/path/to/your_robot.urdf",
dataset_path="your_dataset_path",
mapping_table=mapping_table,
smplh_model_path=DATA_PATHS.smplh_model_path,
dmpls_model_path=DATA_PATHS.dmpls_model_path,
smpl_fits_dir="your_smplh_fit_pt", # From SMPL+H fitting
batch_size=256,
device='your_device', # e.g., 'cuda:0'
use_renderer=True,
sample_steps=10,
save_dir="your_save_path",
loss_coefficients=loss_coeffs,
use_wandb=True,
wandb_project="pose_transformer_{}".format("your_project_name")
)
🔧 Advanced Training Configuration
# Configure joint weights for end-effector emphasis
joint_weights = torch.tensor([
1.0, # Pelvis - base reference
1.0, # L_Hip - structural
1.0, # R_Hip - structural
1.5, # L_Knee - important for leg kinematics
1.5, # R_Knee - important for leg kinematics
4.0, # L_Ankle - end effector, high precision needed
4.0, # R_Ankle - end effector, high precision needed
1.0, # L_Shoulder - structural
1.0, # R_Shoulder - structural
1.5, # L_Elbow - important for arm kinematics
1.5, # R_Elbow - important for arm kinematics
2.0, # L_Wrist - end effector, precision needed
2.0, # R_Wrist - end effector, precision needed
])
# Configure bilateral symmetry for Turin robot
from GBC.utils.data_preparation.robot_flip_left_right import YourFlipperModule # This one is critical if you want to enable symmetry
from GBC.utils.base.base_fk import RobotKinematics
flipper = YourFlipperModule()
fk = RobotKinematics(urdf_path, device=device)
flipper.prepare_flip_joint_ids(fk.get_dof_names())
trainer.set_flip_left_right(flipper)
🎯 Production Training Execution
# Execute training with production-grade configuration
trainer.train(
epochs=1000,
lr=1e-4, # Conservative learning rate
min_lr=1e-6, # Lower bound for scheduling
warmup_epochs=20, # Gradual warmup
validation_interval=1, # Validate every epoch
criterion=torch.nn.MSELoss(), # Standard regression loss
save_interval=5, # Save checkpoints every 5 epochs
save_figs=True, # Save validation visualizations
apply_symmetry=True, # Enable symmetry augmentation
apply_noise=True, # Enable disturbance training
visualize=False, # Disable real-time visualization for speed
joint_weights=joint_weights, # Emphasize end effectors
disturbance_min_coeff=100.0, # Higher disturbance for robustness
disturbance_max_coeff=200.0, # Strong final regularization
load=False, # Train from scratch
link_trans_offset=None # No additional link offsets
)
🔍 Validation and Monitoring
📊 Validation Metrics
Core Metrics:
- FK Distance: Primary pose retargeting accuracy
- Joint Limit Violations: Safety and feasibility assessment
- Symmetry Consistency: Bilateral movement preservation
- Action Smoothness: Temporal continuity evaluation
Implementation:
def validate(self, criterion, joint_weights=None, visualize=False, **kwargs):
"""Comprehensive validation with multiple metrics"""
self.model.eval()
val_metrics = {
'val_main_loss': 0.0,
'val_symmetry_loss': 0.0,
'val_limit_violations': 0.0,
'val_disturbance_loss': 0.0
}
# Validation loop with metric computation
for batch in self.test_loader:
# ... validation logic
return val_metrics
💾 Model Persistence
Checkpoint Management:
# Save training state
trainer.save("checkpoint_epoch_100.pt", optimizer)
# Load from checkpoint
trainer.load("checkpoint_epoch_100.pt", optimizer)
# Model export for deployment
torch.save({
'model_state_dict': trainer.model.state_dict(),
'config': model_config,
'mapping_table': mapping_table
}, 'final_model.pt')
📈 Experiment Tracking
Weights & Biases Integration:
# Automatic hyperparameter logging
config = {
'batch_size': batch_size,
'learning_rate': lr,
'loss_coefficients': loss_coeffs.to_dict(),
'model_architecture': 'PoseTransformer'
}
# Real-time metric monitoring
trainer.logger.log_metrics({
'train_loss': train_loss,
'val_loss': val_loss,
'learning_rate': current_lr
}, step=epoch)
🚨 Best Practices
✅ Training Guidelines
- Initialization: Start with converged SMPL+H fitting results
- Loss Balancing: Use production-tested weights (main=5000, symmetry=1000, limits=1000)
- Progressive Training: Start with low disturbance (100.0) and scale to high (200.0)
- Validation: Monitor multiple metrics and save validation figures
- Regularization: Enable both noise injection and symmetry augmentation for robustness
- Dataset: Use full AMASS dataset for comprehensive coverage
- Architecture: Configure appropriate batch sizes (256 train, 1024 IK) for memory efficiency
🔧 Troubleshooting
Common Issues:
- Convergence Problems:
- Verify SMPL+H fitting quality (should be from converged fitting)
- Check joint mapping accuracy against URDF specifications
- Balance loss coefficients (main_loss=5000.0 typically optimal)
- Joint Violations:
- Increase
out_of_range_loss=1000.0
coefficient - Verify URDF joint limits are realistic
- Increase
- Asymmetric Results:
- Verify
YourFlipperModule
configuration - Increase
symmetry_loss=1000.0
coefficient - Enable
apply_symmetry=True
in training
- Verify
- Jerky Motions:
- Increase disturbance coefficient range (100.0-200.0)
- Enable
apply_noise=True
for smoothness regularization - Use conservative learning rate (
lr=1e-4
)
🎯 Performance Optimization
Training Efficiency:
- Batch Size: Use 256 for training, 1024 for IK (optimized for modern GPUs)
- Device Management: Specify exact device (
cuda:0
) for consistency - Data Loading: Enable renderer (
use_renderer=True
) only when needed - Checkpoint Saving: Use
save_interval=5
to balance storage and recovery - Validation Frequency:
validation_interval=1
for close monitoring - Learning Rate: Conservative schedule (1e-4 → 1e-6) with 20-epoch warmup
Memory Management:
# Production memory optimization
force_retrain = True # Clear previous state
use_renderer = True # Only when visualization needed
sample_steps = 10 # Balanced validation sampling
Production Configuration Template:
# Complete production setup based on Turin robot training
if __name__ == "__main__":
force_retrain = True
if not os.path.exists(save_dir) or force_retrain:
trainer.train(
epochs=1000,
lr=1e-4,
min_lr=1e-6,
warmup_epochs=20,
validation_interval=1,
criterion=torch.nn.MSELoss(),
save_interval=5,
save_figs=True,
apply_symmetry=True,
apply_noise=True,
visualize=False, # Disable for production speed
link_trans_offset=None,
load=False, # Train from scratch
joint_weights=joint_weights,
disturbance_min_coeff=100.0, # Production-tested values
disturbance_max_coeff=200.0
)
This comprehensive training framework provides the foundation for robust, accurate human-to-robot motion retargeting with state-of-the-art regularization techniques and validation capabilities.