为了账号安全,请及时绑定邮箱和手机立即绑定

如何从给定模型中获取 Graph(或 GraphDef)?

如何从给定模型中获取 Graph(或 GraphDef)?

qq_花开花谢_0 2023-04-18 17:15:55
我有一个使用 Tensorflow 2 和 Keras 定义的大模型。该模型在 Python 中运行良好。现在,我想将它导入到 C++ 项目中。在我的 C++ 项目中,我使用TF_GraphImportGraphDef函数。*.pb如果我使用以下代码准备文件,效果很好:    with open('load_model.pb', 'wb') as f:         f.write(tf.compat.v1.get_default_graph().as_graph_def().SerializeToString())我已经在使用 Tensorflow 1(使用 tf.compat.v1.* 函数)编写的简单网络上尝试了这段代码。它运作良好。现在我想将我的大模型(开头提到的,使用Tensorflow 2编写)导出到C++项目中。为此,我需要从我的模型中获取一个Graph或GraphDef对象。问题是:如何做到这一点?我没有找到任何属性或函数来获取它。我也试过用它tf.saved_model.save(model, 'model')来保存整个模型。它生成一个包含不同文件的目录,包括saved_model.pb文件。不幸的是,当我尝试使用TF_GraphImportGraphDef函数在 C++ 中加载此文件时,程序抛出异常。
查看完整描述

2 回答

?
海绵宝宝撒

TA贡献1809条经验 获得超8个赞

生成的协议缓冲区文件tf.saved_model.save不包含GraphDef消息,而是包含一个SavedModel. 您可以在 Python 中遍历它SavedModel以获取其中的嵌入图,但这不会立即用作冻结图,因此正确处理它可能很困难。取而代之的是,C++ API 现在包含一个LoadSavedModel调用,允许您从目录加载整个保存的模型。它应该看起来像这样:

#include <iostream>

#include <...>  // Add necessary TF include directives


using namespace std;

using namespace tensorflow;


int main()

{

    // Path to saved model directory

    const string export_dir = "...";

    // Load model

    Status s;

    SavedModelBundle bundle;

    SessionOptions session_options;

    RunOptions run_options;

    s = LoadSavedModel(session_options, run_options, export_dir,

                       // default "serve" tag set by tf.saved_model.save

                       {"serve"}, &bundle));

    if (!.ok())

    {

        cerr << "Could not load model: " << s.error_message() << endl;

        return -1;

    }

    // Model is loaded

    // ...

    return 0;

}

从这里开始,您可以做不同的事情。也许您最愿意使用 将保存的模型转换为冻结图FreezeSavedModel,这应该让您可以像以前一样做事:


GraphDef frozen_graph_def;

std::unordered_set<string> inputs;

std::unordered_set<string> outputs;

s = FreezeSavedModel(bundle, &frozen_graph_def,

                     &inputs, &outputs));

if (!s.ok())

{

    cerr << "Could not freeze model: " << s.error_message() << endl;

    return -1;

}

否则,您可以直接使用保存的模型对象:


// Default "serving_default" signature name set by tf.saved_model_save

const SignatureDef& signature_def = bundle.GetSignatures().at("serving_default");

// Get input and output names (different from layer names)

// Key is input and output layer names

const string input_name = signature_def.inputs().at("my_input").name();

const string output_name = signature_def.inputs().at("my_output").name();

// Run model

Tensor input = ...;

std::vector<Tensor> outputs;

s = bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));

if (!s.ok())

{

    cerr << "Error running model: " << s.error_message() << endl;

    return -1;

}

// Get result

Tensor& output = outputs[0];


查看完整回答
反对 回复 2023-04-18
?
宝慕林4294392

TA贡献2021条经验 获得超8个赞

我找到了以下问题的解决方案:


g = tf.Graph()

with g.as_default():


    # Create model

    inputs = tf.keras.Input(...) 

    x = tf.keras.layers.Conv2D(1, (1,1), padding='same')(inputs)

    # Done creating model


    # Optionally get graph operations

    ops = g.get_operations()

    for op in ops:

        print(op.name, op.type)


    # Save graph

    tf.io.write_graph(g.as_graph_def(), 'path', 'filename.pb', as_text=False)


查看完整回答
反对 回复 2023-04-18
  • 2 回答
  • 0 关注
  • 145 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信