专栏名称: chopper_bbf4
目录
相关文章推荐
笛扬新闻  ·  太突然!暴涨489%! ·  昨天  
青岛新闻网  ·  已确认去世!千万粉丝网红发文道歉 ·  3 天前  
青岛新闻网  ·  已确认去世!千万粉丝网红发文道歉 ·  3 天前  
51好读  ›  专栏  ›  chopper_bbf4

pytorch 模型训练详解

chopper_bbf4  · 简书  ·  · 2021-01-23 23:52

正文

建立模型并进行训练主要分5步:

  1. 数据处理
    1.1 数据收集 : Img,Label
    1.2 数据划分:train,valid,test【训练集,验证集。测试集】
    1.3 数据读取:DateLoader【Sampler:Index;Dataset:Img,Label】
    1)读哪些数据
    了解数据内容及数据文件包位置,通过index知道读取哪些数据,Sampler输出的Index
    2)从哪读数据
    设置数据存储的路径【分为绝对和相对路径,建议使用相对路径】,Dataset中的data_dir
    3)怎么读数据
    通过Dateset读取本地数据,getitem
    1.4 数据预处理:transforms
    torchvision计算机视觉工具包
    包含transforms,datasets和model模块
    transforems常用预处理方法:
    数据中心化,数据标准化,缩放,裁剪,旋转,翻转,填充,噪声增加,灰度变换,线性变换,仿射变换,亮度、饱和度及对比度变换
    用于数据预处理及增强,目的为提高泛化能力
    具体22种方法:
    https://blog.csdn.net/qq_38410428/article/details/94719553
    例子:
transforms.Compose([transforms.resize((32,32)),transforms.RandomCrop(32,padding=4),tansform.ToTensor(),tansforms.Normalize(norm_mean,norm_std)])
# 缩放,随机裁剪,变成张量形式,数据标准化

核心概念:数据增强:对训练集进行变换,使训练集更丰富,具有泛化能力

  1. 模型建立

  2. 选择合适的损失函数
    损失函数的目的是计算梯度,即计算得到的值和原标签值相差
    根据不同数据及训练目标,其所使用的损失函数也不尽相同

  3. 选择合适的优化器
    优化器即使用合适的方式对损失函数进行向后传播,常用SGD,Adam
    详见: https://blog.csdn.net/weixin_40170902/article/details/80092628

  4. 迭代训练
    主要分epoch,iter,batch-size
    epoch: 全部训练样本训练一次
    iter: 每个batch训练完一次
    batch-size: 一同放入内存进行计算的数据量大小