(点击
上方蓝字
,快速关注我们)
来源: 小石头
www.duanshishi.com/?p=1753
如有好文章投稿,请点击 → 这里了解详情
Goal
本文目标是利用TensorFlow做一个简单的图像分类器,在比较大的数据集上,尽可能高效地做图像相关处理,从Train,Validation到Inference,是一个比较基本的Example, 从一个基本的任务学习如果在TensorFlow下做高效地图像读取,基本的图像处理,整个项目很简单,但其中有一些trick,在实际项目当中有很大的好处, 比如绝对不要一次读入所有的 的数据到内存(尽管在Mnist这类级别的例子上经常出现)…
最开始看到是这篇blog里面的TensorFlow练习22: 手写汉字识别(http://blog.topspeedsnail.com/archives/10897), 但是这篇文章只用了140训练与测试,试了下代码 很快,但是当扩展到所有的时,发现32g的内存都不够用,这才注意到原文中都是用numpy,会先把所有的数据放入到内存,但这个不必须的,无论在MXNet还是TensorFlow中都是不必 须的,MXNet使用的是DataIter,会在程序运行的过程中异步读取数据,TensorFlow也是这样的,TensorFlow封装了高级的api,用来做数据的读取,比如TFRecord,还有就是从filenames中读取, 来异步读取文件,然后做shuffle batch,再feed到模型的Graph中来做模型参数的更新。具体在tf如何做数据的读取可以看看reading data in tensorflow
这里我会拿到所有的数据集来做训练与测试,算作是对斗大的熊猫上面那篇文章的一个扩展。
Batch Generate
数据集来自于中科院自动化研究所,感谢分享精神!!!具体下载:
wget
http
:
//www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip
wget
http
:
//www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip
解压后发现是一些gnt文件,然后用了斗大的熊猫里面的代码,将所有文件都转化为对应label目录下的所有png的图片。(注意在HWDB1.1trn_gnt.zip解压后是alz文件,需要再次解压 我在mac没有找到合适的工具,windows上有alz的解压工具)。
处理好的数据,放到了云盘,大家可以直接在我的云盘来下载处理好的数据集HWDB1. 这里说明下,char_dict是汉字和对应的数字label的记录。
得到数据集后,就要考虑如何读取了,一次用numpy读入内存在很多小数据集上是可以行的,但是在稍微大点的数据集上内存就成了瓶颈,但是不要害怕,TensorFlow有自己的方法:
def batch_data
(
file_labels
,
sess
,
batch_size
=
128
)
:
image_list
=
[
file_label
[
0
]
for
file_label
in
file_labels
]
label_list
=
[
int
(
file_label
[
1
])
for
file_label
in
file_labels
]
print
'tag2 {0}'
.
format
(
len
(
image_list
))
images_tensor
=
tf
.
convert_to_tensor
(
image_list
,
dtype
=
tf
.
string
)
labels_tensor
=
tf
.
convert_to_tensor
(
label_list
,
dtype
=
tf
.
int64
)
input_queue
=
tf
.
train
.
slice_input_producer
([
images_tensor
,
labels_tensor
])
labels
=
input_queue
[
1
]
images_content
=
tf
.
read_file
(
input_queue
[
0
])
# images = tf.image.decode_png(images_content, channels=1)
images
=
tf
.
image
.
convert_image_dtype
(
tf
.
image
.
decode_png
(
images_content
,
channels
=
1
),
tf
.
float32
)
# images = images / 256
images
=
pre_process
(
images
)
# print images.get_shape()
# one hot
labels
=
tf
.
one_hot
(
labels
,
3755
)
image_batch
,
label_batch
=
tf
.
train
.
shuffle_batch
([
images
,
labels
],
batch_size
=
batch_size
,
capacity
=
50000
,
min_after_dequeue
=
10000
)
# print 'image_batch', image_batch.get_shape()
coord
=
tf
.
train
.
Coordinator
()
threads
=
tf
.
train
.
start_queue_runners
(
sess
=
sess
,
coord
=
coord
)
return
image_batch
,
label_batch
,
coord
,
threads
简单介绍下,首先你需要得到所有的图像的path和对应的label的列表,利用tf.convert_to_tensor转换为对应的tensor, 利用tf.train.slice_input_producer将image_list ,label_list做一个slice处理,然后做图像的读取、预处理,以及label的one_hot表示,然后就是传到tf.train.shuffle_batch产生一个个shuffle batch,这些就可以feed到你的 模型。 slice_input_producer和shuffle_batch这类操作内部都是基于queue,是一种异步的处理方式,会在设备中开辟一段空间用作cache,不同的进程会分别一直往cache中塞数据 和取数据,保证内存或显存的占用以及每一个mini-batch不需要等待,直接可以从cache中获取。
Data Augmentation
由于图像场景不复杂,只是做了一些基本的处理,包括图像翻转,改变下亮度等等,这些在TensorFlow里面有现成的api,所以尽量使用TensorFlow来做相关的处理:
def pre_process
(
images
)
:
if
FLAGS
.
random_flip_up_down
:
images
=
tf
.
image
.
random_flip_up_down
(
images
)
if
FLAGS
.
random_flip_left_right
:
images
=
tf
.
image
.
random_flip_left_right
(
images
)
if
FLAGS
.
random_brightness
:
images
=
tf
.
image
.
random_brightness
(
images
,
max_delta
=
0.3
)
if
FLAGS
.
random_contrast
:
images
=
tf
.
image
.
random_contrast
(
images
,
0.8
,
1.2
)
new_size
=
tf
.
constant
([
FLAGS
.
image_size
,
FLAGS
.
image_size
],
dtype
=
tf
.
int32
)
images
=
tf
.
image
.
resize_images
(
images
,
new_size
)
return
images
Build Graph
这里很简单的构造了一个两个卷积+一个全连接层的网络,没有做什么更深的设计,感觉意义不大,设计了一个dict,用来返回后面要用的所有op,还有就是为了方便再训练中查看loss和accuracy, 没有什么特别的,很容易理解, labels 为None时 方便做inference。
def network
(
images
,
labels
=
None
)
:
endpoints
=
{}
conv_1
=
slim
.
conv2d
(
images
,
32
,
[
3
,
3
],
1
,
padding
=
'SAME'
)
max_pool_1
=
slim
.
max_pool2d
(
conv_1
,
[
2
,
2
],[
2
,
2
],
padding
=
'SAME'
)
conv_2
=
slim
.
conv2d
(
max_pool_1
,
64
,
[
3
,
3
],
padding
=
'SAME'
)
max_pool_2
=
slim
.
max_pool2d
(
conv_2
,
[
2
,
2
],[
2
,
2
],
padding
=
'SAME'
)
flatten
=
slim
.
flatten
(
max_pool_2
)
out
=
slim
.
fully_connected
(
flatten
,
3755
,
activation_fn
=
None
)
global_step
=
tf
.
Variable
(
initial_value
=
0
)
if
labels
is
not
None
:
loss
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
out
,
labels
))
train_op
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
0.0001
).
minimize
(
loss
,
global_step
=
global_step
)
accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
tf
.
equal
(
tf
.
argmax
(
out
,
1
),
tf
.
argmax
(
labels
,
1
)),
tf
.
float32
))
tf
.
summary
.
scalar
(
'loss'
,
loss
)
tf
.
summary
.
scalar
(
'accuracy'
,
accuracy
)
merged_summary_op
=
tf
.
summary
.
merge_all
()
output_score
=
tf
.
nn
.
softmax
(
out
)
predict_val_top3
,
predict_index_top3
=
tf
.
nn
.
top_k
(
output_score
,
k
=
3
)
endpoints
[
'global_step'
]
=
global_step
if
labels
is
not
None
:
endpoints
[
'labels'
]
=
labels
endpoints
[
'train_op'
]
=
train_op
endpoints
[
'loss'
]
=
loss
endpoints
[
'accuracy'
]
=
accuracy
endpoints
[
'merged_summary_op'
]
=
merged_summary_op
endpoints
[
'output_score'
]
=
output_score
endpoints
[
'predict_val_top3'
]
=
predict_val_top3
endpoints
[
'predict_index_top3'
]
=
predict_index_top3
return
endpoints
Train
train函数包括从已有checkpoint中restore,得到step,快速恢复训练过程,训练主要是每一次得到mini-batch,更新参数,每隔eval_steps后做一次train batch的eval,每隔save_steps 后保存一次checkpoint。
def train
()
:
sess
=
tf
.
Session
()
file_labels
=
get_imagesfile
(
FLAGS
.
train_data_dir
)
images
,
labels
,
coord
,
threads
=
batch_data
(
file_labels
,
sess
)
endpoints
=
network
(
images
,
labels
)
saver
=
tf
.
train
.
Saver
()
sess
.
run
(
tf
.
global_variables_initializer
())
train_writer
=
tf
.
train
.
SummaryWriter
(
'./log'
+
'/train'
,
sess
.
graph
)
test_writer
=
tf
.
train
.
SummaryWriter
(
'./log'
+
'/val'
)
start_step
=
0
if
FLAGS
.
restore
:
ckpt
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
checkpoint_dir
)
if
ckpt
:
saver
.
restore
(
sess
,
ckpt
)
print
"restore from the checkpoint {0}"
.
format
(
ckpt
)
start_step
+=
int
(
ckpt
.
split
(
'-'
)[
-
1
])
logger
.
info
(