本文整理了Java中org.tensorflow.Graph.importGraphDef()
方法的一些代码示例,展示了Graph.importGraphDef()
的具体用法。这些代码示例主要来源于Github
/Stackoverflow
/Maven
等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Graph.importGraphDef()
方法的具体详情如下:
包路径:org.tensorflow.Graph
类名称:Graph
方法名:importGraphDef
[英]Import a serialized representation of a TensorFlow graph.
The serialized representation of the graph, often referred to as a GraphDef, can be generated by #toGraphDef() and equivalents in other language APIs.
[中]导入TensorFlow图的序列化表示形式。
图的序列化表示(通常称为GraphDef)可以由#toGraphDef()和其他语言API中的等价物生成。
代码示例来源:origin: apache/ignite
/** {@inheritDoc} */
@Override public Session parseModel(byte[] mdl) {
Graph graph = new Graph();
graph.importGraphDef(mdl);
return new Session(graph);
}
}
代码示例来源:origin: org.tensorflow/libtensorflow
/**
* Import a serialized representation of a TensorFlow graph.
*
* <p>The serialized representation of the graph, often referred to as a <i>GraphDef</i>, can be
* generated by {@link #toGraphDef()} and equivalents in other language APIs.
*
* @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
* @see #importGraphDef(byte[], String)
*/
public void importGraphDef(byte[] graphDef) throws IllegalArgumentException {
importGraphDef(graphDef, "");
}
代码示例来源:origin: org.bytedeco.javacpp-presets/tensorflow
/**
* Import a serialized representation of a TensorFlow graph.
*
* <p>The serialized representation of the graph, often referred to as a <i>GraphDef</i>, can be
* generated by {@link #toGraphDef()} and equivalents in other language APIs.
*
* @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
* @see #importGraphDef(byte[], String)
*/
public void importGraphDef(byte[] graphDef) throws IllegalArgumentException {
importGraphDef(graphDef, "");
}
代码示例来源:origin: org.tensorflow/libtensorflow
/**
* Import a serialized representation of a TensorFlow graph.
*
* @param graphDef the serialized representation of a TensorFlow graph.
* @param prefix a prefix that will be prepended to names in graphDef
* @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
* @see #importGraphDef(byte[])
*/
public void importGraphDef(byte[] graphDef, String prefix) throws IllegalArgumentException {
if (graphDef == null || prefix == null) {
throw new IllegalArgumentException("graphDef and prefix cannot be null");
}
synchronized (nativeHandleLock) {
importGraphDef(nativeHandle, graphDef, prefix);
}
}
代码示例来源:origin: org.bytedeco.javacpp-presets/tensorflow
/**
* Import a serialized representation of a TensorFlow graph.
*
* @param graphDef the serialized representation of a TensorFlow graph.
* @param prefix a prefix that will be prepended to names in graphDef
* @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
* @see #importGraphDef(byte[])
*/
public void importGraphDef(byte[] graphDef, String prefix) throws IllegalArgumentException {
if (graphDef == null || prefix == null) {
throw new IllegalArgumentException("graphDef and prefix cannot be null");
}
synchronized (nativeHandleLock) {
importGraphDef(nativeHandle, graphDef, prefix);
}
}
代码示例来源:origin: org.bytedeco.javacpp-presets/tensorflow
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();
if (VERSION.SDK_INT >= 18) {
}
try {
g.importGraphDef(graphDef);
} catch (IllegalArgumentException e) {
throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
}
if (VERSION.SDK_INT >= 18) {
}
final long endMs = System.currentTimeMillis();
Log.i(
TAG,
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}
代码示例来源:origin: org.springframework.cloud.stream.app/spring-cloud-starter-stream-common-tensorflow
public TensorFlowService(Resource modelLocation) throws IOException {
try (InputStream is = modelLocation.getInputStream()) {
if (logger.isInfoEnabled()) {
logger.info("Loading TensorFlow graph model: " + modelLocation);
}
graph = new Graph();
graph.importGraphDef(StreamUtils.copyToByteArray(is));
}
}
代码示例来源:origin: spring-cloud-stream-app-starters/tensorflow
public TensorFlowService(Resource modelLocation) {
if (logger.isInfoEnabled()) {
logger.info("Loading TensorFlow graph model: " + modelLocation);
}
graph = new Graph();
byte[] model = new ModelExtractor().getModel(modelLocation);
graph.importGraphDef(model);
}
代码示例来源:origin: spotify/zoltar
/**
* Note: Please use Models from zoltar-models module.
*
* <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.
*
* @param id model id @{link Model.Id}.
* @param graphDef byte array representing the TensorFlow {@link Graph} definition.
* @param config ConfigProto config for TensorFlow {@link Session}.
* @param prefix a prefix that will be prepended to names in graphDef.
*/
public static TensorFlowGraphModel create(
final Model.Id id,
final byte[] graphDef,
@Nullable final ConfigProto config,
@Nullable final String prefix)
throws IOException {
final Graph graph = new Graph();
final Session session = new Session(graph, config != null ? config.toByteArray() : null);
final long loadStart = System.currentTimeMillis();
if (prefix == null) {
LOG.debug("Loading graph definition without prefix");
graph.importGraphDef(graphDef);
} else {
LOG.debug("Loading graph definition with prefix: {}", prefix);
graph.importGraphDef(graphDef, prefix);
}
LOG.info("TensorFlow graph loaded in {} ms", System.currentTimeMillis() - loadStart);
return new AutoValue_TensorFlowGraphModel(id, graph, session);
}
代码示例来源:origin: com.spotify/zoltar-tensorflow
/**
* Note: Please use Models from zoltar-models module.
*
* <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.</p>
*
* @param id model id @{link Model.Id}.
* @param graphDef byte array representing the TensorFlow {@link Graph} definition.
* @param config ConfigProto config for TensorFlow {@link Session}.
* @param prefix a prefix that will be prepended to names in graphDef.
*/
public static TensorFlowGraphModel create(final Model.Id id,
final byte[] graphDef,
@Nullable final ConfigProto config,
@Nullable final String prefix)
throws IOException {
final Graph graph = new Graph();
final Session session = new Session(graph, config != null ? config.toByteArray() : null);
final long loadStart = System.currentTimeMillis();
if (prefix == null) {
LOG.debug("Loading graph definition without prefix");
graph.importGraphDef(graphDef);
} else {
LOG.debug("Loading graph definition with prefix: {}", prefix);
graph.importGraphDef(graphDef, prefix);
}
LOG.info("TensorFlow graph loaded in {} ms", System.currentTimeMillis() - loadStart);
return new AutoValue_TensorFlowGraphModel(id, graph, session);
}
代码示例来源:origin: tahaemara/object-recognition-tensorflow
private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("softmax").run().get(0)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
}
代码示例来源:origin: jdye64/nifi-addons
private float[] executeInceptionGraph(byte[] graphDef, Tensor image, String feedNodeName, String outputNodeName) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g)) {
Tensor result = s.runner().feed(feedNodeName, image).fetch(outputNodeName).run().get(0);
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
}
代码示例来源:origin: org.bytedeco.javacpp-presets/tensorflow
private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
Tensor<Float> result =
s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
}
内容来源于网络,如有侵权,请联系作者删除!