专栏名称: 3DCV
关注工业3D视觉、SLAM、自动驾驶技术,更专注3D视觉产业的信息传播和产品价值的创造,深度聚焦于3D视觉传感器、SLAM产品,使行业产品快速连接消费者。
51好读  ›  专栏  ›  3DCV

代码逐行解析 | 教你在C++中使用深度学习提取特征点

3DCV  · 公众号  ·  · 2024-03-12 20:34

正文

点击下方 卡片 ,关注 「3DCV」 公众号
选择 星标 ,干货第一时间送达

点击加入「3DCV」技术交流群

作者:泡椒味的口香糖 | 来源:3DCV

在公众号「3DCV」后台,回复「原论文」可获取论文pdf

添加微信:dddvision,备注:SLAM,拉你入群。文末附行业细分群

0. 写在前面

使用深度学习提取特征点的SLAM系统已经很多了,典型工作就是GCN-SLAM和SuperPoint-SLAM。感觉深度学习特征点相较传统ORB、SIFT这类特征点,主要优势在于重复性和鲁棒性,特征点的精度明显提升。缺点就是需要GPU,模型前向推理和NMS的计算本身也非常耗时。而且深度学习特征点的泛化性很差,也很难学习旋转不变性,在大旋转变化的匹配一般都不太好,当然REKD等方案也在专门研究旋转情况。

Anyway,任何方法都有自己的优劣,都有适合的场景。今天笔者将记录在C++上部署深度学习SiLK特征(ICCV 2023)的过程,并附上代码的逐行注释。

1. SiLK论文信息

标题 :SiLK -- Simple Learned Keypoints
作者机构 :Pierre Gleize, et al.Meta AI
来源 :ICCV 2023
论文 :https://arxiv.org/abs/2304.06194
代码 :https://github.com/facebookresearch/silk

2. 模型权重导出

深度学习都是在Pytorch/Tensorflow等Python框架上训练,想在C++上运行需要将模型导出为TorchScript。python环境导出模型权重的示例代码如下:

image = load_image(IMAGE_0_PATH)
model = get_model()
script_model = torch.jit.trace(model, image)
torch.jit.save(script_model, OUTPUT_MODEL)

笔者已经替大家导出好了.pt文件,如有需要可以加入 3D视觉从入门到精通知识星球 三天后获取!

在C++上运行.pt文件需要Libtorch环境,需要根据自己电脑的Cuda来下载对应版本,也可以下载CPU版本的Libtorch不使用GPU。直接下载到本地然后在CMakeLists.txt中定义一下路径即可,不用编译和安装。注意需要引用头文件:

#include 

3. 运行代码及编译

以下是使用SiLK模型提取特征点的完整代码,包含模型加载、数据预处理、前向推导、数据转换等步骤,已经做了详细的注释。注意这里使用的SiLK特征已经将NMS过程封装到了权重文件里,读者想自己实现NMS的话可以在导出权重的时候选择一下。

#include 
#include  // 包含 OpenCV 库
#include  // 包含 PyTorch 脚本解析器

using namespace std;
using namespace cv;

