mmpose-rtmpose onnx 的后处理

mmpose-rtmpose onnx 的后处理 shui 2024-01-18 15:13:19 320

RTMpose 的前后处理

RRTMPose-l onnx的

预处理

根据输入的size将图片转成相应的size[384,288,3]

def preprocess(
    img: np.ndarray, input_size: Tuple[int, int] = (192, 256)
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Do preprocessing for RTMPose model inference.

    Args:
        img (np.ndarray): Input image in shape.
        input_size (tuple): Input image size in shape (w, h).

    Returns:
        tuple:
        - resized_img (np.ndarray): Preprocessed image.
        - center (np.ndarray): Center of image.
        - scale (np.ndarray): Scale of image.
    """
    # get shape of image
    img_shape = img.shape[:2]
    bbox = np.array([0, 0, img_shape[1], img_shape[0]])

    # get center and scale
    center, scale = bbox_xyxy2cs(bbox, padding=1.25)

    # do affine transformation
    resized_img, scale = top_down_affine(input_size, scale, center, img)

    # normalize image
    mean = np.array([123.675, 116.28, 103.53])
    std = np.array([58.395, 57.12, 57.375])
    resized_img = (resized_img - mean) / std

    return resized_img, center, scale

后处理功能

def postprocess(outputs: List[np.ndarray],
                model_input_size: Tuple[int, int],
                center: Tuple[int, int],
                scale: Tuple[int, int],
                simcc_split_ratio: float = 2.0
                ) -> Tuple[np.ndarray, np.ndarray]:
    """Postprocess for RTMPose model output.

    Args:
        outputs (np.ndarray): Output of RTMPose model.
        model_input_size (tuple): RTMPose model Input image size.
        center (tuple): Center of bbox in shape (x, y).
        scale (tuple): Scale of bbox in shape (w, h).
        simcc_split_ratio (float): Split ratio of simcc.

    Returns:
        tuple:
        - keypoints (np.ndarray): Rescaled keypoints.
        - scores (np.ndarray): Model predict scores.
    """
    # use simcc to decode
    simcc_x, simcc_y = outputs
    keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)

    # rescale keypoints
    keypoints = keypoints / model_input_size * scale + center - scale / 2

    return keypoints, scores

导出带有后处理的onnx

首先我们要把SIMCC解码部分用pytorch来实现。模型的直接输出是x和y两个方向上的预测向量,只需要取两个向量的最大值的index即可获得关键点的坐标。而关键点的置信度是取 max(max(x), max(y))。代码如下:

max_val_x, x_locs = torch.max(simcc_x, dim=2)  # x方向上最大值和坐标
max_val_y, y_locs = torch.max(simcc_y, dim=2)  # y方向上最大值和坐标
scores = torch.maximum(max_val_x, max_val_y)  # 置信度取两个方向上最大值中最大的那个
keypoints = torch.stack([x_locs, y_locs], dim=-1)  # 合并x, y坐标表示
keypoints = keypoints.float() / simcc_split_ratio  # 最终的坐标需要除以采样倍率

完整的导出代码如下:

# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import torch
import torch.nn as nn
import onnx
from onnxsim import simplify

from mmpose.apis import init_model


class RTMPoseWithDecode(nn.Module):
    def __init__(self, config, checkpoint):
        super().__init__()
        self.detector = init_model(config, checkpoint, 'cpu')

    def forward(self, x):
        simcc_x, simcc_y = self.detector.forward(x, None)

        max_val_x, x_locs = torch.max(simcc_x, dim=2)
        max_val_y, y_locs = torch.max(simcc_y, dim=2)
        scores = torch.maximum(max_val_x, max_val_y)
        keypoints = torch.stack([x_locs, y_locs], dim=-1)
        keypoints = keypoints.float() / self.detector.cfg.codec.simcc_split_ratio

        return keypoints, scores


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument('save_path', help='onnx save path')
    parser.add_argument(
        '--input_size',
        type=int,
        nargs=2,
        default=[192, 256],
        help='network input size')
    parser.add_argument('--opset', type=int, default=11, help='opset version')
    args = parser.parse_args()
    return args


def export(args):
    model = RTMPoseWithDecode(args.config, args.checkpoint)
    dummy_image = torch.zeros((1, 3, *args.input_size[::-1]), device='cpu')

    torch.onnx.export(
        model,
        dummy_image,
        args.save_path,
        input_names=['input'],
        dynamic_axes={'input': {
            0: 'batch'
        }})

    # 使用onnx simplify简化模型,当前没用
    # onnx_model = onnx.load(args.save_path)
    # onnx_model_simp, check = simplify(onnx_model)
    # assert check, 'Simplified ONNX model could not be validated'
    # onnx.save(onnx_model_simp, args.save_path)


if __name__ == '__main__':
    args = parse_args()
    export(args)
声明:本文内容由易百纳平台入驻作者撰写,文章观点仅代表作者本人,不代表易百纳立场。如有内容侵权或者其他问题,请联系本站进行删除。
shui
红包 2 收藏 评论 打赏
评论
0个
内容存在敏感词
手气红包
    易百纳技术社区暂无数据
相关专栏
置顶时间设置
结束时间
删除原因
  • 广告/SPAM
  • 恶意灌水
  • 违规内容
  • 文不对题
  • 重复发帖
打赏作者
易百纳技术社区
shui
您的支持将鼓励我继续创作!
打赏金额:
¥1易百纳技术社区
¥5易百纳技术社区
¥10易百纳技术社区
¥50易百纳技术社区
¥100易百纳技术社区
支付方式:
微信支付
支付宝支付
易百纳技术社区微信支付
易百纳技术社区
打赏成功!

感谢您的打赏,如若您也想被打赏,可前往 发表专栏 哦~

举报反馈

举报类型

  • 内容涉黄/赌/毒
  • 内容侵权/抄袭
  • 政治相关
  • 涉嫌广告
  • 侮辱谩骂
  • 其他

详细说明

审核成功

发布时间设置
发布时间:
是否关联周任务-专栏模块

审核失败

失败原因
备注
拼手气红包 红包规则
祝福语
恭喜发财,大吉大利!
红包金额
红包最小金额不能低于5元
红包数量
红包数量范围10~50个
余额支付
当前余额:
可前往问答、专栏板块获取收益 去获取
取 消 确 定

小包子的红包

恭喜发财,大吉大利

已领取20/40,共1.6元 红包规则

    易百纳技术社区