python3.8版本conda环境
import os
import argparse
import numpy as np
from PIL import Image
import onnxsim
import onnx
import nncase
import shutil
import math
import onnx.helper
from onnx import mapping
# 解析模型的输入和输出信息
def parse_model_input_output(model_file, input_shape):
# 加载 ONNX 模型
onnx_model = onnx.load(model_file)
# 获取模型中所有输入节点的名称
input_all = [node.name for node in onnx_model.graph.input]
# 获取模型中所有初始化器节点的名称
input_initializer = [node.name for node in onnx_model.graph.initializer]
# 计算实际输入节点的名称,即不在初始化器中的输入节点
input_names = list(set(input_all) - set(input_initializer))
# 获取实际输入节点的张量信息
input_tensors = [
node for node in onnx_model.graph.input if node.name in input_names]
# 存储输入信息的列表
inputs = []
for _, e in enumerate(input_tensors):
# 获取输入张量的类型信息
onnx_type = e.type.tensor_type
input_dict = {}
# 输入节点的名称
input_dict['name'] = e.name
# 将 ONNX 张量类型映射为 NumPy 数据类型
input_dict['dtype'] = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_type.elem_type] #弃用警告
# input_dict['dtype'] = onnx.helper.tensor_dtype_to_np_dtype(onnx_type.elem_type)
# 处理输入张量的形状,将动态维度替换为指定的输入形状
input_dict['shape'] = [(i.dim_value if i.dim_value != 0 else d) for i, d in zip(
onnx_type.shape.dim, input_shape)]
inputs.append(input_dict)
return onnx_model, inputs
# 简化 ONNX 模型
def onnx_simplify(model_file, dump_dir, input_shape):
# 解析模型的输入和输出信息
onnx_model, inputs = parse_model_input_output(model_file, input_shape)
# 进行形状推断,更新模型中的形状信息
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
# 存储输入形状的字典
input_shapes = {}
for input in inputs:
input_shapes[input['name']] = input['shape']
# 简化 ONNX 模型
# onnx_model, check = onnxsim.simplify(onnx_model, input_shapes=input_shapes) #弃用警告
onnx_model, check = onnxsim.simplify(onnx_model, overwrite_input_shapes=input_shapes)
# 确保简化后的模型可以通过验证
assert check, "Simplified ONNX model could not be validated"
# 保存简化后的模型
model_file = os.path.join(dump_dir, 'simplified.onnx')
onnx.save_model(onnx_model, model_file)
return model_file
# 读取模型文件的二进制内容
def read_model_file(model_file):
with open(model_file, 'rb') as f:
model_content = f.read()
return model_content
# 生成随机数据
def generate_data_ramdom(shape, batch):
data = []
for i in range(batch):
# 生成指定形状的随机整数数据,范围为 0 到 255,数据类型为 uint8
data.append([np.random.randint(0, 256, shape).astype(np.uint8)])
return data
# 从校准数据集生成数据
def generate_data(shape, batch, calib_dir):
# 获取校准数据集中所有图像的路径
img_paths = [os.path.join(calib_dir, p) for p in os.listdir(calib_dir)]
data = []
for i in range(batch):
# 确保校准图像数量足够
assert i < len(img_paths), "calibration images not enough."
# 打开图像并转换为 RGB 格式
img_data = Image.open(img_paths[i]).convert('RGB')
# 调整图像大小
img_data = img_data.resize((shape[3], shape[2]), Image.BILINEAR)
# 将图像数据转换为 NumPy 数组
img_data = np.asarray(img_data, dtype=np.uint8)
# 调整图像数据的维度顺序
img_data = np.transpose(img_data, (2, 0, 1))
data.append([img_data[np.newaxis, ...]])
return np.array(data)
def main():
# 创建命令行参数解析器
parser = argparse.ArgumentParser(prog="nncase")
# 添加目标设备参数,默认值为 k230
parser.add_argument("--target", default="k230", type=str, help='target to run,k230/cpu')
# 添加模型文件路径参数
parser.add_argument("--model", type=str,default="C:\\Users\\DELL\\Desktop\\K230--Openmv--yolo\\yolov5-6.0\\runs\\train\\exp22\\weights\\best.onnx", help='model file')
# 添加校准数据集路径参数
parser.add_argument("--dataset", type=str,default="C:\\Users\\DELL\\Desktop\\K230--Openmv--yolo\\yolov5-6.0\\dataset\\train\\images" ,help='calibration_dataset')
# 添加输入宽度参数,默认值为 640
parser.add_argument("--input_width", type=int, default=640, help='input_width')
# 添加输入高度参数,默认值为 640
parser.add_argument("--input_height", type=int, default=640, help='input_height')
# 添加 PTQ 选项参数,默认值为 0
parser.add_argument("--ptq_option", type=int, default=0, help='ptq_option:0,1,2,3,4')
# 解析命令行参数
args = parser.parse_args()
# 更新输入宽度为 32 的倍数
input_width = int(math.ceil(args.input_width / 32.0)) * 32
# 更新输入高度为 32 的倍数
input_height = int(math.ceil(args.input_height / 32.0)) * 32
# 模型的输入形状,维度要跟 input_layout 一致
input_shape = [1, 3, input_height, input_width]
# 临时目录,用于保存简化后的模型
dump_dir = 'tmp'
if not os.path.exists(dump_dir):
os.makedirs(dump_dir)
# 简化 ONNX 模型
model_file = onnx_simplify(args.model, dump_dir, input_shape)
# 创建编译选项对象
compile_options = nncase.CompileOptions()
# 设置目标设备
compile_options.target = args.target
# available_targets = nncase.targets()
# print("Available targets:", available_targets)
compile_options.target ='k230'
# 是否采用模型做预处理
compile_options.preprocess = True
# 是否交换 RGB 通道
compile_options.swapRB = False
# 输入图像的形状
compile_options.input_shape = input_shape
# 模型输入格式,'uint8' 或者 'float32'
compile_options.input_type = 'uint8'
# 如果输入是 'uint8' 格式,输入反量化之后的范围
compile_options.input_range = [0, 1]
# 预处理的 mean/std 值,每个 channel 一个
compile_options.mean = [0, 0, 0]
compile_options.std = [1, 1, 1]
# 设置输入的 layout,onnx 默认 'NCHW' 即可
compile_options.input_layout = "NCHW"
# compile_options.output_layout = "NCHW"
# 是否保存中间表示和汇编代码
# compile_options.dump_ir = True
# compile_options.dump_asm = True
# 保存中间文件的目录
# compile_options.dump_dir = dump_dir
# 量化类型
compile_options.quant_type = 'uint8'
# 创建编译器对象
compiler = nncase.Compiler(compile_options)
# 读取模型文件的二进制内容
model_content = read_model_file(model_file)
# 创建导入选项对象
import_options = nncase.ImportOptions()
# 导入 ONNX 模型
compiler.import_onnx(model_content, import_options)
# 创建 PTQ 选项对象
ptq_options = nncase.PTQTensorOptions()
# 设置校准样本数量
ptq_options.samples_count = 20
# 根据 PTQ 选项设置不同的校准方法和量化类型
if args.ptq_option == 0:
pass
elif args.ptq_option == 1:
ptq_options.calibrate_method = 'NoClip'
ptq_options.w_quant_type = 'int16'
elif args.ptq_option == 2:
ptq_options.calibrate_method = 'NoClip'
ptq_options.quant_type = 'int16'
else:
pass
# 设置 PTQ 所需的张量数据
# ptq_options.set_tensor_data(generate_data_ramdom(input_shape, ptq_options.samples_count))
ptq_options.set_tensor_data(generate_data(input_shape, ptq_options.samples_count, args.dataset))
# 使用 PTQ 进行量化
compiler.use_ptq(ptq_options)
# 编译模型
compiler.compile()
# 生成 kmodel 二进制数据
kmodel = compiler.gencode_tobytes()
# 获取模型文件的基本名称和扩展名
base, ext = os.path.splitext(args.model)
# 生成 kmodel 文件的名称
kmodel_name = base + ".kmodel"
# 保存 kmodel 文件
with open(kmodel_name, 'wb') as f:
f.write(kmodel)
# 删除临时目录
if os.path.exists("./tmp"):
shutil.rmtree("./tmp")
# 删除模型转储目录
if os.path.exists("./gmodel_dump_dir"):
shutil.rmtree("./gmodel_dump_dir")
if __name__ == '__main__':
main()