专栏名称: 深度学习与神经网络
关注深度学习教育,关注人工智能前沿科技
目录
相关文章推荐
央视财经  ·  即将全面上映!“反响前所未有”→ ·  15 小时前  
EETOP  ·  HDMI 2.2 新规范抢先解读 ·  2 天前  
21世纪经济报道  ·  【南财早新闻】① ... ·  2 天前  
有限次重复博弈  ·  买得起可能养不起 ... ·  2 天前  
神嘛事儿  ·  我回答了 @紫气东来硕珩 ... ·  3 天前  
51好读  ›  专栏  ›  深度学习与神经网络

浅入浅出TensorFlow 4 — 训练CIFAR数据

深度学习与神经网络  · 公众号  ·  · 2018-04-20 09:30

正文

上篇,Amusi带着大家学习了如何 浅入浅出TensorFlow 3 — MNIST手写体识别 ,今天继续给大家介绍linolzhang大佬的 TensorFlow系列课程 ,带大家学习 训练CIFAR数据


正文


一.  CIFAR数据集

CIFAR数据集是一个经典的数据集,提供两个版本的分类样本,CIFAR-10和CIFAR-100。

CIFAR-10 提供10类标注数据,每类6000张(32*32),其中5000张用于训练,1000张用于测试。

获取数据集的方法:

1git clone https://github.com/tensorflow/models.git  

2cd models/tutorials/image/cifar10



可以看一下我们从github上down下来的数据,外面不看了,直接进tutorials/image,教程专用,看来是基础的不能再基础了。

里面提供了几个典型的数据集的 下载、训练等接口,方便直接在python里调用。

进入cifar10,能够看到:

其中文件 cifar10.py 和 cifar10_input.py 就是接下来我们要 import 的。


二.  代码实现

撸一段 Python 代码,可以View里面的注释讲解:

  1#coding=utf-8  
 2import cifar10,cifar10_input  
 3import tensorflow as tf  
 4import numpy as np  
 5import time  
 6
 7# define max_iter_step  batch_size  
 8max_iter_step = 1000  
 9batch_size = 128  
10
11# define variable_with_weight_loss  
12# 和之前定义的weight有所不同,  
13# 这里定义附带loss的weight,通过权重惩罚避免部分权重系数过大,导致overfitting  
14def variable_with_weight_loss(shape,stddev,w1):  
15    var = tf.Variable(tf.truncated_normal(shape,stddev=stddev))  
16    if w1 is not None:  
17        weight_loss = tf.multiply(tf.nn.l2_loss(var),w1,name='weight_loss')  
18        tf.add_to_collection('losses',weight_loss)  
19    return var  
20
21# 下载数据集 - 调用cifar10函数下载并解压  
22cifar10.maybe_download_and_extract()  
23# 注意路径  
24cifar_dir = './cifar-10-batches-bin'  
25
26# 采用 data augmentation进行数据处理  
27# 生成训练数据,训练数据通过cifar10_input的distort变化  
28images_train, labels_train = cifar10_input.distorted_inputs(data_dir=cifar_dir,batch_size=batch_size)  
29# 测试数据(eval_data 测试数据)  
30images_test,labels_test = cifar10_input.inputs(eval_data=True,data_dir=cifar_dir,batch_size=batch_size)  
31
32 # 创建输入数据,采用 placeholder  
33x_input = tf.placeholder(tf.float32,[batch_size,24,24,3])  
34y_input = tf.placeholder(tf.int32,[batch_size])  
35
36# 创建第一个卷积层 input:3(channel) kernel:64 size:5*5  
37weight1 = variable_with_weight_loss(shape=[5,5,3,64],stddev=5e-2,w1=0.0)  
38bias1 = tf.Variable(tf.constant(0.0,shape=[64]))  
39conv1 = tf.nn.conv2d(x_input,weight1,[1,1,1,1],padding='SAME')  
40relu1 = tf.nn.relu(tf.nn.bias_add(conv1,bias1))  
41pool1 = tf.nn.max_pool(conv1,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME')  
42norm1 = tf.nn.lrn(pool1,4,bias=1.0,alpha=0.001/9.0,beta=0.75)  
43
44# 创建第二个卷积层 input:64 kernel:64 size:5*5  
45weight2 = variable_with_weight_loss(shape=[5,5,64,64],stddev=5e-2,w1=0.0)  
46bias2 = tf.Variable(tf.constant(0,1,shape=[64]))  
47conv2 = tf.nn.conv2d(norm1,weight2,[1,1,1,1],padding='SAME')  
48relu2 = tf.nn.relu(tf.nn.bias_add(conv2,bias2))  
49norm2 = tf.nn.lrn(relu2,4,bias=1.0,alpha=0.001/9.0,beta=0.75)  
50pool2 = tf.nn.max_pool(norm2,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME')  
51
52# 创建第三个层-全连接层  output:384  
53reshape = tf.reshape(pool2,[batch_size,-1])  
54dim = reshape.get_shape()[1].value  
55weight3 = variable_with_weight_loss(shape=[dim,384],stddev=0.04,w1=0.004)  
56bias3 = tf.Variable(tf.constant(0.1,shape=[384]))  
57local3 = tf.nn.relu(tf.matmul(reshape,weight3)+bias3)  
58
59# 创建第四个层-全连接层  output:192  
60weight4 = variable_with_weight_loss(shape=[384,192],stddev=0.04,w1=0.004)  
61bias4 = tf.Variable(tf.constant(0.1,shape=[192]))  
62local4 = tf.nn.relu(tf.matmul(local3,weight4)+bias4)  
63
64# 最后一层  output:10  
65weight5 = variable_with_weight_loss(shape=[192,10],stddev=1/192.0,w1=0.0)  
66bias5 = tf.Variable(tf.constant(0.0,shape=[10]))  
67results = tf.add(tf.matmul(local4,weight5),bias5)  
68
69# 定义loss  
70def loss(results,labels):  
71    labels = tf.cast(labels,tf.int64)  
72    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=results,labels=labels,name='cross_entropy_per_example')  
73    cross_entropy_mean = tf.reduce_mean(cross_entropy,name='cross_entropy')  
74    tf.add_to_collection('losses',cross_entropy_mean)  
75    return tf.add_n(tf.get_collection('losses'),name='total_loss')  
76
77# 计算loss  
78loss = loss(results,y_input)  
79
80train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)  # Adam  
81top_k_op = tf.nn.in_top_k(results,y_input,1)  # top1 准确率  
82
83sess = tf.InteractiveSession()         # 创建session  
84# tf.global_variable_initializer().run() # 初始化全部模型  
85tf.initialize_all_variables().run()
86
87tf.train.start_queue_runners()  # 启动多线程加速  
88






请到「今天看啥」查看全文