AI老铁 · 2020年12月17日

双目测距系列(七)monodepth2训练前数据集准备过程的简析

转载自:双目测距系列(七)monodepth2训练前数据集准备过程的简析
作者:ltshan139

前言

上一篇文章说过,monodepth2模型有三种训练方式。针对我们的双目场景,准备使用stereo training方法。

monodepth2的训练入口函数在train.py中,如下图所示。

总共就2行代码,第一行代码(类Trainer的构造函数)主要是来初始化和数据集准备;第二行代码(Trainer类的成员函数)是真正执行训练过程。

下文将结合代码讲解数据集准备过程。

数据加载

在Train()构造函数中,首先会对Trainer类成员变量进行初始化。这里会摘取重点部分进行讲解。

1)

    self.num_scales = len(self.opt.scales)
    self.num_input_frames = len(self.opt.frame_ids)

代码中的opt是对options.py中的参数parse得到的dict。其参数对应值可以通过运行train.py脚本时输入参数来进行设置,如下所示。如果在运行train.py时没有显示指定参数值,那么该参数就对应使用缺省值。

python train.py --frame_ids 0 --use_stereo
回到代码,因为在运行train.py时没有输入scales参数,所以其为缺省值[0,1,2,3],其含义是在encoder和decoder时进行4级缩小和放大的多尺度,其倍数分别对应为1, 2, 4, 8。

frame_ids的缺省值为[0,-1,1],这里如果采用stereo training的话 要显示输入参数:--frame_ids 0,即当前图片,而不考虑它的时间域上的上一帧和下一帧。

2)

if self.opt.use_stereo:

        self.opt.frame_ids.append("s")

如果是stereo training,那么需要显示添加参数--use_stereo,这样上面代码if条件为true, frame_ids就变成了["0", "s"]

3)接下来就到了数据加载部分

    datasets_dict = {"kitti": datasets.KITTIRAWDataset,
                     "kitti_odom": datasets.KITTIOdomDataset}
    self.dataset = datasets_dict[self.opt.dataset]

KITTI数据集有两个子类型:KITTIRAW和KITTIOdom,monodepth使用的是前者,本系列四(https://blog.csdn.net/ltshan1...)有专门对它进行说明。

    fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt")

    train_filenames = readlines(fpath.format("train"))
    val_filenames = readlines(fpath.format("val"))
    img_ext = '.png' if self.opt.png else '.jpg'

上面第一行代码来获取train和valid的文件路径:fpath。在monodepth2开源项目根目录下有一个splits的子目录,然后在它的下面又分了eigen, eigen_full和eigen_zhou等子目录,最后每个子目录下才带有train_files.txt和val_files.txt。其目录结构如下所示:

根据github上的readme,单目训练时推荐用的是eigen_zhou,双目用的是eigen_full。

最后一行img_ext用来显示告诉当前训练和验证样本图片的格式是png还是jpg。

    train_dataset = self.dataset(
        self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
        self.opt.frame_ids, 4, is_train=True, img_ext=img_ext)
    self.train_loader = DataLoader(
        train_dataset, self.opt.batch_size, True,
        num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
    val_dataset = self.dataset(
        self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
        self.opt.frame_ids, 4, is_train=False, img_ext=img_ext)
    self.val_loader = DataLoader(
        val_dataset, self.opt.batch_size, True,
        num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)

上面的代码就是真正数据加载部分。因为train和valid数据加载原理一样,而且DatalLoader是pytorch的API,没啥好讲的,所以这里主要分析下train_dataset = self.dataset(...)的运行过程。

前面已经讲过了 self.dataset=datasets.KITTIRAWDataset。调用self.dataset(...)实际上调用的是datasets.KITTIRAWDataset的构造函数,如下所示。

class KITTIRAWDataset(KITTIDataset):

"""KITTI dataset which loads the original velodyne depth maps for ground truth
"""
def __init__(self, *args, **kwargs):
    super(KITTIRAWDataset, self).__init__(*args, **kwargs)

其构造函数只有一行代码: super(KITTIRAWDataset, self).__init__(args, *kwargs),实际上它会调用其父类KITTIDataset的构造函数,如下所示。

class KITTIDataset(MonoDataset):

"""Superclass for different types of KITTI dataset loaders
"""
def __init__(self, *args, **kwargs):
    super(KITTIDataset, self).__init__(*args, **kwargs)
    。。。 。。。

里面的super函数又会调用KITTIDataset的父类MonoDataset的构造函数。

class MonoDataset(data.Dataset):

"""Superclass for monocular dataloaders
Args:
    data_path
    filenames
    height
    width
    frame_idxs
    num_scales
    is_train
    img_ext
"""
def __init__(self,
             data_path,
             filenames,
             height,
             width,
             frame_idxs,
             num_scales,
             is_train=False,
             img_ext='.jpg'):
    super(MonoDataset, self).__init__()

    self.data_path = data_path
    self.filenames = filenames
    self.height = height
    self.width = width
    self.num_scales = num_scales
    self.interp = Image.ANTIALIAS

    self.frame_idxs = frame_idxs

    self.is_train = is_train
    self.img_ext = img_ext

    self.loader = pil_loader
    self.to_tensor = transforms.ToTensor()
    。。。 。。。

注意,self.dataset(。。。)所带的实参全部赋值给了MonoDataset(。。。),比如说data_path, filenames等。相当于把全部训练和验证样本文件名拿到了,以便后面训练时一个一个batch来从数据集里面随机抽取。

MonoDataset的构造函数运行完成后再回到KITTIDataset的构造函数剩余部分执行。


推荐阅读



更多海思AI芯片方案学习笔记欢迎关注海思AI芯片方案学习

推荐阅读
关注数
16760
内容数
1233
嵌入式端AI,包括AI算法在推理框架Tengine,MNN,NCNN,PaddlePaddle及相关芯片上的实现。欢迎加入微信交流群,微信号:aijishu20(备注:嵌入式)
目录
极术微信服务号
关注极术微信号
实时接收点赞提醒和评论通知
安谋科技学堂公众号
关注安谋科技学堂
实时获取安谋科技及 Arm 教学资源
安谋科技招聘公众号
关注安谋科技招聘
实时获取安谋科技中国职位信息