onnx_exporter
Module: GBC.gyms.isaaclab_45.lab_tasks.utils.wrappers.rsl_rl.onnx_exporter
🎯 Overview
The ONNX exporter provides seamless conversion of BERT-form transformer PPO architectures to ONNX format for efficient deployment. This module leverages the latest ONNX library to export trained policies, enabling cross-platform inference and optimized production deployment.
🚀 Key Benefits
- 🌐 Cross-Platform Deployment: Run trained policies on various hardware and software platforms
- ⚡ Optimized Inference: ONNX runtime provides highly optimized execution
- 🔧 Framework Agnostic: Export from PyTorch to platform-independent format
- 📱 Edge Deployment: Enable deployment on edge devices and embedded systems
📦 Public Functions
export_policy_as_onnx()
def export_policy_as_onnx(
actor_critic: object,
path: str,
normalizer: object | None = None,
filename: str = "policy.onnx",
verbose: bool = False
) -> None
Purpose: Export a trained actor-critic policy to ONNX format for deployment.
Parameters
actor_critic(object): The trained actor-critic torch module to exportpath(str): Directory path where the ONNX file will be savednormalizer(object | None): Empirical normalizer module. IfNone, Identity normalization is usedfilename(str): Name of the exported ONNX file (default:"policy.onnx")verbose(bool): Whether to print detailed export information (default:False)
Functionality
- 📁 Directory Creation: Automatically creates output directory if it doesn't exist
- 🧠 Model Processing: Handles both recurrent (LSTM) and non-recurrent transformer architectures
- 🔄 Reference Support: Automatically detects and exports reference observation capabilities
- ⚙️ Optimization: Uses appropriate ONNX opset versions for best compatibility
Usage Example
from GBC.gyms.isaaclab_45.lab_tasks.utils.wrappers.rsl_rl import export_policy_as_onnx
# Export trained policy
export_policy_as_onnx(
actor_critic=trained_policy,
path="./exported_models",
normalizer=obs_normalizer,
filename="robot_policy.onnx",
verbose=True
)
🏗️ Internal Implementation
_OnnxPolicyExporter Class
Purpose: Private helper class that handles the actual ONNX conversion process.
Key Features
🧠 Architecture Detection
- Automatically detects recurrent vs. non-recurrent models
- Handles BERT-form transformer architectures
- Supports reference observation embeddings
📊 Input/Output Handling
- Standard Mode:
obs→actions - Reference Mode:
obs,ref_obs,ref_mask→actions - Recurrent Mode:
obs,h_in,c_in→actions,h_out,c_out
⚙️ ONNX Configuration
- Opset Version: Uses opset 14 for non-recurrent, opset 11 for recurrent models
- Dynamic Axes: Configurable for flexible batch sizes
- Export Parameters: Includes all learned weights and biases
Supported Model Types
🔄 Non-Recurrent Models (Transformer-based)
# Input specification
input_names = ["obs"] # Standard observations
if reference_mode:
input_names += ["ref_obs", "ref_mask"] # Reference observations + mask
# Output specification
output_names = ["actions"]
🔁 Recurrent Models (LSTM-based)
# Input specification
input_names = ["obs", "h_in", "c_in"] # Observations + hidden states
# Output specification
output_names = ["actions", "h_out", "c_out"] # Actions + updated states
🎯 Export Process Flow
graph TD
A[Trained Policy] --> B[_OnnxPolicyExporter]
B --> C{Architecture Type?}
C -->|Transformer| D[Configure Non-Recurrent Export]
C -->|LSTM| E[Configure Recurrent Export]
D --> F{Reference Observations?}
F -->|Yes| G[Multi-Input ONNX Graph]
F -->|No| H[Single-Input ONNX Graph]
E --> I[LSTM ONNX Graph]
G --> J[Export ONNX File]
H --> J
I --> J
J --> K[Deployment-Ready Model]
🔧 Technical Specifications
ONNX Compatibility
- Framework: PyTorch → ONNX conversion
- Opset Versions: 11 (recurrent) / 14 (non-recurrent)
- Data Types: FP32 tensors
- Batch Dimension: Configurable dynamic axes
Model Architecture Support
- ✅ BERT-form Transformers: Full support for attention-based policies
- ✅ Reference Observations: Multi-modal input handling
- ✅ LSTM Networks: Recurrent policy support
- ✅ Observation Normalization: Embedded preprocessing
Deployment Targets
- 🖥️ ONNX Runtime: CPU/GPU inference
- 📱 Mobile Platforms: iOS/Android deployment
- 🔧 Edge Devices: Embedded systems and microcontrollers
- ☁️ Cloud Services: Scalable inference endpoints
💡 Best Practices
🎯 Export Optimization
- Model Preparation: Ensure model is in evaluation mode before export
- Device Placement: Export process automatically moves model to CPU
- Input Shapes: Zero tensors are used to trace the computation graph
- Normalizer Integration: Include observation normalizers in the export
📊 Quality Assurance
- Validation: Test exported ONNX model against original PyTorch version
- Performance: Benchmark inference speed on target deployment platform
- Compatibility: Verify ONNX runtime version compatibility
- Precision: Check numerical precision matches original model
🚀 Deployment Workflow
# 1. Export trained policy
export_policy_as_onnx(policy, "./models", normalizer, "robot.onnx")
# 2. Load in ONNX runtime
import onnxruntime as ort
session = ort.InferenceSession("./models/robot.onnx")
# 3. Run inference
outputs = session.run(None, {"obs": observation_data})
actions = outputs[0]
🔗 Related Components
- Actor-Critic Models: Source models for export
- Observation Normalizers: Preprocessing components
- ONNX Runtime: Deployment inference engine
- Transformer Architectures: BERT-form policy networks