Skip to main content

Adversarial Motion Prior (AMP)

Module: rsl_rl.modules.amp

This implementation is based on the paper AMP: Adversarial Motion Priors for Stylized Physics-Based Character Control by Xue Bin Peng, Ze Ma, Pieter Abbeel, Sergey Levine, Angjoo Kanazawa.

🎯 Overview

The Adversarial Motion Prior (AMP) module implements a GAN-based discriminator that learns to distinguish between expert demonstrations and policy-generated motions. This enables reinforcement learning agents to learn natural, life-like behaviors by mimicking motion capture data.

🏗️ Core Concepts

🎭 Adversarial Training

  • Discriminator Network: AMPNet acts as a binary classifier distinguishing expert vs. policy motions
  • Motion Representation: Uses state transitions [current_state, next_state] for motion classification
  • Reward Signal: Provides intrinsic motivation for natural motion through discriminator scores

🎯 Key Features

  • Style Transfer: Learn stylized motions from motion capture data
  • Robust Training: Gradient penalty and label smoothing for stable training
  • Flexible Architecture: Configurable backbone networks (MLP, CNN, etc.)
  • Reward Shaping: Saturated cross-entropy rewards to avoid reward explosion

🧠 Core Components

🔧 HingeLoss

Module Name: rsl_rl.modules.amp.HingeLoss

Definition:

class HingeLoss(nn.Module):
def __init__(self, reduction: str = "mean")

📥 Parameters:

  • reduction (str): Reduction method ("mean" or "sum"). Default is "mean"

🔧 Functionality: Implements hinge loss for Tanh-activated discriminators: loss = max(0, 1 - x * label) where labels are ±1.

💡 Usage: Used in AMP discriminator training when out_activation="tanh" to provide margin-based classification loss.


🕸️ AMPNet

Module Name: rsl_rl.modules.amp.AMPNet

Definition:

class AMPNet(nn.Module):
def __init__(self, backbone_input_dim: int, backbone_output_dim: int,
backbone: str = "mlp", activation: str = "elu",
out_activation: str = "tanh", device: str = "cpu",
label_smoothing: float = 0.0, **kwargs)

📥 Parameters:

  • backbone_input_dim (int): Input dimension (typically 2 × state_dim for [cur_state, next_state])
  • backbone_output_dim (int): Hidden dimension of the backbone network
  • backbone (str): Backbone architecture type ("mlp", "cnn", etc.). Default is "mlp"
  • activation (str): Activation function for backbone layers. Default is "elu"
  • out_activation (str): Output activation ("tanh" or "sigmoid"). Default is "tanh"
  • device (str): Device for computation ("cpu" or "cuda"). Default is "cpu"
  • label_smoothing (float): Label smoothing factor for stable training. Default is 0.0
  • **kwargs: Additional arguments passed to backbone construction

🔧 Functionality: Core discriminator network that classifies motion segments as expert-like or policy-generated. Acts as the adversarial component in AMP training.

💡 Key Features:

  • Dual Activation Support: Tanh (±1 classification) or Sigmoid ([0,1] classification)
  • Adaptive Loss: Automatically selects HingeLoss for Tanh or BCELoss for Sigmoid
  • Regularization: Built-in gradient penalty and label smoothing support

🎯 Core Methods

📊 Forward Pass
def forward(self, x: torch.Tensor) -> torch.Tensor

📥 Input: x (torch.Tensor): Concatenated state transitions, shape (batch_size, 2*state_dim) 📤 Output: Raw discriminator scores, shape (batch_size, 1)

🤖 Policy Training
def policy_loss(self, y: torch.Tensor) -> torch.Tensor

📥 Input: y (torch.Tensor): Discriminator scores for policy-generated motions 📤 Output: Policy loss encouraging expert-like motions 🔧 Function: Computes loss that encourages policy to generate motions classified as expert-like

def policy_acc(self, y: torch.Tensor) -> torch.Tensor

