作者:刘光聪
原文:http://www.jianshu.com/p/bebcdfb74fb1?utm_campaign=hugo&utm_medium=reader_share&utm_content=note&utm_source=weixin-friends
Variable
是一个特殊的OP,它拥有状态(Stateful)。本文通过阐述Variable初始化模型,深入理解变量初始化的过程。
以一个简单的线性模型为例(为了简化问题,此处省略了训练子图)。首先,使用tf.placeholder
定义模型的输入,然后定义了两个全局变量,同时它们都是训练参数,最后定义学习模型。
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]), name='W')
b = tf.Variable(tf.zeros([10]), name='b')
y = tf.matmul(x, W) + b
在使用变量之前,必须对变量进行初始化。按照习惯用法,使用tf.global_variables_initializer()
将所有全局变量的初始化器汇总,并对其进行初始化。
init = tf.global_variables_initializer()with tf.Session() as sess:
sess.run(init)
按照既有经验,其计算图大致如下图所示。
线性模型
事实上,正如下图所示,实际的计算图要复杂得多,让我们从头说起。
线性模型
Variable
是一个特殊的OP,它拥有状态(Stateful)。如果从实现技术探究,Variable
的Kernel实现直接持有一个Tensor
实例,其生命周期与变量一致。相对于普通的Tensor实例,其生命周期仅对本次迭代(Step)有效;而Variable对多个迭代都有效,甚至可以存储到文件系统,或从文件系统中恢复。
此外,存在几个操作Variable的特殊OP,例如Assign, AssignAdd等。变量所持有的Tensor以引用的方式输入到Assign中,Assign根据初始值,就地修改Tensor内部的值,最后以引用的方式输出该Tensor。
一般地,在使用变量之前,必须对变量进行初始化。事实上,TensorFlow设计了一个精巧的变量初始化模型。Variable根据初始值(Initial Value)进行类型推演,并确定Tensor的形状(Shape)。另外,通过初始化器(Initializer)在初始化期间,将初始化值赋予Variable内部所持有Tensor,完成Variable的就地修改。
例如,变量W
的定义如下。tf.zeros([784,10])
常称为初始值,它通过初始化器Assign,将W内部持有的Tensor以引用的形式就地修改为该初始值。
W = tf.Variable(tf.zeros([784,10]), name='W')
如果要读取变量的值,则通过Identity
恒等变化,直接输出变量所持有的Tensor。但时,Identity
去除了Variable的引用标识,同时也避免了内存拷贝。
变量初始化模型
然后,通过调用tf.global_variables_initializer()
将变量的所有初始化器进行汇总,然后启动Session运行该OP。
init = tf.global_variables_initializer()
事实上,搜集所有全局变量的初始化器的OP是一个NoOp
,即不存在输入,也不存在输出。所有变量的初始化器通过控制依赖边与该NoOp相连,保证所有的全局变量被初始化。
初始化过程
同位关系是一种特殊的设备约束关系。显而易见,Assign, Identity
这两个OP与Variable
关系极其紧密,分别实现了变量的修改与读取功能。因此,它们必须与Variable
在同一个设备上执行。
这样的关系,常称为同位关系(Colocation)。可以在Assign/Identity
节点上指定_class
属性值:[s: "loc:@W"]
,它表示这两个OP与W
放在同一个设备上运行。
例如,以W/read
节点为例,该节点增加了_class
属性,指示与W
的同位关系。
node {
name: "W/read"
op: "Identity"
input: "W"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@W"
}
}
}
}
如果一个变量初始化需要依赖于另外一个变量的初始值,则需要特殊地处理。例如,变量V
的初始值依赖于W
的初始值,可以通过W.initialized_value()
指定。
W = tf.Variable(tf.zeros([784,10]), name='W')
V = tf.Variable(W.initialized_value(), name='V')
事实上,两者通过Identity
衔接,并显式地添加了依赖控制边,保证W
在V
之前初始化。此处,存在两个Identity
的OP,但职责不一样,它们分别完成初始化依赖和变量读取。
初始化依赖
同样地,可以通过调用tf.global_variables_initializer()
将变量的所有初始化器进行汇总,然后启动Session完成所有变量的初始化。
init = tf.global_variables_initializer()
按照依赖关系,因为增加了W/Assign
与Identity
之间的控制依赖边,从而巧妙地实现了W
在V
之前完成初始化,并通过W
当前的初始化值,最终完成V
的初始化。
初始化过程