中文字幕理论片,69视频免费在线观看,亚洲成人app,国产1级毛片,刘涛最大尺度戏视频,欧美亚洲美女视频,2021韩国美女仙女屋vip视频

打開APP
userphoto
未登錄

開通VIP,暢享免費電子書等14項超值服

開通VIP
Tensorflow如何導(dǎo)出與使用預(yù)測圖

干貨教程第一時間送達(dá)!

tf.train.Saver API說明

保存于恢復(fù)變量,對定義好完成訓(xùn)練或者完成部分訓(xùn)練的計算圖所有OP操作的中間變量進(jìn)行保存,保存為檢查點文件(checkpoint file),檢查點文件通過restore方法完成恢復(fù),實現(xiàn)從變量到張量值(tensor value)得映射加載,可以進(jìn)行調(diào)用或者繼續(xù)訓(xùn)練。同時Saver支持全局步長參數(shù),通過對不同的step自動保存為檢查點

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

上述代碼表示分別在step=0與step=1000的時候保存檢查點。

Saver在保存檢查點的時候默認(rèn)保存計算圖的全部變量,但是可以通過var_list來決定保存多少個變量到檢查點文件中去。對保存的檢查點進(jìn)行恢復(fù)可以調(diào)用如下的方法:

restore(
    sess,
    save_path
)

從檢查點恢復(fù)變量并映射到相關(guān)的tensor中去,要求必須有一個當(dāng)前會話才可以重新加載計算圖。當(dāng)使用這種方式時候就無需再重復(fù)調(diào)用初始化方法來初始化變量了,restore方法本身就完成了變量初始化,然后就可以繼續(xù)訓(xùn)練或者使用計算圖進(jìn)行預(yù)測。

預(yù)測圖導(dǎo)出

使用tf.train.Saver會保存檢測點文件,但是這些文件不是一個,是四個文件一組:

-checkpoint
-prefix-model-steps.data-00000-of-00001
-prefix-model-steps.index
-prefix-model-steps.meta

其中

prefix是前綴名稱
steps是運行number of steps

當(dāng)prefix=my_cnn_mnist,steps=10000時

通過讀取checkpint文件與meta文件加載計算圖,然后把所有的變量轉(zhuǎn)換為常量形式通過GFile進(jìn)行串行化寫入生成預(yù)測圖(PB文件),從檢查點導(dǎo)出成為預(yù)測圖(PB文件)的代碼如下:

# We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,  # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(),  # The graph_def is used to retrieve the nodes
            output_node_names.split(',')  # The output node names are used to select the usefull nodes
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, 'wb'as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

    return output_graph_def

這段代碼我也是借鑒tensorflow中一個工具類copy過來的,發(fā)現(xiàn)很好用!

一個例子

首先定義個網(wǎng)絡(luò)模型,對于輸入與預(yù)測部分tensor的name屬性我們都給予賦值。
定義輸入-X

x = tf.placeholder(shape=[None, 784], dtype=tf.float32, name='input_x')
y = tf.placeholder(shape=[None, 10], dtype=tf.float32)
keep_prob = tf.placeholder(dtype=tf.float32)

定義預(yù)測輸出

acc_mat = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
acc = tf.reduce_sum(tf.cast(acc_mat, tf.float32))
prediction = tf.argmax(logits, axis=1, name='prediction_out')

構(gòu)建卷積神經(jīng)網(wǎng)絡(luò)的代碼如下

def conv_net(x_dict, n_classes, dropout):
    conv1 = tf.layers.conv2d(x_dict, 32, 5, activation=tf.nn.relu)
    pool1 = tf.layers.max_pooling2d(conv1, pool_size=2, strides=2)

    conv2 = tf.layers.conv2d(pool1, 64, 3, activation=tf.nn.relu)
    pool2 = tf.layers.max_pooling2d(conv2, pool_size=2, strides=2)

    fc1 = tf.layers.flatten(pool2, name='fc1')
    fc2 = tf.layers.dense(fc1, 1024)
    fc3 = tf.layers.dropout(fc2, rate=dropout)
    out = tf.layers.dense(fc3, n_classes)
    return out


logits = conv_net(x_image, num_classes, keep_prob)
cross_loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)
loss = tf.reduce_mean(cross_loss)
step = tf.train.AdamOptimizer(learning_rate).minimize(loss)

保存檢查點的代碼如下:

saver = tf.train.Saver()
......
saver.save(sess, './my_cnn_mnist.model', global_step=10000)

導(dǎo)出預(yù)測圖之后使用預(yù)測實現(xiàn)手寫數(shù)字預(yù)測的代碼如下

import argparse
import tensorflow as tf
import numpy as np
import cv2 as cv
from tensorflow.examples.tutorials.mnist import input_data
print(tf.__version__)
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)


def load_graph(frozen_graph_filename):
    # 開始解析
    with tf.gfile.GFile(frozen_graph_filename, 'rb'as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # 加載圖
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='prefix')
    return graph


if __name__ == '__main__':
    # 傳遞參數(shù),加載預(yù)測圖
    parser = argparse.ArgumentParser()
    parser.add_argument('--frozen_model_filename', default='./frozen_model.pb', type=str,
                        help='Frozen model file to import')
    args = parser.parse_args()

    # 加載
    graph = load_graph(args.frozen_model_filename)

    # 遍歷所有
    for op in graph.get_operations():
        print(op.name)

    # 獲取張量
    input_x = graph.get_tensor_by_name('prefix/input_x:0')
    prediction = graph.get_tensor_by_name('prefix/prediction_out:0')
    print(input_x, prediction)

    # 運行預(yù)測圖
    with tf.Session(graph=graph) as sess:
        for i in range(100):
            test_img = np.expand_dims(mnist.test.images[i], 0)
            predicted_ = sess.run(prediction, feed_dict={input_x: test_img})[0]
            label = np.argmax(mnist.test.labels[i])
            print('predicted number %s, actual label : %s' % (str(predicted_), str(label)))
            ti = np.reshape(mnist.test.images[i], [2828])
            ti = cv.resize( ti, (128128))
            cv.imshow('actual image', ti)
            cv.waitKey(0)

運行結(jié)果:


天下難事,必作于易
天下大事,必作于細(xì)
歡迎掃碼加入【OpenCV研習(xí)社】

本站僅提供存儲服務(wù),所有內(nèi)容均由用戶發(fā)布,如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請點擊舉報。
打開APP,閱讀全文并永久保存 查看更多類似文章
猜你喜歡
類似文章
學(xué)習(xí)TensorFlow,保存學(xué)習(xí)到的網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)并調(diào)用
tensorflow如何使用訓(xùn)練好的模型做測試 no partice
Tensorflow實例:利用LSTM預(yù)測股票每日最高價(二)
keras訓(xùn)練的h5模型轉(zhuǎn)換為pb模型
Py之tf2onnx:tf2onnx庫的簡介、安裝、使用方法之詳細(xì)攻略
anaconda3,python3.7,pycharm安裝配置
更多類似文章 >>
生活服務(wù)
熱點新聞
分享 收藏 導(dǎo)長圖 關(guān)注 下載文章
綁定賬號成功
后續(xù)可登錄賬號暢享VIP特權(quán)!
如果VIP功能使用有故障,
可點擊這里聯(lián)系客服!

聯(lián)系客服