org.tensorflow.Graph.importGraphDef()方法的使用及代码示例

x33g5p2x  于2022-01-20 转载在 其他  
字(8.2k)|赞(0)|评价(0)|浏览(254)

本文整理了Java中org.tensorflow.Graph.importGraphDef()方法的一些代码示例,展示了Graph.importGraphDef()的具体用法。这些代码示例主要来源于Github/Stackoverflow/Maven等平台,是从一些精选项目中提取出来的代码,具有较强的参考意义,能在一定程度帮忙到你。Graph.importGraphDef()方法的具体详情如下:
包路径:org.tensorflow.Graph
类名称:Graph
方法名:importGraphDef

Graph.importGraphDef介绍

[英]Import a serialized representation of a TensorFlow graph.

The serialized representation of the graph, often referred to as a GraphDef, can be generated by #toGraphDef() and equivalents in other language APIs.
[中]导入TensorFlow图的序列化表示形式。
图的序列化表示(通常称为GraphDef)可以由#toGraphDef()和其他语言API中的等价物生成。

代码示例

代码示例来源:origin: apache/ignite

/** {@inheritDoc} */
  @Override public Session parseModel(byte[] mdl) {
    Graph graph = new Graph();
    graph.importGraphDef(mdl);

    return new Session(graph);
  }
}

代码示例来源:origin: org.tensorflow/libtensorflow

/**
 * Import a serialized representation of a TensorFlow graph.
 *
 * <p>The serialized representation of the graph, often referred to as a <i>GraphDef</i>, can be
 * generated by {@link #toGraphDef()} and equivalents in other language APIs.
 *
 * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
 * @see #importGraphDef(byte[], String)
 */
public void importGraphDef(byte[] graphDef) throws IllegalArgumentException {
 importGraphDef(graphDef, "");
}

代码示例来源:origin: org.bytedeco.javacpp-presets/tensorflow

/**
 * Import a serialized representation of a TensorFlow graph.
 *
 * <p>The serialized representation of the graph, often referred to as a <i>GraphDef</i>, can be
 * generated by {@link #toGraphDef()} and equivalents in other language APIs.
 *
 * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
 * @see #importGraphDef(byte[], String)
 */
public void importGraphDef(byte[] graphDef) throws IllegalArgumentException {
 importGraphDef(graphDef, "");
}

代码示例来源:origin: org.tensorflow/libtensorflow

/**
 * Import a serialized representation of a TensorFlow graph.
 *
 * @param graphDef the serialized representation of a TensorFlow graph.
 * @param prefix a prefix that will be prepended to names in graphDef
 * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
 * @see #importGraphDef(byte[])
 */
public void importGraphDef(byte[] graphDef, String prefix) throws IllegalArgumentException {
 if (graphDef == null || prefix == null) {
  throw new IllegalArgumentException("graphDef and prefix cannot be null");
 }
 synchronized (nativeHandleLock) {
  importGraphDef(nativeHandle, graphDef, prefix);
 }
}

代码示例来源:origin: org.bytedeco.javacpp-presets/tensorflow

/**
 * Import a serialized representation of a TensorFlow graph.
 *
 * @param graphDef the serialized representation of a TensorFlow graph.
 * @param prefix a prefix that will be prepended to names in graphDef
 * @throws IllegalArgumentException if graphDef is not a recognized serialization of a graph.
 * @see #importGraphDef(byte[])
 */
public void importGraphDef(byte[] graphDef, String prefix) throws IllegalArgumentException {
 if (graphDef == null || prefix == null) {
  throw new IllegalArgumentException("graphDef and prefix cannot be null");
 }
 synchronized (nativeHandleLock) {
  importGraphDef(nativeHandle, graphDef, prefix);
 }
}

代码示例来源:origin: org.bytedeco.javacpp-presets/tensorflow

