如何使用TensorFlowJava0.2.0在Tensor中找到最大浮点值的索引?
我对TensorFlowJava相当陌生,还不清楚ndarray和tensor之间的关系。如果有更好的方法把我的列表转换成Tensor输入,请告诉我!
从 SavedModelBundle
我使用 SessionRunner
输入Tensor。
通缉行为如下
text=“爱它”
val=Tensor[0.12,0.2,0.68]
.预测(文本)返回-->2
java
public Integer predict(String text) {
// Returns tokenized text in a List of MAXLEN
List<Integer> token_ids = tokenize(text);
// Convert List to tensor for Input
IntDataBuffer bufferTokens = DataBuffers.ofInts(MAXLEN);
int[] primArr = new int[MAXLEN];
for (int i=0; i<MAXLEN; i++) {
primArr[i] = token_ids.get(i);
}
bufferTokens.write(primArr);
IntNdArray tokensMatrix = NdArrays.ofInts(Shape.of(1, MAXLEN));
IntNdArray vector = tokensMatrix.get(0);
vector.write(bufferTokens);
Tensor<TInt32> input = TInt32.tensorOf(tokensMatrix);
// Model.predict
Tensor output = model.session()
.runner()
.feed("serving_default_input_ids:0", input)
.fetch("StatefulPartitionedCall:0")
.run() // List<Tensor<?>>
.get(0);
// TODO - HELP NEEDED: Extract arg max from tensor
Tensor val = output.expect(TFloat32.DTYPE); // val = FLOAT (1) tensor with shape [1, 3]
Integer maxIndex = ????????
return maxIndex;
}
运行时的模型输入、输出信息 myModel.metaGraphDef().getSignatureDefMap().get("serving_default");
如下所示。
ModelInfo
inputs {
key: "input_ids"
value {
name: "serving_default_input_ids:0"
dtype: DT_INT32
tensor_shape {
dim {
size: -1
}
dim {
size: MAXLEN
}
}
}
}
outputs {
key: "dense_3"
value {
name: "StatefulPartitionedCall:0"
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 3
}
}
}
}
提前谢谢!
暂无答案!
目前还没有任何答案,快来回答吧!