MNIST数字识别
文章目录
MNIST数据处理
MNIST是一个手写数字识别数据集,包含了60000张图片作为训练数据,10000张图片作为测试数据。每一张图片代表0~9中的一个数字。图片大小都是28x28。
Tensorflow提供了一个类来处理MNIST数据。1
2
3
4
5
6
7from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./mypath/mnist_data/', one_hot=True)
print("Training data size: ", mnist.train.num_examples)
print("Validating data size: ", mnist.validation.num_examples)
print("Testing data size: ", mnist.test.num_examples)
print("Example training data: ", mnist.train.images[0])
print("Example training data label: ", mnist.train.labels[0])
通过input_data.read_date_sets函数生成的类自动将 MNIST 数据集划分为train、validation和test三个数据集,train 有55000张图片,validation集合内有5000张图片,这两个集合组成了MNIST本身提供的训练数据集。test内有10000张图片。处理后的每一张图片是一个长度为784的一维数组,这个数组中的元素对应了图片像素矩阵中的每一个数字(28x28=784)。另外,这3个数据集还对应3个标签文件,用来标注图片上的数字是几,把图片和标签放在一起,称为“样本”,通过样本来实现一个有监督信号的深度学习模型。
相对应的,MNIST数据集的标签是介于0~9之间的数字,同来描述给定图片里表示的数字。标签数据是“one-hot vectors”:一个one-hot向量,除了某一位数字是1外,其余各维都是0。如标签0表示为[1,0,0,0,0,0,0,0,0,0].
独热编码是将分类变量转换为可提供给机器学习算法更好地进行预测的形式的过程。 一种稀疏向量,其中:一个元素设为 1;所有其他元素均设为 0。 one-hot 编码常用于表示拥有有限个可能值的字符串或标识符。例如,假设某个指定的植物学数据集记录了 15000 个不同的物种,其中每个物种都用独一无二的字符串标识符来表示。在特征工程过程中,您可能需要将这些字符串标识符编码为 one-hot 向量,向量的大小为 15000。
因为神经网络的输入是一个特征向量,所以在此把一张二维图像的像素矩阵放到一个一维数组中可以方便Tensorflow将图片的像素矩阵提供给神经网络的输入层。为了方便实用梯度下降,input_data.read_data_sets函数生成的类提供了mnist.train.next_batch函数,可以从所有的训练数据中读取一小部分作为一个训练的batch。1
2
3
4batch_size=100
xs,ys=mnist.train.next_batch(batch_size)
print(xs.shape) # (100,784)
print(ys.shape) # (100,10)
神经网络训练及不同模型结果对比
Tensorflow训练神经网络
1 | import tensorflow as tf |
使用验证数据集判断模型效果
虽然一个神经网络模型的最终效果是通过测试数据来评判的,但是不能直接通过模型在测试数据上的效果来选择参数。使用测试数据来选择参数可能会导致神经网络模型过度拟合测试数据,从而失去对未知数据的判断能力。
变量管理
Tensorflow提供了一种通过变量名来创建或获取一个变量的机制。通过这个机制,在不同的函数中可以直接通过变量的名字来使用变量,而不需要讲变量通过参数的形式到处传递。该机制主要通过tf.get_variable和tf.variable_scope函数实现。1
2
3# 下面两个定义是等价的
v=tf.get_variable('v',shape=[1],initializer=tf.constant_initializer(1.0))
v=tf.Variable(tf.constant(1.0,shape=[1]),name='v')
以下为Tensorflow中变量初始化函数
初始化函数 | 功能 | 主要参数 |
---|---|---|
tf.constant_initializer | 将变量初始化为给定常量 | 常量的取值 |
tf.random_normal_initializer | 将变量初始化为满足正太分布的随机值 | 正太分布的均值和标准差 |
tf.truncated_normal_initializer | 将变量初始化为满足正太分布的随机值,但如果随机出来的值偏离平均值超过2个标准差,那么这个数将会被重新随机 | 正太分布的均值和标准差 |
tf.random_uniform_initializer | 将变量初始化为满足平均分布的随机值 | 最大、小值 |
tf.uniform_unit_scaling_initializer | 将变量初始化为满足平均分布但不影响输出数量级的随机值 | factor(产生随机值时乘以的系数) |
tf.zeros_initializer | 将变量设置为全0 | 变量维度 |
tf.ones_initializer | 将变量设置为全1 | 变量的维度 |
tf.get_variable和tf.Variable最大的区别在于指定变量名称的参数。tf.Variable函数,变量名称是一个可选参数,而tf.get_variable是一个必选参数。
如果要通过tf.get_variable获取一个已经创建的变量,需要通过tf.variable_scope函数来生成一个上下文管理器,并明确指定在这个上下文管理器中,tf.get_variable将直接获取已经生成的变量。1
2
3
4
5
6
7
8
9
10
11
12
13
14import tensorflow as tf
with tf.variable_scope('foo'):
v = tf.get_variable('v', [1], initializer=tf.constant_initializer(1.0))
# 命名空间foo中已经存在名字为v的变量,所以一下代码会报错
# with tf.variable_scope('foo'):
# v = tf.get_variable('v', [1])
# 在生成上下文管理器时,将参数reuse设置为True,这样tf.get_variable函数直接获取已经声明的变量
with tf.variable_scope('foo', reuse=True):
v1 = tf.get_variable('v', [1])
print(v == v1) # 表明v和v1代表相同的Tensorflow中变量
# 将参数reuse设置为True,tf.variable_scope将只能获取已经创建过的变量.因为在命名空间bar中还没有创建变量v,所以以下代码会报错
with tf.variable_scope('bar',reuse=True):
v=tf.get_variable('v',[1])
当tf.variable_scope函数使用参数reuse=True生成上下文管理器时,这个上下文管理器内所有的tf.get_variable函数直接获取已经存在的变量。如果变量不存在,则函数将报错。相反如果tf.variable_scope函数使用参数reuse=None或reuse=False创建上下文管理器,tf.get_variable将创建新的变量。
Tensorflow中tf.variable_scope函数可以嵌套。1
2
3
4
5
6
7
8import tensorflow as tf
with tf.variable_scope('root'):
print(tf.get_variable_scope().reuse)
with tf.variable_scope('foo', reuse=True):
print(tf.get_variable_scope().reuse)
with tf.variable_scope('bar'):
print(tf.get_variable_scope().reuse)
print(tf.get_variable_scope().reuse)
Tensorflow函数生成的上下文管理器也会创建一个Tensorflow中的命名空间,在命名空间内创建的变量名称都会带上这个命名空间名作为前缀。所以tf.variable_scope函数处理可以控制tf.get_variable执行的功能,也提供了一个管理命名空间的方式。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16import tensorflow as tf
with tf.variable_scope('root'):
v1 = tf.get_variable('v', [1])
print(v1.name) # root/v:0
with tf.variable_scope('foo'):
v3 = tf.get_variable('v', [1])
print(v3.name) # foo/v:0
with tf.variable_scope('bar'):
v2 = tf.get_variable('v', [1])
print(v2.name) # foo/bar/v:0
# 创建一个名称为空的命名空间,并设置reuse=True
with tf.variable_scope('', reuse=True):
v4 = tf.get_variable("foo/bar/v", shape=[1])
# v5 = tf.get_variable('v', shape=[1])
# print(v5.name)
print(v4 == v2)
模型持久化
持久化代码实现
Tensorflow提供了tf.train.Saver类来保存和还原一个神经网络模型。
保存模型
1 | import tensorflow as tf |
Tensorflow模型一般会存在后缀为.ckpt文件中,运行上面的程序会出现四个文件。
model.ckpt.meta
保存Tensorflow计算图的结构;
model.ckpt
保存Tensorflow程序中每一个变量的取值;
checkponit
保存了一个目录下所有的模型文件列表;
model.ckpt.index
文件保存了当前参数名。
载入模型
1 | import tensorflow as tf |
如果不希望重复定义图上的运算,也可以直接加载已经持久化的图。1
2
3
4
5
6
7import tensorflow as tf
saver = tf.train.import_meta_graph('./path/model.ckpt.meta')
with tf.Session() as sess:
# 加载引进保存的模型,并通过已经保存的模型中的变量的值来计算加法
saver.restore(sess, './path/model.ckpt')
# 通过张量的名称来获取张量
print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
上面的程序默认保存和加载了Tensorflow计算图上定义的全部变量。但有时可能只需要保存或加载部分变量。比如,可能之前有一个训练好的五层神经网络模型,但现在想尝试一个6层的神经网络,那么可以将前面五层神经网络中的参数直接加载到新的模型,而仅将最后一层神经网络重新训练。
为了保存或加载部分变量,在声明 tf.train.Saver
类时可以提供一个列表来指定需要保存或加载的变量。比如加载模型时使用saver=tf.train.Saver([v1]),那么只有变量v1会被加载进来。处理可以选取需要被加载的变量,Saver类也支持在保存或加载时给变量重命名。1
2
3
4
5
6
7
8
9
10
11
12import tensorflow as tf
# 这里声明的变量名称和已经保存的模型中的变量的名称不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
v2 = tf.Variable(tf.constant(-5.0, shape=[1]), name='other-v2')
# 直接使用tf.train.Saver()加载模型会报错
# 使用一个字典来重命名变量就可以记载原来模型了
# 原来名称为v1的变量现在加载到名称为other-v1中
saver = tf.train.Saver({"v1": v1, "v2": v2})
result = v1 + v2
with tf.Session() as sess:
saver.restore(sess, './path/model.ckpt')
print(sess.run(result))
Tensorflow可以通过字典将模型保存时的变量名和需要加载的变量联系起来。这样做主要是方便使用变量的滑动平均。
Tensorflow中,每一个变量的滑动平均值是通过影子变量维护的,所以要获取变量的滑动平均值实际上就是获取这个影子变量的取值。如果在加载模型时直接将影子变量映射到变量本身,那么在使用训练好的模型就不需要再调用函数来获取变量的滑动平均值了。
以下代码给出一个保存滑动平均模型的样例。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name='v')
# 在没有声明滑动平均模型时只有一个变量v,所以以下语句只会输出'v:0'
for variable in tf.global_variables():
print(variable.name)
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
# 在声明滑动平均模型后,Tensorflow会自动生成一个影子变量
for variable in tf.global_variables():
print(variable.name)
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
sess.run(tf.assign(v, 10))
sess.run(maintain_averages_op)
# Tensorflow会将v:0和v/ExponentialMovingAverage:0两个变量都保存下来
saver.save(sess, './path/model.ckpt')
print(sess.run([v, ema.average(v)]))
# [10.0,0.099999905]
以下代码给出如何通过变量重命名直接读取变量的滑动平均值。下面程序结果可以看出读取变量v的值实际上是上面代码变量中v的滑动平均值。1
2
3
4
5
6import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name='v')
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
saver.restore(sess, './path/model.ckpt')
print(sess.run(v)) # 0.099999905
ExpontentialMovingAverage类提供了variables_to_restore函数来生成Saver类所需要的变量重命名字典。1
2
3
4
5
6
7
8
9import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name='v')
ema = tf.train.ExponentialMovingAverage(0.99)
# 通过使用variables_to_restore函数可以直接生成上面代码中提供的字典
print(ema.variables_to_restore())
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
saver.restore(sess, './path/model.ckpt')
print(sess.run(v))
使用 tf.train.Saver会保存运行TensorFlow程序所需要的全部信息,然而有时并不需要
某些信息。比如在测试或者离线预测时,只需要知道如何从神经网络的输入层经过前向传
播计算得到输出层即可,而不需要类似于变量初始化、模型保存等辅助节点的信息。Tensorflow提供了convert_variables_to_constants函数,通过这个函数可以将计算图中的变量及取值通过常量的方式保存。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
rresult = v1 + v2
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
graph_def = tf.get_default_graph().as_graph_def()
'''
将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉
最后一个参数['add']给出了需要保存的节点名称.add节点是上面定义的两个变量相加的操作,注意这里给出的是计算节点的名称,所以没有后面的:0
'''
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, ['add'])
# 将导出的模型存入文件
with tf.gfile.GFile('./path/combined_model.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
通过以下程序可以直接计算定义的加法运算的结果。当只需要得到计算图中某个节点的取值时,这提供了一个更加方便的方法。1
2
3
4
5
6
7
8
9
10
11
12import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename = './path/combined_model.pb'
# 读取保存的模型文件,并将文件解析成对应的GraphDef Protocol Buffer
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def中保存的图加载到当前图中.return_elements=["add:0"]给出了返回的张量名称.在保存的时候给出的是计算节点的名称,所以为"add"
# 在加载的时候给出的是张量的名称,所以是add:0
result = tf.import_graph_def(graph_def, return_elements=['add:0'])
print(sess.run(result)) # [3.0]
持久化原理及数据格式
Tensorflow是一个通过图的形式来表达计算机的编程系统,Tensorflow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据,Tensorflow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGrapgDef中的内容就构成了Tensorflow持久化时第一个文件。1
2
3
4
5
6
7
8message MetaGraphDef{
MetaInfoDef meta_info_def=1;
GraphDef graph_def=2;
SaverDef saver_def=3;
map<string,CollectionDef> collection_def=4;
map<string,SignatureDef> signature_def=5;
repeated AssetFileDef asset_file_def=6;
}
Tensorflow提供了export_meta_graph函数,以json格式导出MetaGraphDef Protocol Buffer。1
2
3
4
5
6
7
8import tensorflow as tf
# 定义变量相加的计算
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(3.0, shape=[1]), name='v2')
result = v1 + v2
saver = tf.train.Saver()
# 通过export_meta_graph函数导出Tensorflow计算图的元图,并保存为json格式
saver.export_meta_graph("./path/model.ckpt.meta.json", as_text=True)
meta_info_def属性
meta_info_def属性是通过MetaInfoDef定义的,它记录了Tensorflow计算图中的元数据以及Tensorflow程序中所有使用到的运算方法信息。下面是MetaInfoDef Buffer的定义:1
2
3
4
5
6
7
8message MetaInfoDef {
string meta_graph_version=1;
OpList stripped_op_list=2;
google.protobuf.Any any_info=3;
repeated string tags=4;
string tensorflow_version=5;
string tensorflow_git_version=6;
}
Tensorflow计算图的元数据包括计算图的版本号(meta_graph_version属性)以及用户指定的一些标签(tags属性)。如果saver中没有特殊指定,那么这些属性都默认为空。在model.ckpt.meta.json中,meta_info_def属性里只有stripped_op_list属性是不为空的。stripped_op_list属性记录了Tensorflow计算图上使用到的所有原酸方法的信息。stripped_op_list属性记录了Tensorflow计算图上使用到的所有运算方法的信息。注意stripped_op_list属性保存的是Tensorflow运算方法的信息,所以如果每一个运算在Tensorflow计算图中出现多次,那么stripped_op_list也只会出现一次。stripped_op_list属性的类型是OpList。OpList类型是一个OpDef类型的列表,以下给出OpDef类型的定义:1
2
3
4
5
6
7
8
9
10
11
12
13message Def{
string name=1;
repeated ArgDef input_arg=2;
repeated ArgDef output_arg=3;
repeqted AttrDef attr=4;
OpDeprecation deprecation=8;
string summary =5 ;
string description=6;
bool is_commutative=18;
bool is_aggregate=16;
bool is_stateful=17;
bool allows_uninitialized_input=19;
}
OpDef类型前4个属性定义了一个运算最核心的信息,OpDef中第一个属性name定义了运算的名称,这也是运算唯一的标识符。在TensorFlow计算图元图的其他属性中,比如下面将要介绍的 GraphDef属性,将通过运算名称来引用不同的运算,OpDef的第二和第三个属性为input_arg和output_arg,它们定义了运算的输入和输出,因为输入输出都可以有多个,所以这两个属性都是列表(repeated)。第四个属性atr给出了其他的运算参数信息。在 model.ckpt.meta.json文件中总共定义了8个运算,下面将给出比较有代表性的一个运算来辅助说明 OpDef的数据结构。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26op {
name:"Add",
input_arg{
name:"x",
type_attr:"T"
}
input_arg{
name:"y"
type_attr:"T"
}
output_arg{
name:"z"
type_attr:"T"
}
attr{
name:"T"
type:"type"
allowed_values:{
list{
type:DT_HALF
type:DT_FLOAT
...
}
}
}
}
上面给出了名称为Add的运算。这个运算有2个输入和1个输出,输入输出属性都指定了属性type_attr,并且这个属性的值为T。在OpDef的attr属性中,必须要出现名称(name)为T的属性。 以上样例中,这个属性指定了运算输入输出允许的参数类型(allowed_values)。MetaInfoDef中的 tensorflow_version 和 tensorflow_git_version 属性记录了生成当前计算图的 TensorFlow 版本。
graph_def 属性
graph_def属性主要记录了 TensorFlow 计算图上的节点信息。TensorFlow 计算图的每个节点对应了 TensorFlow 程序中的一个运算。因为在 meta_info_def属性中已经包含了所有运算的具体信息,所以 graph def 属性只关注运算的连接结构。graph_def 属性是通过GraphDef Protocol Buffer定义的,GraphDef 主要包含了一个NodeDef类型的列表。以下代码给出了 GraphDef和NodeDef类型中包含的信息:1
2
3
4
5
6
7
8
9
10
11message GraphDef{
repeated NodeDef node=1;
VersionDef versions=4;
};
message NodeDef{
string name=1;
string op=2;
repeated string input=3;
string device=4;
map<string,AttrValue> attr=5;
};
GraphDef中versions除妖存储了Tensorflow的版本号。GraphDef的主要信息存储在node属性,记录了Tensorflow计算图上所有的节点信息。NodeDef类型中的名称属性name是一个节点的唯一标识符。在Tensorflow中可以通过节点的名称来获取相应的节点。op属性给出了该节点使用Tensorflow运算方法名称,通过这个名称可以在Tensorflow计算图元图的meta_info_def属性中找到该运算的具体信息。
NodeDef类型中input属性是一个字符串列表,定义了运算的输入,input属性中每个字符串的取值格式为node:src_output,其中node表示节点的名称,src_output表示这个输入是指定节点的第几个输出。当src_output为0时,可以将其省略。比如node:0表示名称为node的节点的第一个输出,也可以计为node。
NodeDef类型中device。当device属性指定了处理这个运算的设备。当device属性为空时,tensorflow会自动选取一个最合适的设备来运算。
attr属性指定了和当前运算相关的配置信息。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36grapg_def{
node{
name:"v1"
op:"VariableV2"
attr {
key:"_output_shapes"
value {
list {
shape { dim { size:1 }}
}
}
}
attr {
key:"dtype"
value {
type:DT_FLOAT
}
}
...
}
node {
name:"add"
op:"Add"
input:"v1/read"
input:"v2/read"
...
}
node {
name:"save/control_dependency"
op:"Identity"
...
}
versions {
producer:24
}
}
上面给出了model.ckpt.meta.json文件中graph_def属性里比较有代表性的几个节点。第一个节点给出的是变量定义的运算。在Tensorflow中变量定义也是一个运算,运算名称为v1(name:”v1”),运算方式的名称为Variable(op:”VariableV2”)。定义变量的运算可以有很多个,于是在NodeDef类型的node属性中可以有很多个变量定义的节点。但定义变量的运算方法只用到一个,于是在MetaInfoDef类型的stripped_op_list属性中只有一个名称为VariableV2的运算方法。除了指定计算图中节点的名称和运算方法,NodeDef类型中还定义了运算相关的属性。在节点v1中,attr属性指定了这个变量的维度以及类型。
给出的第二个节点是代表加法运算的节点,指定了2个输入,一个为v1/read,另一个为v2/read。其中v1/read代表的节点可以读取变量v1的值。因为v1的值是节点v1/read的第一输出,所以后面的:0就可以省略了。v2/read也类似的代表了变量v2的取值。以上样例文件中给出的最后一个名称为save/control_dependency,该节点是系统在完成Tensorflow模型持久化过程中自动生成的一个运算。versions表示生成该文件时Tensorflow的版本号。
save_der属性
saver_def属性中记录了持久化模型时需要用到的一些参数,比如保存到文件的文件名、保存操作和加载操作的名称以及保存频率、清理历史记录等。saver_def属性的类型为SaverDef,其定义如下。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23message SaverDef{
string filename_tensor_name = 1;
string save_tensor_name = 2;
stirng restore_op_name = 3;
int32 max_to_keep =4 ;
bool sharded = 5;
float keep_checkpoint_every_n_hours = 6;
enum CheckpointFormatVarsion{
LEGACY = 0;
V1 = 1;
V2 = 2;
}
CheckpointFormatVersion version = 7;
}
//saver_def属性的内容
saver_def{
filename_tensor_name:"save/Const:0"
save_tensor_name:"save/control_dependency:0"
resotre_op_name:"save/restore_all"
max_to_keep:5
keep_checkpoint_every_n_hours:10000.0
version:V2
}
filename_tensor_name为保存文件名的张量名称,这个张量就是节点save/Const的第一个输出。save_tensor_name表示持久化Tensorflow模型的运算所对应的节点名称。从以上文件可以看出,这个节点就是在graph_def属性中给出的save/control_denpendency节点。和持久化Tensorflow模型运算对应的是加载Tensorflow模型的运算,该运算的名称由restore_op_name属性决定。max_to_keep属性和keep_checkpoint_every_n_hours属性设定了tf.train.Saver类清理之前保存的模型的策略,如到max_to_keep为5时,在第6次调用saver.save时,第一次保存的模型就会被自动删除。通过设置keep_checkpoint_every_n_hours,每n小时可以在max_to_keep的基础上多保存一个模型。
collection_def属性
在Tensorflow计算图(tf.Graph)中底层通过collection_def这个属性可以维护不同的集合。collection_def属性是一个从集合名称到集合内容的映射,集合名称为字符串,而集合内容为CollectionDef Protocol Buffer。以下代码给出CollectionDef类型定义。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24message CollectionDef {
message NodeList {
repeated string value=1;
}
message BytesList {
repeated bytes value=1;
}
message Int64List {
repeated int64 value=1 [packed=true];
}
message FloatList {
repeated float value=1 [packed=true];
}
message AnyList {
repeated google.protobuf.Any value=1;
}
oneof kind {
NodeList node_list=1;
BytesList bytes_list=2;
Int64List int64_list=3;
FloatList float_list=4;
AnyList any_list=5;
}
}
Tensorflow计算图上的集合主要可以维护4类不同的集合。NodeList维护计算图上节点的集合。BytesList维护字符串或系列化之后的Protocol Buffer的集合。比如张量是通过Protocol Buffer表示的,而张量集合是通过BytesList维护的。Int64List用于维护整数集合,FloatList用于维护实数集合。下面给出model.ckpt.meta.json文件中collection_def属性的内容。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18collection_def {
key:"trainable_variables"
value {
bytes_list {
value:"\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
value:"\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
}
}
}
collection_def {
key:"variables"
value {
bytes_list {
value:"\n\004v1:0\022\tv1/Assign\032\tv1/read:0"
value:"\n\004v2:0\022\tv2/Assign\032\tv2/read:0"
}
}
}
上面维护了两个集合,一个是所有变量的集合,名称为variables。另一个是可训练变量的集合,名为trainable_variables。
使用tf.Saver得到的model.ckpt.index和model.ckpt.data--of-文件就保存了所有变量的取值。其中model.ckpt.data文件是通过SSTable格式存储的,可以理解为一个(key,value)列表。通过tf.train.NewCheckpointReader类来查看保存的变量信息。1
2
3
4
5
6
7
8
9
10
11import tensorflow as tf
# tr.train.NewCheckpointReader可以读取checkpoint文件中保存的所有变量
# 后面的.data和.index可以省略
reader = tf.train.NewCheckpointReader('./path/model.ckpt')
# 获取所有变量列表,这个是从变量名到变量维度的字典
global_variables = reader.get_variable_to_shape_map()
for valirable_name in global_variables:
# variable_name为变量名称,global_variables[variable_name]为变量的维度
print(valirable_name, global_variables[valirable_name])
# 获取名称为v1的变量的取值
print("v=", reader.get_tensor("v"))
checkpoint文件中维护了由一个tf.train.Saver类持久化的所有Tensorflow模型文件的文件名。当某个保存的Tensorflow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容格式为CheckpointState Protocol Buffer,下面给出CheckpointState类型的定义。1
2
3
4message CheckpointState {
string model_checkpoint_path=1;
repeated string all_model_checkpoint_paths=2;
}
model_checkpoint_path属性保存了最新的Tensorflow模型文件的文件名。all_model_checkpoint_paths属性列出了当前还没有被删除的所有Tensorflow模型文件的文件名。
常用函数
函数 | 功能 |
---|---|
tf.get_collection | 表示从collection集合中取出全部变量生成一个列表 |
tf.add | 将参数列表中对应元素相加 |
tf.cast(x,dtype) | 将参数x转换为指定数据类型 |
tf.equal | 表示对比两个矩阵或向量元素,若对应元素相等则返回True;不等返回False |
tf.reduce_mean(x,axis) | 表示求取矩阵或张量指定维度的平均值,若不指定第二个参数,则在所有元素取平均值;若指定第二个参数为0,则在每一列求平均值;若指定第二个参数为1,则每一行求平均值 |
tf.argmax(x,axis) | 返回指定维度axis下,参数x中最大值索引号 |
tf.Graph().as_default | 将当前图设置为默认图,返回一个上下文管理器。该函数一般与with关键字搭配使用,应用于将已经定义好的神经网络在计算图中复现 |