边缘AI部署全链路:从TensorRT到ONNX的模型优化

引言:边缘部署的核心挑战

关键指标

指标 要求
延迟 <30ms(30fps)
内存 <2GB(嵌入式)
功耗 <5W(车规)
精度损失 <1%

一、模型量化

1.1 量化类型

类型 精度 模型大小 速度
FP32 最高 最大 最慢
FP16 中等 50% 快2倍
INT8 25% 快4倍

1.2 INT8量化流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn

class Quantizer:
"""
模型量化器
"""
def __init__(self):
# 量化配置
self.calibrator = calibrate()
self.quant_config = {
'weight_dtype': torch.qint8,
'activation_dtype': torch.qint8,
'scheme': 'per_channel' # 逐通道量化
}

def quantize(self, model, calibrate_data):
"""
量化模型
"""
# 1. 收集校准数据
calibration_stats = self.collect_calibration_stats(
model, calibrate_data
)

# 2. 生成量化模型
quantized_model = torch.quantization.quantize_dynamic(
model,
calibration_stats,
**self.quant_config
)

return quantized_model

def collect_calibration_stats(self, model, calibrate_data):
"""
收集校准统计
"""
stats = {}

for name, module in model.named_modules():
# 只量化卷积层和全连接层
if isinstance(module, (nn.Conv2d, nn.Linear)):
# 收集激活值范围
activation_range = self.collect_activation_range(
module, calibrate_data
)

stats[name] = activation_range

return stats

def collect_activation_range(self, module, calibrate_data):
"""
收集激活值范围
"""
activations = []

for data in calibrate_data:
# 前向传播收集激活值
with torch.no_grad():
output = module(data)
activations.append(output.detach())

# 计算min/max
activations_tensor = torch.cat(activations, dim=0)

min_val = torch.min(activations_tensor)
max_val = torch.max(activations_tensor)

return {
'min': min_val.item(),
'max': max_val.item(),
'range': max_val - min_val
}

1.3 混合精度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class HybridQuantization:
"""
混合精度量化
"""
def __init__(self):
# 关键层FP16,其他INT8
self.critical_layers = {
'backbone.conv1': torch.float16,
'backbone.conv2': torch.float16,
'neck': torch.float16
}

self.quant_config = {
'default': torch.qint8
}

def quantize(self, model):
"""
混合精度量化
"""
quantized_modules = {}

for name, module in model.named_modules():
# 检查是否为关键层
if name in self.critical_layers:
# FP16量化
quantized_modules[name] = self.quantize_fp16(module)
else:
# INT8量化
quantized_modules[name] = self.quantize_int8(module)

return quantized_modules

def quantize_fp16(self, module):
"""
FP16量化
"""
# 转换权重
module.weight.data = module.weight.data.half()
if module.bias is not None:
module.bias.data = module.bias.data.half()

return module

二、框架对比

2.1 TensorRT

特点

特点 说明
硬件加速 NVIDIA GPU/Orin专用
内核融合 自动优化算子
动态形状 支持变长输入
INT8优化 硬件加速

部署流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import tensorrt as trt

class TensorRTDeployer:
"""
TensorRT部署器
"""
def __init__(self, engine_path):
# 加载引擎
self.logger = trt.Logger(trt.Logger.WARNING)
self.runtime = trt.Runtime(self.logger)

with open(engine_path, 'rb') as f:
self.engine = self.runtime.deserialize_cuda_engine(f.read())

# 创建执行上下文
self.context = self.engine.create_execution_context(
stream=None,
device='cuda'
)

def infer(self, input_data):
"""
推理
"""
# 1. 创建输入缓冲区
input_buffer = input_data.cuda()

# 2. 分配输出缓冲区
output_buffer = torch.empty(
self.engine.max_batch_size,
self.engine.max_shape,
dtype=torch.float32
).cuda()

# 3. 执行推理
self.context.execute_v2(
bindings=[
input_buffer,
output_buffer
],
stream=self.stream
)

# 4. 复制到CPU
result = output_buffer.cpu()

return result

2.2 ONNX Runtime

特点

特点 说明
跨平台 Windows/Linux/Android/Mac
轻量级 最小<5MB
硬件支持 CPU/GPU/NPU

部署流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import onnxruntime as ort

class ONNXDeployer:
"""
ONNX运行时部署器
"""
def __init__(self, model_path):
# 创建推理会话
self.session = ort.InferenceSession(
model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)

# 获取输入输出信息
self.input_name = self.session.get_inputs()[0].name
self.output_names = [o.name for o in self.session.get_outputs()]

def infer(self, input_data):
"""
推理
"""
# 1. 准备输入
inputs = {self.input_name: input_data}

# 2. 执行推理
outputs = self.session.run(
None,
inputs,
self.output_names
)

return outputs

2.3 TFLite

特点

特点 说明
移动优化 Android/iOS专用
NPU支持 硬件神经网络加速
极小尺寸 <1MB

部署流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import tflite as tfl

class TFLiteDeployer:
"""
TFLite部署器
"""
def __init__(self, model_path):
# 加载模型
self.interpreter = tfl.Interpreter(
model_path=model_path,
num_threads=4
)