📥 Input: y (torch.Tensor): Discriminator scores for policy motions 📤 Output: Classification accuracy for policy samples 🔧 Function: Measures how often policy motions are classified as expert-like

👨‍🏫 Expert Training
def expert_loss(self, y: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor

📥 Input:

  • y (torch.Tensor): Discriminator scores, shape (num_envs, 1)
  • tgt_mask (torch.Tensor): Mask for available expert data, shape (num_envs,) 📤 Output: Expert classification loss 🔧 Function: Computes loss for discriminator to correctly classify expert motions
def expert_acc(self, y: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor

📥 Input: Same as expert_loss 📤 Output: Classification accuracy for expert samples 🔧 Function: Measures discriminator accuracy on expert motions

🔗 Gradient Penalty
def expert_grad_penalty(self, expert_cur_state: torch.Tensor, 
expert_next_state: torch.Tensor,
expert_available_mask: torch.Tensor) -> torch.Tensor

📥 Input:

  • expert_cur_state (torch.Tensor): Current states, shape (num_envs, state_dim)
  • expert_next_state (torch.Tensor): Next states, shape (num_envs, state_dim)
  • expert_available_mask (torch.Tensor): Availability mask, shape (num_envs, 1) 📤 Output: Gradient penalty term for training stability 🔧 Function: Implements WGAN-GP style gradient penalty to regularize discriminator training
🎁 Reward Computation
def amp_reward(self, cur_state: torch.Tensor, next_state: torch.Tensor,
epsilon: float = 1e-4, reward_shift: float = 0.45) -> torch.Tensor

📥 Input:

  • cur_state (torch.Tensor): Current states, shape (num_envs, state_dim)
  • next_state (torch.Tensor): Next states, shape (num_envs, state_dim)
  • epsilon (float): Minimum threshold for saturated cross-entropy. Default is 1e-4
  • reward_shift (float): Reward baseline shift. Default is 0.45 📤 Output: AMP rewards, shape (num_envs,) 🔧 Function: Computes intrinsic rewards based on discriminator confidence

🧮 Reward Formula:

reward = -log(max(expert_score - predicted_score - label_smoothing, epsilon)) - reward_shift
reward = clamp(reward, min=-0.1)
📊 Score Evaluation
def amp_score(self, cur_state: torch.Tensor, next_state: torch.Tensor) -> torch.Tensor

📥 Input:

  • cur_state (torch.Tensor): Current states
  • next_state (torch.Tensor): Next states 📤 Output: Normalized discriminator scores, shape (num_envs,) 🔧 Function: Returns post-activation discriminator scores for analysis

💡 Training Pipeline

🔄 AMP Training Loop

# 1. Discriminator Training Phase
expert_loss = amp_net.expert_loss(expert_scores, expert_mask)
policy_loss = amp_net.policy_loss(policy_scores)
grad_penalty = amp_net.expert_grad_penalty(expert_cur, expert_next, expert_mask)

discriminator_loss = expert_loss + policy_loss + lambda_gp * grad_penalty

# 2. Policy Training Phase
amp_rewards = amp_net.amp_reward(cur_states, next_states)
total_rewards = task_rewards + lambda_amp * amp_rewards

# 3. Update networks
discriminator_optimizer.zero_grad()
discriminator_loss.backward()
discriminator_optimizer.step()

policy_optimizer.zero_grad()
policy_loss_with_amp_rewards.backward()
policy_optimizer.step()

🎯 Key Training Considerations

🔧 Hyperparameter Guidelines:

  • lambda_gp (gradient penalty weight): 0.1 - 1.0
  • lambda_amp (AMP reward weight): 0.1 - 0.5
  • label_smoothing: 0.0 - 0.2 for training stability
  • reward_shift: 0.4 - 0.5 to avoid positive reward explosion

⚖️ Training Balance:

  • Update discriminator and policy at similar frequencies
  • Monitor discriminator accuracy (should stay around 50-70%)
  • Watch for discriminator overfitting (policy can't improve)

🚀 Usage Example

import torch
from rsl_rl.modules.amp import AMPNet

# Initialize AMP discriminator
amp_net = AMPNet(
backbone_input_dim=100, # 2 * state_dim (current + next state)
backbone_output_dim=256,
backbone="mlp",
activation="elu",
out_activation="tanh",
label_smoothing=0.1,
net_kwargs={
"hidden_sizes": [512, 512, 256],
"dropout": 0.1
}
)

# Training data
batch_size = 1024
state_dim = 50
cur_states = torch.randn(batch_size, state_dim)
next_states = torch.randn(batch_size, state_dim)
expert_mask = torch.randint(0, 2, (batch_size,)).bool()

# Forward pass
state_transitions = torch.cat([cur_states, next_states], dim=-1)
discriminator_scores = amp_net(state_transitions)

# Compute losses
expert_loss = amp_net.expert_loss(discriminator_scores, expert_mask)
policy_loss = amp_net.policy_loss(discriminator_scores)
grad_penalty = amp_net.expert_grad_penalty(cur_states, next_states, expert_mask.unsqueeze(-1))

# Compute AMP rewards for RL training
amp_rewards = amp_net.amp_reward(cur_states, next_states)

print(f"Expert Loss: {expert_loss.item():.4f}")
print(f"Policy Loss: {policy_loss.item():.4f}")
print(f"Gradient Penalty: {grad_penalty.item():.4f}")
print(f"Average AMP Reward: {amp_rewards.mean().item():.4f}")

🎨 Motion Style Applications

🏃 Supported Motion Types

  • Locomotion: Walking, running, jumping gaits
  • Athletic Motions: Dancing, martial arts, sports movements
  • Character Animation: Expressive gestures, personality traits
  • Multi-Character: Group behaviors and interactions

📊 Performance Metrics

  • Motion Quality: Discriminator accuracy and confidence scores
  • Diversity: Motion variation within learned style manifold
  • Stability: Consistent performance across different environments
  • Transfer: Adaptation to new tasks while preserving style

🔧 Implementation Notes

⚠️ Important Considerations

🎯 State Representation:

  • Use consistent state normalization for current and next states
  • Include relevant motion features (velocities, joint angles, contact info)
  • Avoid including task-specific information in AMP states

🔄 Training Stability:

  • Start with small lambda_amp and gradually increase
  • Monitor discriminator loss - should decrease initially then stabilize
  • Use gradient penalty to prevent discriminator from becoming too strong

🎨 Style Control:

  • Mix multiple motion datasets for style diversity
  • Use curriculum learning: start with single style, add complexity
  • Consider hierarchical approaches for complex multi-style policies

🚨 Common Pitfalls

  • Discriminator Collapse: Too strong discriminator prevents policy learning
  • Mode Collapse: Policy learns limited motion repertoire
  • Reward Explosion: Insufficient reward clipping or shifting
  • State Mismatch: Inconsistent state representations between expert and policy

📖 References

Original Paper:

Peng, X. B., Ma, Z., Abbeel, P., Levine, S., & Kanazawa, A. (2021). AMP: Adversarial Motion Priors for Stylized Physics-Based Character Control. ACM Transactions on Graphics (TOG), 40(4), 1-20.

Key Concepts:

  • Adversarial Training: GAN-based approach for motion imitation
  • Motion Priors: Learning implicit motion manifolds from demonstrations
  • Style Transfer: Transferring motion characteristics across tasks
  • Reward Shaping: Using discriminator confidence as intrinsic motivation

Related Work:

  • DeepMimic: Direct motion imitation without adversarial training
  • GAIL: Generative Adversarial Imitation Learning for discrete actions
  • Motion VAE: Variational approaches to motion modeling

💡 Best Practice: Start with simple locomotion tasks and gradually move to more complex, stylized behaviors. Use motion capture data preprocessing to ensure clean, consistent expert demonstrations.