.pth
转 ONNX:从模型训练到跨平台部署在深度学习里,模型的格式决定了它的可用性。
如果你是 PyTorch 用户,你可能熟悉 .pth
文件,它用于存储训练好的模型。
但当你想在不同的环境(如 TensorRT、OpenVINO、ONNX Runtime)部署模型时,.pth
可能并不适用。这时,ONNX(Open Neural Network Exchange)就必不可少。
本文目录:
.pth
文件?.onnx
文件?.pth
到 .onnx
?.pth
文件?.pth
是 PyTorch 专属的模型权重文件,用于存储:
torch.save(model, "model.pth")
保存的情况。在 PyTorch 中,你可以用以下方式加载 .pth
:
import torch
from NestedUNet import NestedUNet # 你的模型类
# 仅保存权重的加载方式
model = NestedUNet(num_classes=2, input_channels=3)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
.pth
文件只能在 PyTorch 运行的环境中使用,不能直接在 TensorFlow、OpenVINO 或 TensorRT 里运行。
ONNX(Open Neural Network Exchange)是 一个开放的神经网络标准格式,它的目标是:
ONNX 文件是一个 .onnx
文件,它包含:
ONNX 让你可以在不同平台上运行同一个模型,而不必依赖某个特定的深度学习框架。
.pth
到 .onnx
?转换为 ONNX 主要有以下好处:
✅ 跨平台兼容
.pth
只能在 PyTorch 里用,而 .onnx
可以在 TensorRT、ONNX Runtime、OpenVINO、CoreML 等多种环境中运行。✅ 推理速度更快
✅ 支持多种硬件
.pth
主要用于 CPU/GPU,而 .onnx
可用于 FPGA、TPU、ARM 设备,如 安卓手机、树莓派、Jetson Nano 等。✅ 更轻量级
.pth
到 .onnx
?在转换前,确保你已安装 PyTorch 和 ONNX:
pip install torch torchvision onnx
假设你有一个 NestedUNet
训练好的 .pth
文件,转换方式如下:
import torch
import torch.onnx
from NestedUNet import NestedUNet # 你的模型文件
# 1. 加载 PyTorch 模型
model = NestedUNet(num_classes=2, input_channels=3, deep_supervision=False)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
# 2. 创建示例输入(确保形状正确)
dummy_input = torch.randn(1, 3, 256, 256)
# 3. 导出为 ONNX
onnx_path = "nested_unet.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_path,
export_params=True,
opset_version=11, # 确保兼容性
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)
print(f"✅ 模型已成功转换为 {onnx_path}")
安装 onnxruntime
并测试:
pip install onnxruntime
然后运行:
import onnxruntime as ort
import numpy as np
# 加载 ONNX
ort_session = ort.InferenceSession("nested_unet.onnx")
# 生成随机输入
input_data = np.random.randn(1, 3, 256, 256).astype(np.float32)
outputs = ort_session.run(None, {"input": input_data})
print("ONNX 推理结果:", outputs[0].shape)
✅ 提高推理速度
✅ 跨平台部署
.onnx
可用于 Windows、Linux、安卓、iOS、嵌入式设备。✅ 减少依赖
⚠ ONNX 可能不支持某些 PyTorch 操作
grid_sample
)可能在 ONNX 不支持,需要手动修改模型。⚠ ONNX 的 Upsample
可能需要 align_corners=False
如果 Upsample(scale_factor=2, mode='bilinear', align_corners=True)
,可能会导致 ONNX 兼容性问题,建议改为:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
⚠ ONNX 在 CPU 上的推理可能比 PyTorch 慢
⚠ TensorRT 需要额外优化
直接用 TensorRT 运行 ONNX 可能会报错,需要 onnx-simplifier
:
pip install onnx-simplifier
python -m onnxsim nested_unet.onnx nested_unet_simplified.onnx
比较项 | .pth (PyTorch) | .onnx (ONNX) |
---|---|---|
框架依赖 | 仅支持 PyTorch | 兼容多框架 |
推理速度 | 较慢 | 更快(ONNX Runtime / TensorRT) |
跨平台性 | 仅支持 PyTorch | 可在多种设备上运行 |
部署难度 | 需要完整 Python | 轻量级,适用于嵌入式 |
👉 建议
此文由 Mix Space 同步更新至 xLog
原始链接为 https://blog.kanes.top/posts/default/pth2ONNX