训练好的模型参数保存起来,以便以后直接使用或进行进一步的训练,这是我们经常要做的事情,本文对TensorFlow模型保存和重新加载的方法进行了介绍。

  • 保存模型本质上保存的是模型的各项参数,对神经网络来说就是网络的结构信息、各个边的权值等,tf中提供了两种可以用来保存模型的格式:
    • checkpoints格式,a format dependent on the code that created the model。对应的tf.train.Saver类。
    • SavedModel格式,a format independent of the code that created the model。对应tf.saved_model模块。

tf.train.Saver

基本方法

Saver类保存模型的基本方法为:
1、定义变量
2、使用Saver.save()方法保存
3、重新定义变量
4、使用Saver.restore()方法加载

  • 这种方法要求按照原有模型重新定义一遍变量和网络结构,restore()方法会将保存的值加载到对应的变量中(即加载到名字相同的变量)。
  • 因为重新定义一遍网络太过麻烦,且要保证重新定义的变量名字、类型都与原变量相同,因此这种方法实际用处不大。
  • example code:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    # save the model
    import tensorflow as tf

    W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='W')
    v = tf.Variable([[4,4],[5,5]], dtype = tf.int32, name='v')
    s = tf.train.Saver()

    with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    s.save(sess, "run/model")

    # restore the model
    import tensorflow as tf

    # 类型必须相同
    W = tf.Variable(tf.zeros([2,3], dtype = tf.float32), name='W')
    v = tf.Variable(tf.zeros([2,2], dtype = tf.int32), name='v')
    s = tf.train.Saver()

    with tf.Session() as sess:
    s.restore(sess, 'run/model') # 只用传入文件前缀即可
    print(W.eval())
    print(v.eval())

不需要重新定义网络结构的加载方法

很多时候我们都希望能够读取一个文件然后直接使用模型,而不是还要把模型重新定义一遍,因此就有了这种方法。基本流程为:
1、定义变量
2、使用Saver.save()方法保存
3、使用tf.train.import_meta_graph()加载网络结构,会返回一个saver对象
4、使用saver.restore()方法恢复网络中的变量值
5、使用graph.get_operation_by_name()和graph.get_tensor_by_name()方法获取op和tensor;使用collection机制获取之前保存的值。

  • example code:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    import tensorflow as tf

    # 不需要定义网络结构的方法
    # store the model
    W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='W')
    v = tf.Variable([[4,4],[5,5]], name='v')
    b = tf.Variable([[7,7], [8,8]], name='b')
    tf.add_to_collection('b_collection', b)
    s = tf.train.Saver()

    with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    s.save(sess, "run/model")

    # restore the model
    with tf.Session() as sess:
    s = tf.train.import_meta_graph("run/model.meta")
    s.restore(sess, "run/model") # 只用传入文件前缀即可
    graph = tf.get_default_graph()
    W = graph.get_tensor_by_name("W:0")
    v = graph.get_tensor_by_name("v:0")
    b = tf.get_collection("b_collection")[0]
    print(W.eval())
    print(v.eval())
    print(b.eval())
  • 如何获取op:tf.Graph.get_operation_by_name('name')

  • 如何获取tensor/variable/placeholder:tf.Graph.get_tensor_by_name('name'),注意tensor的命名规则:<op_name>:<output_index>
  • 对于Variable,get_tensor_by_name是获取了其内部封装的tensor,但可以像使用Variable那样使用这个的tensor(因为variable的外部op如v/initial_value, v/Assign, v/read也是被保留在图中的)。
  • saver只会存tensorflow相关的数据:如tensor、operation和collection,对于python的变量类型(如string,int),saver默认不会保存,可以通过collection来存(但需要注意:string在collection中是以b’str’形式保存的,取出后需要转化为utf-8格式)。

如何重新开始训练模型

  • 在初次定义模型时,不要把输入数据、标记数据定义为tensor类型,这样saver就不会保存训练数据,节省保存模型的大小。
  • 恢复模型后可重新生成训练数据来训练,数据一般也会做shffule处理,所以不用担心会一直训练某一小部分的数据。
  • 初次定义模型时养成好习惯,对tensor/opertaion等都要起好名字,这样才能在以后恢复。
  • 需要恢复的重要变量有:train_op,loss,global_step,模型输入的placeholder等。训练过程其实就是sess.run(loss, train_op),所以这些变量尤为重要。
  • 不需要恢复optimizer类,只需要恢复train_op即可,train_op中就已经包含了当初定义的optimizer信息。
  • 恢复模型后不用再sess.run(tf.global_variables_initializer()),否则保存的变量值都会被重新覆盖掉。

模型的保存格式

  • 模型和数据会被保存为三个文件:filename.data-00000-of-00001filename.indexfilename.meta,三个文件名字都是一样的,只是后缀不同。
  • 三个文件作用如下:.meta文件存储计算图结构,.data文件存储计算图中所有的变量值,.index文件则负责指示如何在.data文件中查找计算图变量对应的值(相当于一个.meta文件和.data文件的映射)。

init

init(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)

  • 参数解释:
    • var_list:指定要保存的变量,默认是保存所传入Session中的全部变量;
    • max_to_keep:最多保存的模型数量,默认为5,即只保存最近的5次模型训练结果;若为None或0,则保存所有的模型。
    • keep_checkpoint_every_n_hours:每隔多长时间保存一次模型;
    • 通常情况下直接调用构造函数即可,不用指定参数。

tf.train.Saver.save()

save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix=’meta’,
write_meta_graph=True,
write_state=True,
strip_default_attrs=False
)

  • 函数作用:保存模型。
  • 参数解释:
    • sess:运行模型的session,默认是将session中所有的参数都保存下来;
    • save_path:保存模型的路径,可以是相对路径,也可以是绝对路径。
    • global_step:文件名后缀,用训练步数对文件名添加的数字标记。
    • 注1:save_path其实是模型文件名的一部分(前缀),比如如果值为”ckpt/mnist”,会将模型保存到ckpt目录下,文件名为”mnist-xxx”,xxx为后缀,由global_step指定。
    • 注2:无需事先创建目录,若路径中的目录不存在tf会自动创建。
  • return:返回文件的绝对路径,type:string。

tf.train.Saver.restore()

restore(
sess,
save_path
)

  • 函数作用:恢复之前保存的模型。
  • 参数解释:
    • sess:要恢复到哪个session;
    • save_path:文件保存的路径。可直接传入save()函数的返回值来掉用。
  • return:void

tf.train.import_meta_graph

tf.train.import_meta_graph(
meta_graph_or_file,
clear_devices=False,
import_scope=None,
**kwargs
)

  • 函数作用:加载计算图结构信息。
  • 参数解释:
    • meta_graph_or_file:.meta后缀文件路径
  • return:返回一个Saver对象,可后续进行restore操作

tf.train.Saver.last_checkpoints

  • Saver类的一个属性,type:list,保存了储存在磁盘上的模型的名字(包括路径),依照产生顺序从旧到新排列。
  • list中任何一个元素都可以直接作为参数传给restore()函数。

Post Date: 2019-04-03

版权声明: 本文为原创文章,转载请注明出处