如何在Google Colab上下载TensorFlow训练模型?

kuhbmx9i  于 2023-01-02  发布在  Go
关注(0)|答案(3)|浏览(334)

我是一个新手和研究机器学习,我已经创建了一个模型上GoogleColab
我的目标是在Android应用中使用该模型进行离线预测,因此我需要下载经过训练的模型。
我唯一知道的是,我需要保存我的模型为.pb文件,以便使我的Android应用程序。我四处搜索,也许有答案,但他们太短,我无法理解,所以需要非常详细的答案。
这是我的Test.ipynb文件,有人能花一点时间来训练模型,看看我们是否可以下载到本地驱动器.

brccelvz

brccelvz1#

这是我从Collab保存和下载模型文件的方式。
1.仅保存可训练变量(这些是我的存储/恢复fns):
代码如下:

def store(sess_var, model_path):        
    if model_path is not None:
        saver = tf.train.Saver(var_list=tf.trainable_variables())
        save_path = saver.save(sess_var, model_path)
        print("Model saved in path: %s" % save_path)
    else:
        print("Model path is None - Nothing to store")


def restore(sess_var, model_path):
    if model_path is not None:
        if os.path.exists("{}.index".format(model_path)):
            saver = tf.train.Saver(var_list=tf.trainable_variables())
            saver.restore(sess_var, model_path)
            print("Model at %s restored" % model_path)
        else:
            print("Model path does not exist, skipping...")
    else:
        print("Model path is None - Nothing to restore")

1.压缩存储模型的目录-确保其中不包含其他内容:第一个月
1.下载模型:
from google.colab import files files.download('model.tar.gz')
由于您只存储可训练变量而不是整个会话,因此模型的大小很小,因此可以下载。请确保使用Chrome-我无法在Firefox上运行最后一个片段。

dbf7pr2w

dbf7pr2w2#

这应该可以做到:

import tensorflow as tf
from google.colab import files

# Specify export directory and use tensorflow to save your_model
export_dir = './saved_model'
tf.saved_model.save(your_model, export_dir=export_dir)

请注意,导出目录包含多个文件,但如果您只想下载. pb文件,则可以使用以下命令。

# Download the model
files.download(export_dir + '/saved_model.pb')
g6baxovj

g6baxovj3#

!mkdir -p saved_model
saved_model_dir = 'saved_model/my_model'
model.save(saved_model_dir) 

#Convert the model

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) 

tflite_model = converter.convert()

#Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

从tensorflow 文档。

相关问题