TensorFlow图形是否可以通过C API按文件而非目录加载?

dl5txlt9  于 2022-12-04  发布在  其他
关注(0)|答案(1)|浏览(96)

对于我们的项目强制性要求,网络是按文件而不是按目录加载的。
现在我按目录加载saved_model.pb,它正在工作:

auto tfSessionOpts = TF_NewSessionOptions();
runOpts = nullptr:
TF_Session * pSession = TF_LoadSessionFromSavedModel(tfSessionOpts,
                                                     runOpts,
                                                     m_strModelDir.c_str(),
                                                     m_vecTags.data(),
                                                     static_cast<int>(m_vecTags.size()),
                                                     m_graph,
                                                     nullptr,
                                                     status);

现在我必须切换到加载每个文件的方法,同时使用c_api。我有一个frozen_graph.pb,我想在C中加载它,并从它获得一些信息(Tensor等),并运行会话。我所做的是:
将frozen_graph.pb加载到缓冲区中:

std::ifstream f(m_strModelFile, std::ios::binary);
if (f.seekg(0, std::ios::end).fail()) { throw; }
auto fsize = f.tellg();
if (f.seekg(0, std::ios::beg).fail()) { throw; }
if (fsize <= 0) { throw; }
auto data = static_cast<char*>(std::malloc(fsize));
if (f.read(data, fsize).fail()) { throw; }
TF_Buffer* pBuffer = TF_NewBuffer();
pBuffer->data = data;
pBuffer->length = fsize;
pBuffer->data_deallocator = DeallocateBuffer;

使用缓冲区加载图形(状态正常):

TF_ImportGraphDefOptions* pGraphDefOptions = TF_NewImportGraphDefOptions();
TF_GraphImportGraphDef(m_graph, pBuffer, pGraphDefOptions, status);

加载会话(状态正常):

auto tfSessionOpts = TF_NewSessionOptions();
TF_Session* session = TF_NewSession(m_graph, sessionOptions, status);

缓冲区大小正确,图形和会话不是nullptr,状态为“ok”。但当我想使用图形时,没有信息可以使用。例如,我无法获得具体的操作:

auto input_op = TF_Output{ TF_GraphOperationByName(m_graph, "StatefulPartitionedCall"), 1 };
if (input_op.oper == nullptr)
{
 throw; // this happens
}

但是,当我使用TF_LoadSessionFromSavedModel函数构造图时,我可以访问这些信息。
是我做错了什么,还是一个有有用信息的图形不能按文件名加载?
如果使用c_api不可能,那么使用c++ API是否可能?

dgsult0t

dgsult0t1#

难道你还没有加载图形操作?

std::unique_ptr<TF_ImportGraphDefOptions,
                decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {
TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
TF_GraphImportGraphDef(m_graph.get(), def, graph_opts.get(),
                       status.get());

相关问题