专栏名称: 视学算法
公众号专注于人工智能 | 机器学习 | 深度学习 | 计算机视觉 | 自然语言处理等前沿论文和基础程序设计等算法。地球不爆炸,算法不放假。
目录
相关文章推荐
手游那点事  ·  破纪录!这个被腾讯看中的IP,哪怕过去了20 ... ·  3 天前  
手游那点事  ·  这一次,腾讯游戏将「球」交到了年轻人手里! ·  4 天前  
手游那点事  ·  起飞了!Supercell今年第一款新作曝光 ... ·  4 天前  
51好读  ›  专栏  ›  视学算法

实操教程 | 深度学习pytorch训练代码模板(个人习惯)

视学算法  · 公众号  ·  · 2022-09-15 09:55

正文

点击上方 视学算法 ”,选择加" 星标 "或“ 置顶

重磅干货,第一时间送达

作者丨 wfnian@知乎(已授权)
来源丨 https://zhuanlan.zhihu.com/p/396666255
编辑丨极市平台

导读

本文从参数定义,到网络模型定义,再到训练步骤,验证步骤,测试步骤,总结了一套较为直观的模板。

目录如下:

  1. 导入包以及设置随机种子
  2. 以类的方式定义超参数
  3. 定义自己的模型
  4. 定义早停类(此步骤可以省略)
  5. 定义自己的数据集Dataset,DataLoader
  6. 实例化模型,设置loss,优化器等
  7. 开始训练以及调整lr
  8. 绘图
  9. 预测

一、导入包以及设置随机种子

import numpy as np
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import random
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

二、以类的方式定义超参数

class argparse():
pass

args = argparse()
args.epochs, args.learning_rate, args.patience = [30, 0.001, 4]
args.hidden_size, args.input_size= [40, 30]
args.device, = [torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),]

三、定义自己的模型

class Your_model(nn.Module):
def __init__(self):
super(Your_model, self).__init__()
pass

def forward(self,x):
pass
return x

四、定义早停类(此步骤可以省略)

class EarlyStopping():
def __init__(self,patience=7,verbose=False,delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self,val_loss,model,path):
print("val_loss={}".format(val_loss))
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss,model,path)
elif score < self.best_score+self.delta:
self.counter+=1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter>=self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss,model,path)
self.counter = 0
def save_checkpoint(self,val_loss,model,path):
if self.verbose:
print(
f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth')
self.val_loss_min = val_loss

五、定义自己的数据集Dataset,DataLoader

class Dataset_name(Dataset):
def __init__(self, flag='train'):
assert flag in ['train', 'test', 'valid']
self.flag = flag
self.__load_data__()

def __getitem__(self, index):
pass
def __len__(self):
pass

def __load_data__(self, csv_paths: list):
pass
print(
"train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n"
.format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))

train_dataset = Dataset_name(flag='train')
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_dataset = Dataset_name(flag='valid')
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)

六、实例化模型,设置loss,优化器等

model = Your_model().to(args.device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)

train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []

early_stopping = EarlyStopping(patience=args.patience,verbose=True)

七、开始训练以及调整lr

for epoch in range(args.epochs):
Your_model.train()
train_epoch_loss = []
for idx,(data_x,data_y) in enumerate(train_dataloader,0):
data_x = data_x.to(torch.float32).to(args.device)
data_y = data_y.to(torch.float32).to(args.device)
outputs = Your_model(data_x)






请到「今天看啥」查看全文