int main(int argc, char** argv) {

    if (argc != 2) { // 检查输入参数数量是否正确
        cout <"usage: ./extract_point img" <        return 1;
    }
    // 加载图像
    Mat image = imread(argv[1], CV_LOAD_IMAGE_COLOR); 
    assert(image.data && "Can not load image!"); 
    // 转换为灰度图
    Mat mImGray = image.clone(); // 深拷贝避免修改原始数据
    cvtColor(mImGray, mImGray, COLOR_RGB2GRAY); // 转换为灰度图像

    vector keypoints; // 存储特征点
    torch::jit::script::Module module = torch::jit::load("SiLK.pt", torch::kCUDA); // 加载 SiLK 模型
    mImGray.convertTo(mImGray, CV_32FC1, 1.f / 255.f, 0); // 转换图像数据类型并归一化
    int img_height = mImGray.rows, img_width = mImGray.cols; // 获取图像尺寸
    vector dims = { 1, img_height, img_width, 1 }; // 定义张量维度
    auto img_var = torch::from_blob(mImGray.data, dims, torch::kFloat32).to(torch::kCUDA); // 创建 PyTorch 张量并移动到 GPU
    img_var = img_var.permute({ 0,3,1,2 }); // 调整张量维度顺序
    vector<:jit::ivalue> inputs; // 存储输入参数
    inputs.push_back(img_var); // 添加图像张量作为输入参数
    auto output = module.forward(inputs).toTuple(); // 模型前向传播
    auto pts = output->elements()[0].toTuple()->elements()[0].toTensor().to(torch::kCPU); // 获取特征点张量
    auto desc = output->elements()[1].toTuple()->elements()[0].toTensor().to(torch::kCPU); // 获取特征描述子张量
    cv::Mat pts_mat(cv::Size(3, pts.size(0)), CV_32FC1, pts.data_ptr<float>()); // 创建特征点矩阵
    cv::Mat descriptors(desc.size(0), 128, CV_32FC1); // 创建描述子矩阵
    cv::Mat Confidence(pts.size(0), 3, CV_32FC1); // 创建置信度矩阵
    for (int i = 0; i         keypoints.push_back(cv::KeyPoint(pts_mat.at<float>(i, 1), pts_mat.at<float>(i, 0), 1.0f)); // 将特征点添加到容器中
        Confidence.at<float>(i, 2) = pts_mat.at<float>(i, 2); // 提取置信度信息
        for (int j = 0; j             descriptors.at<float>(i, j) = desc[i][j].item().toFloat(); // 提取描述子信息
    }
    drawKeypoints(image, keypoints, image, Scalar(255, 255, 0), DrawMatchesFlags::DRAW_RICH_KEYPOINTS); // 在图像上绘制特征点
    imwrite("image_with_SiLK.jpg", image); // 保存带有特征点的图像

    return 0;
}

代码运行的CMakeLists.txt文件如下,读者可以直接复制运行:

cmake_minimum_required(VERSION 3.0)
project(pose_recover)

# 设置C++编译标准
set(CMAKE_CXX_STANDARD 14)

# 寻找OpenCV库
find_package(OpenCV REQUIRED)

# Libtorch库
set(TORCH_PATH ${PROJECT_SOURCE_DIR}/libtorch)
set(CMAKE_PREFIX_PATH ${PROJECT_SOURCE_DIR}/libtorch)
find_package(Torch REQUIRED)

# 包含头文件目录
include_directories(${OpenCV_INCLUDE_DIRS} 
${Torch_INCLUDE_DIRS})

# 添加可执行文件
add_executable(extract_point extract_point.cpp)

# 链接OpenCV库
target_link_libraries(extract_point ${OpenCV_LIBS} ${TORCH_LIBRARIES})

4. 效果展示

至此已完成了SiLK特征的提取,包含特征点坐标、描述子、置信度,下篇文章我们将继续讨论如何在C++中使用深度学习模型来做帧间的特征匹配。

本文仅做学术分享,如有侵权,请联系删文。

3D视觉精品课程:
3dcver.com

3DGS、NeRF、结构光、相位偏折术、机械臂抓取、点云实战、Open3D、缺陷检测、BEV感知、Occupancy、Transformer、模型部署、3D目标检测、深度估计、多传感器标定、规划与控制、无人机仿真、三维视觉C++、三维视觉python、dToF、相机标定、ROS2、机器人控制规划、LeGo-LAOM、多模态融合SLAM、LOAM-SLAM、室内室外SLAM、VINS-Fusion、ORB-SLAM3、MVSNet三维重建、colmap、线面结构光、硬件结构光扫描仪。

▲长按扫码学习3D视觉精品课程

3D视觉学习圈子

3D视觉从入门到精通知识星球 、国内成立最早、6000+成员交流学习。包括: 星球视频课程近20门(价值超6000) 项目对接 3D视觉学习路线总结 最新顶会论文&代码 3D视觉行业最新模组 3D视觉优质源码汇总 书籍推荐 编程基础&学习工具 实战项目&作业 求职招聘&面经&面试题 等等。欢迎加入3D视觉从入门到精通知识星球,一起学习进步。

▲长按扫码加入星球

3D视觉交流群







请到「今天看啥」查看全文