# 分配张量
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()

def infer(self, input_data):
"""
推理
"""
# 1. 分配输入张量
input_tensor = self.interpreter.allocate_tensors()
self.input_details[0]['index'] = input_tensor

# 2. 执行推理
self.interpreter.invoke()

# 3. 获取输出
output_tensor = self.interpreter.tensor(
self.output_details[0]['index']
)

return output_tensor.numpy()

三、性能优化技巧

3.1 算子融合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class OperatorFusion:
"""
算子融合
"""
def __init__(self, model):
# 识别可融合层
self.fused_layers = []

for module in model.modules():
if isinstance(module, nn.Sequential):
# 检查是否可以融合
if self.can_fuse(module):
self.fused_layers.append(module)

def can_fuse(self, module):
"""
判断是否可以融合
"""
# 典型模式
patterns = [
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU),
(nn.Conv2d, nn.BatchNorm2d, nn.ReLU, nn.Conv2d)
]

# 检查
module_list = list(module.children())

for pattern in patterns:
matched = True
for i, layer_type in enumerate(pattern):
if not isinstance(module_list[i], layer_type):
matched = False
break

if matched:
return True

return False

3.2 内存优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class MemoryOptimizer:
"""
内存优化器
"""
def optimize(self, model):
"""
优化内存使用
"""
# 1. 共享卷积核
for module in model.modules():
if isinstance(module, nn.Conv2d):
# 使用in-place操作
if hasattr(module, 'bias'):
# 共享bias(如果有相同形状)
if hasattr(module, 'shared_bias'):
module.bias = module.shared_bias

# 2. 梯度检查点
for module in model.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
# 移除不必要的hook
if hasattr(module, 'hook'):
delattr(module, 'hook')

return model

3.3 批处理优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class BatchOptimizer:
"""
批处理优化器
"""
def __init__(self, target_batch_size=8):
self.target_batch = target_batch_size

def optimize_inference(self, model, inputs):
"""
优化推理
"""
# 1. 动态批处理
if len(inputs) < self.target_batch:
# 填充到目标大小
inputs = self.pad_to_batch_size(inputs)

# 2. 使用torch.jit
@torch.jit.script
def forward(model, x):
return model(x)

# 3. 混合精度推理
with torch.no_grad():
# FP16推理
outputs = model.half()(inputs)

# INT8后处理
outputs = outputs.float()

return outputs

四、硬件加速

4.1 NPU/TPU

硬件 性能 适用平台
NVIDIA Orin 275 TOPS 高端车型
Qualcomm Hexagon 45 TOPS 中端车型
Google Edge TPU 4 TOPS Android设备
Intel NPU 12 TOPS 服务器端

4.2 DSP卸载

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class OffloadToDSP:
"""
DSP卸载
"""
def __init__(self, dsp_runtime):
self.dsp = dsp_runtime

def offload(self, layer):
"""
卸载到DSP
"""
# 检查是否为计算密集型层
if self.is_compute_intensive(layer):
# 生成DSP代码
dsp_code = self.generate_dsp_code(layer)

# 部署到DSP
self.dsp.deploy(dsp_code)

return True

return False

五、最佳实践

5.1 部署流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
模型训练

模型量化(FP16/INT8)

模型转换(PyTorch→ONNX→TensorRT)

性能测试

┌─────────────────────────────────┐
│ 部署选择 │
│ ├── 移动端:TFLite │
│ ├── 跨平台:ONNX │
│ ├── NVIDIA:TensorRT │
│ └── 自定义:C++ │
└─────────────────────────────────┘

┌─────────────────────────────────┐
│ 生产部署 │
│ ├── OTA更新 │
│ ├── A/B测试 │
│ ├── 监控告警 │
│ └── 回滚机制 │
└─────────────────────────────────┘

5.2 性能指标

指标 目标 测量方法
延迟 <30ms 平均P99延迟
吞吐量 >30fps 帧率测试
内存 <2GB 内存profiling
功耗 <5W 功率测试
精度 <1%损失 与FP32对比

六、总结

6.1 框架选择

| 需求 | 推荐框架 |
|——|———-|———-|
| NVIDIA GPU | TensorRT |
| 跨平台 | ONNX Runtime |
| 移动设备 | TFLite |
| 高性能定制 | 自定义C++ |

6.2 关键经验

经验 说明
量化优先 INT8比FP16节省4倍内存
FP16关键 首层FP16保持精度
内核融合 利用硬件优化算子
批处理 充分利用并行计算

参考文献

  1. NVIDIA. “TensorRT Documentation.” 2025.
  2. ONNX. “ONNX Runtime API.” 2025.
  3. TensorFlow. “TFLite Guide.” 2025.

本文是IMS边缘部署系列文章之一,上一篇:数据标注规范


边缘AI部署全链路:从TensorRT到ONNX的模型优化
https://dapalm.com/2026/03/13/2026-03-13-边缘AI部署全链路-从TensorRT到ONNX的模型优化/
作者
Mars
发布于
2026年3月13日
许可协议