深度学习2.0-23.Keras高层接口之模型的加载与保存

1147-柳同学

发表文章数:589

首页 » 算法 » 正文

模型的保存与加载

深度学习2.0-23.Keras高层接口之模型的加载与保存

1.load/save_weights

深度学习2.0-23.Keras高层接口之模型的加载与保存
深度学习2.0-23.Keras高层接口之模型的加载与保存

实战
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics


def preprocess(x, y):
    """
    x is a simple image, not a batch
    """
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())

db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)

ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()

network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)

network.evaluate(ds_val)

# 保存模型的参数
network.save_weights('weights.ckpt')
print('saved weights.')
del network

# 构建多层网络
network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )
# 加载模型的参数
network.load_weights('weights.ckpt')
print('loaded weights!')
network.evaluate(ds_val)
Epoch 2/3
  1/469 [..............................] - ETA: 13:50 - loss: 0.1488 - accuracy: 0.9688
 19/469 [>.............................] - ETA: 43s - loss: 0.1262 - accuracy: 0.9634  
 39/469 [=>............................] - ETA: 20s - loss: 0.1342 - accuracy: 0.9619
 57/469 [==>...........................] - ETA: 13s - loss: 0.1324 - accuracy: 0.9635
 76/469 [===>..........................] - ETA: 10s - loss: 0.1379 - accuracy: 0.9623
 93/469 [====>.........................] - ETA: 8s - loss: 0.1348 - accuracy: 0.9632 
110/469 [======>.......................] - ETA: 6s - loss: 0.1370 - accuracy: 0.9618
130/469 [=======>......................] - ETA: 5s - loss: 0.1375 - accuracy: 0.9615
150/469 [========>.....................] - ETA: 4s - loss: 0.1384 - accuracy: 0.9618
169/469 [=========>....................] - ETA: 3s - loss: 0.1384 - accuracy: 0.9614
187/469 [==========>...................] - ETA: 3s - loss: 0.1369 - accuracy: 0.9619
207/469 [============>.................] - ETA: 2s - loss: 0.1385 - accuracy: 0.9612
229/469 [=============>................] - ETA: 2s - loss: 0.1387 - accuracy: 0.9612
251/469 [===============>..............] - ETA: 2s - loss: 0.1393 - accuracy: 0.9610
274/469 [================>.............] - ETA: 1s - loss: 0.1388 - accuracy: 0.9615
297/469 [=================>............] - ETA: 1s - loss: 0.1378 - accuracy: 0.9616
319/469 [===================>..........] - ETA: 1s - loss: 0.1373 - accuracy: 0.9618
342/469 [====================>.........] - ETA: 0s - loss: 0.1366 - accuracy: 0.9621
363/469 [======================>.......] - ETA: 0s - loss: 0.1356 - accuracy: 0.9624
385/469 [=======================>......] - ETA: 0s - loss: 0.1362 - accuracy: 0.9623
407/469 [=========================>....] - ETA: 0s - loss: 0.1358 - accuracy: 0.9624
429/469 [==========================>...] - ETA: 0s - loss: 0.1350 - accuracy: 0.9627
450/469 [===========================>..] - ETA: 0s - loss: 0.1342 - accuracy: 0.9629
466/469 [============================>.] - ETA: 0s - loss: 0.1343 - accuracy: 0.9629
469/469 [==============================] - 3s 7ms/step - loss: 0.1344 - accuracy: 0.9629 - val_loss: 0.1209 - val_accuracy: 0.9648
Epoch 3/3
  1/469 [..............................] - ETA: 14:16 - loss: 0.1254 - accuracy: 0.9609
 20/469 [>.............................] - ETA: 42s - loss: 0.1014 - accuracy: 0.9695  
 39/469 [=>............................] - ETA: 21s - loss: 0.1063 - accuracy: 0.9700
 60/469 [==>...........................] - ETA: 13s - loss: 0.1006 - accuracy: 0.9703
 82/469 [====>.........................] - ETA: 9s - loss: 0.1041 - accuracy: 0.9690 
