专栏名称: Python开发者
人生苦短,我用 Python。伯乐在线旗下账号「Python开发者」分享 Python 相关的技术文章、工具资源、精选课程、热点资讯等。
目录
相关文章推荐
Python爱好者社区  ·  DeepSeek全攻略 ... ·  4 天前  
Python爱好者社区  ·  中国最难入的IT公司。 ·  4 天前  
Python爱好者社区  ·  黄仁勋预言成真!!! ·  2 天前  
Python爱好者社区  ·  DeepSeek彻底爆了。。。 ·  5 天前  
Python开发者  ·  湖南大学的 DeepSeek ... ·  3 天前  
51好读  ›  专栏  ›  Python开发者

TensorFlow与中文手写汉字识别

Python开发者  · 公众号  · Python  · 2017-03-24 21:29

正文

(点击 上方蓝字 ,快速关注我们)


来源: 小石头

www.duanshishi.com/?p=1753

如有好文章投稿,请点击 → 这里了解详情


Goal


本文目标是利用TensorFlow做一个简单的图像分类器,在比较大的数据集上,尽可能高效地做图像相关处理,从Train,Validation到Inference,是一个比较基本的Example, 从一个基本的任务学习如果在TensorFlow下做高效地图像读取,基本的图像处理,整个项目很简单,但其中有一些trick,在实际项目当中有很大的好处, 比如绝对不要一次读入所有的 的数据到内存(尽管在Mnist这类级别的例子上经常出现)…