private void loadGraph(byte[] graphDef, Graph g) throws IOException {
 final long startMs = System.currentTimeMillis();
 if (VERSION.SDK_INT >= 18) {
 }
 try {
  g.importGraphDef(graphDef);
 } catch (IllegalArgumentException e) {
  throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
 }
 if (VERSION.SDK_INT >= 18) {
 }
 final long endMs = System.currentTimeMillis();
 Log.i(
   TAG,
   "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}

代码示例来源:origin: org.springframework.cloud.stream.app/spring-cloud-starter-stream-common-tensorflow

public TensorFlowService(Resource modelLocation) throws IOException {
  try (InputStream is = modelLocation.getInputStream()) {
    if (logger.isInfoEnabled()) {
      logger.info("Loading TensorFlow graph model: " + modelLocation);
    }
    graph = new Graph();
    graph.importGraphDef(StreamUtils.copyToByteArray(is));
  }
}

代码示例来源:origin: spring-cloud-stream-app-starters/tensorflow

public TensorFlowService(Resource modelLocation) {
  if (logger.isInfoEnabled()) {
    logger.info("Loading TensorFlow graph model: " + modelLocation);
  }
  graph = new Graph();
  byte[] model = new ModelExtractor().getModel(modelLocation);
  graph.importGraphDef(model);
}

代码示例来源:origin: spotify/zoltar

/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.
 *
 * @param id model id @{link Model.Id}.
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config ConfigProto config for TensorFlow {@link Session}.
 * @param prefix a prefix that will be prepended to names in graphDef.
 */
public static TensorFlowGraphModel create(
  final Model.Id id,
  final byte[] graphDef,
  @Nullable final ConfigProto config,
  @Nullable final String prefix)
  throws IOException {
 final Graph graph = new Graph();
 final Session session = new Session(graph, config != null ? config.toByteArray() : null);
 final long loadStart = System.currentTimeMillis();
 if (prefix == null) {
  LOG.debug("Loading graph definition without prefix");
  graph.importGraphDef(graphDef);
 } else {
  LOG.debug("Loading graph definition with prefix: {}", prefix);
  graph.importGraphDef(graphDef, prefix);
 }
 LOG.info("TensorFlow graph loaded in {} ms", System.currentTimeMillis() - loadStart);
 return new AutoValue_TensorFlowGraphModel(id, graph, session);
}

代码示例来源:origin: com.spotify/zoltar-tensorflow

/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.</p>
 *
 * @param id       model id @{link Model.Id}.
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config   ConfigProto config for TensorFlow {@link Session}.
 * @param prefix   a prefix that will be prepended to names in graphDef.
 */
public static TensorFlowGraphModel create(final Model.Id id,
                     final byte[] graphDef,
                     @Nullable final ConfigProto config,
                     @Nullable final String prefix)
  throws IOException {
 final Graph graph = new Graph();
 final Session session = new Session(graph, config != null ? config.toByteArray() : null);
 final long loadStart = System.currentTimeMillis();
 if (prefix == null) {
  LOG.debug("Loading graph definition without prefix");
  graph.importGraphDef(graphDef);
 } else {
  LOG.debug("Loading graph definition with prefix: {}", prefix);
  graph.importGraphDef(graphDef, prefix);
 }
 LOG.info("TensorFlow graph loaded in {} ms", System.currentTimeMillis() - loadStart);
 return new AutoValue_TensorFlowGraphModel(id, graph, session);
}

代码示例来源:origin: tahaemara/object-recognition-tensorflow

private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
  try (Graph g = new Graph()) {
    g.importGraphDef(graphDef);
    try (Session s = new Session(g);
        Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("softmax").run().get(0)) {
      final long[] rshape = result.shape();
      if (result.numDimensions() != 2 || rshape[0] != 1) {
        throw new RuntimeException(
            String.format(
                "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                Arrays.toString(rshape)));
      }
      int nlabels = (int) rshape[1];
      return result.copyTo(new float[1][nlabels])[0];
    }
  }
}

代码示例来源:origin: jdye64/nifi-addons

private float[] executeInceptionGraph(byte[] graphDef, Tensor image, String feedNodeName, String outputNodeName) {
  try (Graph g = new Graph()) {
    g.importGraphDef(graphDef);
    try (Session s = new Session(g)) {
      Tensor result = s.runner().feed(feedNodeName, image).fetch(outputNodeName).run().get(0);
      final long[] rshape = result.shape();
      if (result.numDimensions() != 2 || rshape[0] != 1) {
        throw new RuntimeException(
            String.format(
                "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                Arrays.toString(rshape)));
      }
      int nlabels = (int) rshape[1];
      return result.copyTo(new float[1][nlabels])[0];
    }
  }
}

代码示例来源:origin: org.bytedeco.javacpp-presets/tensorflow

private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
 try (Graph g = new Graph()) {
  g.importGraphDef(graphDef);
  try (Session s = new Session(g);
    // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
    Tensor<Float> result =
      s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
   final long[] rshape = result.shape();
   if (result.numDimensions() != 2 || rshape[0] != 1) {
    throw new RuntimeException(
      String.format(
        "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
        Arrays.toString(rshape)));
   }
   int nlabels = (int) rshape[1];
   return result.copyTo(new float[1][nlabels])[0];
  }
 }
}

相关文章