smile学子吧 关注:8贴子:998
  • 3回复贴,共1

tensorflow 保存和加载模型

只看楼主收藏回复



IP属地:广东1楼2018-11-09 15:38回复
    我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。
    1.Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值。
    2.只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
    3.为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。


    IP属地:广东2楼2018-11-09 15:40
    回复
      import tensorflow as tf
      import numpy as np
      isTrain = False
      train_steps = 100
      checkpoint_steps = 50
      checkpoint_dir = r'C:\Users\lenovo\workspace\modelcunchu\checkpoint_dir\.'
      x = tf.placeholder(tf.float32, shape=[None, 1])
      y = 4 * x + 4
      w = tf.Variable(tf.random_normal([1], -1, 1))
      b = tf.Variable(tf.zeros([1]))
      y_predict = w * x + b
      loss = tf.reduce_mean(tf.square(y - y_predict))
      optimizer = tf.train.GradientDescentOptimizer(0.5)
      train = optimizer.minimize(loss)
      saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
      x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
      with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      if isTrain:
      for i in range(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
      saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
      # print(sess.run(w))
      # print(sess.run(b))
      else:
      ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
      if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
      else:
      pass
      print(sess.run(w))
      print(sess.run(b))


      IP属地:广东3楼2018-11-09 15:41
      回复
        改进一下:
        import tensorflow as tf
        import numpy as np
        def save(checkpoint_dir,step):
        checkpoint_dir = r'C:\Users\lenovo\workspace\modelcunchu\checkpoint_dir\.'
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
        def load(checkpoint_dir):
        import re
        checkpoint_dir = r'C:\Users\lenovo\workspace\modelcunchu\checkpoint_dir\.'
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        else:
        pass
        isTrain = False
        train_steps = 100
        checkpoint_steps = 50
        checkpoint_dir = ''
        x = tf.placeholder(tf.float32, shape=[None, 1])
        y = 4 * x + 4
        w = tf.Variable(tf.random_normal([1], -1, 1))
        b = tf.Variable(tf.zeros([1]))
        y_predict = w * x + b
        loss = tf.reduce_mean(tf.square(y - y_predict))
        optimizer = tf.train.GradientDescentOptimizer(0.5)
        train = optimizer.minimize(loss)
        saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
        x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
        with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if isTrain:
        for i in range(train_steps):
        sess.run(train, feed_dict={x: x_data})
        if (i + 1) % checkpoint_steps == 0:
        save(checkpoint_dir,i+1)
        print(sess.run(w))
        print(sess.run(b))
        else:
        load(checkpoint_dir)
        print(sess.run(w))
        print(sess.run(b))


        IP属地:广东4楼2018-11-09 16:45
        回复