105/469 [=====>........................] - ETA: 7s - loss: 0.1089 - accuracy: 0.9676
128/469 [=======>......................] - ETA: 5s - loss: 0.1072 - accuracy: 0.9684
151/469 [========>.....................] - ETA: 4s - loss: 0.1056 - accuracy: 0.9692
171/469 [=========>....................] - ETA: 3s - loss: 0.1089 - accuracy: 0.9688
189/469 [===========>..................] - ETA: 3s - loss: 0.1094 - accuracy: 0.9688
208/469 [============>.................] - ETA: 2s - loss: 0.1122 - accuracy: 0.9681
228/469 [=============>................] - ETA: 2s - loss: 0.1099 - accuracy: 0.9687
250/469 [==============>...............] - ETA: 2s - loss: 0.1093 - accuracy: 0.9691
270/469 [================>.............] - ETA: 1s - loss: 0.1088 - accuracy: 0.9692
291/469 [=================>............] - ETA: 1s - loss: 0.1081 - accuracy: 0.9696
312/469 [==================>...........] - ETA: 1s - loss: 0.1079 - accuracy: 0.9700
334/469 [====================>.........] - ETA: 1s - loss: 0.1082 - accuracy: 0.9700
356/469 [=====================>........] - ETA: 0s - loss: 0.1086 - accuracy: 0.9699
378/469 [=======================>......] - ETA: 0s - loss: 0.1083 - accuracy: 0.9699
401/469 [========================>.....] - ETA: 0s - loss: 0.1071 - accuracy: 0.9700
422/469 [=========================>....] - ETA: 0s - loss: 0.1081 - accuracy: 0.9698
441/469 [===========================>..] - ETA: 0s - loss: 0.1089 - accuracy: 0.9697
459/469 [============================>.] - ETA: 0s - loss: 0.1083 - accuracy: 0.9700
469/469 [==============================] - 3s 6ms/step - loss: 0.1082 - accuracy: 0.9701
 1/79 [..............................] - ETA: 0s - loss: 0.0620 - accuracy: 0.9844
11/79 [===>..........................] - ETA: 0s - loss: 0.1625 - accuracy: 0.9616
21/79 [======>.......................] - ETA: 0s - loss: 0.1902 - accuracy: 0.9576
32/79 [===========>..................] - ETA: 0s - loss: 0.1910 - accuracy: 0.9570
41/79 [==============>...............] - ETA: 0s - loss: 0.1845 - accuracy: 0.9573
50/79 [=================>............] - ETA: 0s - loss: 0.1695 - accuracy: 0.9605
60/79 [=====================>........] - ETA: 0s - loss: 0.1499 - accuracy: 0.9645
70/79 [=========================>....] - ETA: 0s - loss: 0.1389 - accuracy: 0.9667
79/79 [==============================] - 0s 5ms/step - loss: 0.1372 - accuracy: 0.9664
saved weights.
loaded weights!
 1/79 [..............................] - ETA: 6s - loss: 0.0620 - accuracy: 0.9844
11/79 [===>..........................] - ETA: 0s - loss: 0.1625 - accuracy: 0.9616
21/79 [======>.......................] - ETA: 0s - loss: 0.1902 - accuracy: 0.9576
30/79 [==========>...................] - ETA: 0s - loss: 0.1884 - accuracy: 0.9581
39/79 [=============>................] - ETA: 0s - loss: 0.1914 - accuracy: 0.9559
49/79 [=================>............] - ETA: 0s - loss: 0.1724 - accuracy: 0.9600
59/79 [=====================>........] - ETA: 0s - loss: 0.1522 - accuracy: 0.9640
69/79 [=========================>....] - ETA: 0s - loss: 0.1404 - accuracy: 0.9665
78/79 [============================>.] - ETA: 0s - loss: 0.1388 - accuracy: 0.9663
79/79 [==============================] - 1s 6ms/step - loss: 0.1372 - accuracy: 0.9664

2.save/load entire model

深度学习2.0-23.Keras高层接口之模型的加载与保存

实战
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

# 数据预处理
def preprocess(x, y):
    """
    x is a simple image, not a batch
    """
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batchsz = 128
# 数据集加载
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())

db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)

ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)

network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dense(128, activation='relu'),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary()

network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )

network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2)

network.evaluate(ds_val)

# 保存整个模型
network.save('model.h5')
print('saved total model.')
del network

print('loaded model from file.')
# 加载整个模型
network = tf.keras.models.load_model('model.h5', compile=False)
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy']
                )

x_val = tf.cast(x_val, dtype=tf.float32) / 255.
x_val = tf.reshape(x_val, [-1, 28 * 28])
y_val = tf.cast(y_val, dtype=tf.int32)
y_val = tf.one_hot(y_val, depth=10)

ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(128)
network.evaluate(ds_val)

3.saved_model-用于工业环境的部署

深度学习2.0-23.Keras高层接口之模型的加载与保存

未经允许不得转载:作者:1147-柳同学, 转载或复制请以 超链接形式 并注明出处 拜师资源博客
原文地址:《深度学习2.0-23.Keras高层接口之模型的加载与保存》 发布于2020-10-10

分享到:
赞(0) 打赏

评论 抢沙发

评论前必须登录!

  注册



长按图片转发给朋友

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

微信扫一扫打赏

Vieu3.3主题
专业打造轻量级个人企业风格博客主题!专注于前端开发,全站响应式布局自适应模板。

登录

忘记密码 ?

您也可以使用第三方帐号快捷登录

Q Q 登 录
微 博 登 录