java tensorflow tensor访问

ffscu2ro  于 2021-07-06  发布在  Java
关注(0)|答案(0)|浏览(316)

如何使用TensorFlowJava0.2.0在Tensor中找到最大浮点值的索引?
我对TensorFlowJava相当陌生,还不清楚ndarray和tensor之间的关系。如果有更好的方法把我的列表转换成Tensor输入,请告诉我!
SavedModelBundle 我使用 SessionRunner 输入Tensor。
通缉行为如下
text=“爱它”
val=Tensor[0.12,0.2,0.68]
.预测(文本)返回-->2
java

public Integer predict(String text) {

        // Returns tokenized text in a List of MAXLEN
        List<Integer> token_ids = tokenize(text);

        // Convert List to tensor for Input
        IntDataBuffer bufferTokens = DataBuffers.ofInts(MAXLEN);
        int[] primArr = new int[MAXLEN];
        for (int i=0; i<MAXLEN; i++) {
            primArr[i] = token_ids.get(i);
        }
        bufferTokens.write(primArr);

        IntNdArray tokensMatrix = NdArrays.ofInts(Shape.of(1, MAXLEN));
        IntNdArray vector = tokensMatrix.get(0);
        vector.write(bufferTokens);

        Tensor<TInt32> input = TInt32.tensorOf(tokensMatrix);

        // Model.predict
        Tensor output = model.session()
                .runner()
                .feed("serving_default_input_ids:0", input)
                .fetch("StatefulPartitionedCall:0")
                .run() // List<Tensor<?>>
                .get(0);

        // TODO - HELP NEEDED: Extract arg max from tensor
        Tensor val = output.expect(TFloat32.DTYPE); // val = FLOAT (1) tensor with shape [1, 3]
        Integer maxIndex = ????????

        return maxIndex;
    }

运行时的模型输入、输出信息 myModel.metaGraphDef().getSignatureDefMap().get("serving_default"); 如下所示。

ModelInfo
    inputs {
        key: "input_ids"
        value {
            name: "serving_default_input_ids:0"
            dtype: DT_INT32
            tensor_shape {
                dim {
                    size: -1
                }
                dim {
                    size: MAXLEN
                }
            }
        }
    }
    outputs {
        key: "dense_3"
        value {
            name: "StatefulPartitionedCall:0"
            dtype: DT_FLOAT
            tensor_shape {
                dim {
                    size: -1
                }
                dim {
                    size: 3
                }
            }
        }
    }

提前谢谢!

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题