Tensorflow共享变量机制理解与应用

创建变量

Tensorflow创建变量有两种方式:

  1. tf.get_variable()
  2. tf.Variable()

它们的区别如下:

在 tf.name_scope下时,tf.get_variable()创建的变量名不受 name_scope 的影响,而且在未指定共享变量时,如果重名会报错,tf.Variable()会自动检测有没有变量重名,如果有则会自行处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import tensorflow as tf
with tf.name_scope('name_scope_x'):
var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
var3 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
var4 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var3.name, sess.run(var3))
print(var4.name, sess.run(var4))
# 输出结果:
# var1:0 [-0.30036557] 可以看到前面不含有指定的'name_scope_x'
# name_scope_x/var2:0 [ 2.]
# name_scope_x/var2_1:0 [ 2.] 可以看到变量名自行变成了'var2_1',避免了和'var2'冲突

如果使用tf.get_variable()创建变量,且没有设置共享变量,重名时会报错

1
2
3
4
5
6
7
8
9
10
11
12
13
import tensorflow as tf
with tf.name_scope('name_scope_1'):
var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
var2 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var2.name, sess.run(var2))
# ValueError: Variable var1 already exists, disallowed. Did you mean
# to set reuse=True in VarScope? Originally defined at:
# var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)

共享变量

基础写法

如果要共享变量,需要使用tf.variable_scope()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import tensorflow as tf
with tf.variable_scope('variable_scope_y') as scope:
var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
scope.reuse_variables() # 设置共享变量
var1_reuse = tf.get_variable(name='var1')
var2 = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
var2_reuse = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var1_reuse.name, sess.run(var1_reuse))
print(var2.name, sess.run(var2))
print(var2_reuse.name, sess.run(var2_reuse))
# 输出结果:
# variable_scope_y/var1:0 [-1.59682846]
# variable_scope_y/var1:0 [-1.59682846] 可以看到变量var1_reuse重复使用了var1
# variable_scope_y/var2:0 [ 2.]
# variable_scope_y/var2_1:0 [ 2.]

或者如下形式:

1
2
3
4
with tf.variable_scope('foo') as foo_scope:
v = tf.get_variable('v', [1])
with tf.variable_scope('foo', reuse=True):
v1 = tf.get_variable('v')

还可以像下面这样编写:

1
2
3
4
with tf.variable_scope('foo') as foo_scope:
v = tf.get_variable('v', [1])
with tf.variable_scope(foo_scope, reuse=True):
v1 = tf.get_variable('v')

更优雅的写法

之前的几种写法是在重复使用(非第一次使用)的时候设置reuse=True来再次调用共享变量作用域(variable_scope),这是一种比较笨的方式,下面使用tf.AUTO_REUSE的写法或许更加优雅:

1
2
3
with tf.variable_scope('foo', reuse=tf.AUTO_REUSE):
v = tf.get_variable('v', [1])
v1 = tf.get_variable('v')

实例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import numpy as np
import tensorflow as tf
def convolution(in_put, in_channel, out_channel):
with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE):
weights = tf.get_variable(name="weights", shape=[2, 2, in_channel, out_channel],
initializer=tf.contrib.layers.xavier_initializer_conv2d())
output = tf.nn.conv2d(input=in_put, filter=weights, strides=[1, 1, 1, 1], padding="SAME")
return output
def main():
with tf.Graph().as_default():
input_x = tf.placeholder(dtype=tf.float32, shape=[1, 4, 4, 1])
for _ in range(5):
output = convolution(input_x, 1, 1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
_output = sess.run([output], feed_dict={input_x: np.random.uniform(low=0, high=255, size=[1, 4, 4, 1])})
print(_output)
if __name__ == "__main__":
main()

reuse参数使用

  • 当参数reuse=False,函数get_variable()表示创建变量
1
2
3
4
5
6
7
8
9
import tensorflow as tf
with tf.variable_scope("foo", reuse=False):
v = tf.get_variable("v", [1], initializer=tf.constant_initializer(1.0))
v1 = tf.get_variable("v", [1])
# 输出结果:
# ValueError: Variable foo/v already exists, disallowed.
# Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
  • 当参数reuse=True,函数get_variable()表示获取变量
1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
with tf.variable_scope("foo"):
v = tf.get_variable("v", [1], initializer=tf.constant_initializer(1.0))
with tf.variable_scope("foo", reuse=True):
v1 = tf.get_variable("v", [1])
print(v1 == v)
# 输出结果:True

在tf.variable_scope()函数中,设置reuse=True时,在其命名空间”foo”中执行函数get_variable()时,表示获取变量”v”。若在该命名空间中还没有该变量,则在获取时会报错,实例如下:

1
2
3
4
5
6
7
8
import tensorflow as tf
with tf.variable_scope("foo", reuse=True):
v1 = tf.get_variable("v",[1])
# 输出结果:
# ValueError: Variable foo/v does not exist, or was not created with tf.get_variable().
# Did you mean to set reuse=tf.AUTO_REUSE in VarScope?

参考

[1]: tensorflow里面name_scope, variable_scope等如何理解?
[2]: tf.AUTO_REUSE作用
[3]: TensorFlow中变量管理reuse参数的使用

显示 Gitment 评论