最开始看到是这篇blog里面的TensorFlow练习22: 手写汉字识别(http://blog.topspeedsnail.com/archives/10897), 但是这篇文章只用了140训练与测试,试了下代码 很快,但是当扩展到所有的时,发现32g的内存都不够用,这才注意到原文中都是用numpy,会先把所有的数据放入到内存,但这个不必须的,无论在MXNet还是TensorFlow中都是不必 须的,MXNet使用的是DataIter,会在程序运行的过程中异步读取数据,TensorFlow也是这样的,TensorFlow封装了高级的api,用来做数据的读取,比如TFRecord,还有就是从filenames中读取, 来异步读取文件,然后做shuffle batch,再feed到模型的Graph中来做模型参数的更新。具体在tf如何做数据的读取可以看看reading data in tensorflow



这里我会拿到所有的数据集来做训练与测试,算作是对斗大的熊猫上面那篇文章的一个扩展。


Batch Generate


数据集来自于中科院自动化研究所,感谢分享精神!!!具体下载:


wget http : //www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip

wget http : //www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip

解压后发现是一些gnt文件,然后用了斗大的熊猫里面的代码,将所有文件都转化为对应label目录下的所有png的图片。(注意在HWDB1.1trn_gnt.zip解压后是alz文件,需要再次解压 我在mac没有找到合适的工具,windows上有alz的解压工具)。



处理好的数据,放到了云盘,大家可以直接在我的云盘来下载处理好的数据集HWDB1. 这里说明下,char_dict是汉字和对应的数字label的记录。


得到数据集后,就要考虑如何读取了,一次用numpy读入内存在很多小数据集上是可以行的,但是在稍微大点的数据集上内存就成了瓶颈,但是不要害怕,TensorFlow有自己的方法:


def batch_data ( file_labels , sess , batch_size = 128 ) :

image_list = [ file_label [ 0 ] for file_label in file_labels ]

label_list = [ int ( file_label [ 1 ]) for file_label in file_labels ]

print 'tag2 {0}' . format ( len ( image_list ))

images_tensor = tf . convert_to_tensor ( image_list , dtype = tf . string )

labels_tensor = tf . convert_to_tensor ( label_list , dtype = tf . int64 )

input_queue = tf . train . slice_input_producer ([ images_tensor , labels_tensor ])

labels = input_queue [ 1 ]

images_content = tf . read_file ( input_queue [ 0 ])

# images = tf.image.decode_png(images_content, channels=1)

images = tf . image . convert_image_dtype ( tf . image . decode_png ( images_content , channels = 1 ), tf . float32 )

# images = images / 256

images = pre_process ( images )

# print images.get_shape()

# one hot

labels = tf . one_hot ( labels , 3755 )

image_batch , label_batch = tf . train . shuffle_batch ([ images , labels ], batch_size = batch_size , capacity = 50000 , min_after_dequeue = 10000 )

# print 'image_batch', image_batch.get_shape()

coord = tf . train . Coordinator ()

threads = tf . train . start_queue_runners ( sess = sess , coord = coord )

return image_batch , label_batch , coord , threads


简单介绍下,首先你需要得到所有的图像的path和对应的label的列表,利用tf.convert_to_tensor转换为对应的tensor, 利用tf.train.slice_input_producer将image_list ,label_list做一个slice处理,然后做图像的读取、预处理,以及label的one_hot表示,然后就是传到tf.train.shuffle_batch产生一个个shuffle batch,这些就可以feed到你的 模型。 slice_input_producer和shuffle_batch这类操作内部都是基于queue,是一种异步的处理方式,会在设备中开辟一段空间用作cache,不同的进程会分别一直往cache中塞数据 和取数据,保证内存或显存的占用以及每一个mini-batch不需要等待,直接可以从cache中获取。



Data Augmentation


由于图像场景不复杂,只是做了一些基本的处理,包括图像翻转,改变下亮度等等,这些在TensorFlow里面有现成的api,所以尽量使用TensorFlow来做相关的处理:


def pre_process ( images ) :

if FLAGS . random_flip_up_down :

images = tf . image . random_flip_up_down ( images )

if FLAGS . random_flip_left_right :

images = tf . image . random_flip_left_right ( images )

if FLAGS . random_brightness :

images = tf . image . random_brightness ( images , max_delta = 0.3 )

if FLAGS . random_contrast :

images = tf . image . random_contrast ( images , 0.8 , 1.2 )

new_size = tf . constant ([ FLAGS . image_size , FLAGS . image_size ], dtype = tf . int32 )

images = tf . image . resize_images ( images , new_size )

return images


Build Graph


这里很简单的构造了一个两个卷积+一个全连接层的网络,没有做什么更深的设计,感觉意义不大,设计了一个dict,用来返回后面要用的所有op,还有就是为了方便再训练中查看loss和accuracy, 没有什么特别的,很容易理解, labels 为None时 方便做inference。


def network ( images , labels = None ) :

endpoints = {}

conv_1 = slim . conv2d ( images , 32 , [ 3 , 3 ], 1 , padding = 'SAME' )

max_pool_1 = slim . max_pool2d ( conv_1 , [ 2 , 2 ],[ 2 , 2 ], padding = 'SAME' )

conv_2 = slim . conv2d ( max_pool_1 , 64 , [ 3 , 3 ], padding = 'SAME' )

max_pool_2 = slim . max_pool2d ( conv_2 , [ 2 , 2 ],[ 2 , 2 ], padding = 'SAME' )

flatten = slim . flatten ( max_pool_2 )

out = slim . fully_connected ( flatten , 3755 , activation_fn = None )

global_step = tf . Variable ( initial_value = 0 )

if labels is not None :

loss = tf . reduce_mean ( tf . nn . softmax_cross_entropy_with_logits ( out , labels ))

train_op = tf . train . AdamOptimizer ( learning_rate = 0.0001 ). minimize ( loss , global_step = global_step )

accuracy = tf . reduce_mean ( tf . cast ( tf . equal ( tf . argmax ( out , 1 ), tf . argmax ( labels , 1 )), tf . float32 ))

tf . summary . scalar ( 'loss' , loss )

tf . summary . scalar ( 'accuracy' , accuracy )

merged_summary_op = tf . summary . merge_all ()

output_score = tf . nn . softmax ( out )

predict_val_top3 , predict_index_top3 = tf . nn . top_k ( output_score , k = 3 )

endpoints [ 'global_step' ] = global_step

if labels is not None :

endpoints [ 'labels' ] = labels

endpoints [ 'train_op' ] = train_op

endpoints [ 'loss' ] = loss

endpoints [ 'accuracy' ] = accuracy

endpoints [ 'merged_summary_op' ] = merged_summary_op

endpoints [ 'output_score' ] = output_score

endpoints [ 'predict_val_top3' ] = predict_val_top3

endpoints [ 'predict_index_top3' ] = predict_index_top3

return endpoints


Train


train函数包括从已有checkpoint中restore,得到step,快速恢复训练过程,训练主要是每一次得到mini-batch,更新参数,每隔eval_steps后做一次train batch的eval,每隔save_steps 后保存一次checkpoint。


def train () :

sess = tf . Session ()

file_labels = get_imagesfile ( FLAGS . train_data_dir )

images , labels , coord , threads = batch_data ( file_labels , sess )

endpoints = network ( images , labels )

saver = tf . train . Saver ()

sess . run ( tf . global_variables_initializer ())

train_writer = tf . train . SummaryWriter ( './log' + '/train' , sess . graph )

test_writer = tf . train . SummaryWriter ( './log' + '/val' )

start_step = 0

if FLAGS . restore :

ckpt = tf . train . latest_checkpoint ( FLAGS . checkpoint_dir )

if ckpt :

saver . restore ( sess , ckpt )

print "restore from the checkpoint {0}" . format ( ckpt )

start_step += int ( ckpt . split ( '-' )[ - 1 ])

logger . info (







请到「今天看啥」查看全文


推荐文章
Python爱好者社区  ·  中国最难入的IT公司。
4 天前
Python爱好者社区  ·  黄仁勋预言成真!!!
2 天前
Python爱好者社区  ·  DeepSeek彻底爆了。。。
5 天前
人人都是产品经理  ·  仅仅是做一个产品经理,远远不够
7 年前
刑事法律实务  ·  公证处卷入惊天骗局?司法部介入调查
7 年前