|
@@ -107,23 +107,33 @@ public class TensorflowModel {
|
|
FloatBuffer.wrap(inputValues)
|
|
FloatBuffer.wrap(inputValues)
|
|
);
|
|
);
|
|
|
|
|
|
|
|
+ float[][] result = null;
|
|
|
|
+ Tensor<?> t = null;
|
|
// 序列数据
|
|
// 序列数据
|
|
if (this.withSequenceInputs){
|
|
if (this.withSequenceInputs){
|
|
Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
|
|
Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
|
|
- this.session.runner();
|
|
|
|
|
|
|
|
- return this.session.runner().feed(this.X, inputTensor)
|
|
|
|
|
|
+ t = this.session.runner().feed(this.X, inputTensor)
|
|
.feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
|
|
.feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
|
|
.feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
|
|
.feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
|
|
.feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
.feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
- .fetch(this.SOFT_MAX).run().get(0)
|
|
|
|
- .copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
|
|
|
|
+ .fetch(this.SOFT_MAX).run().get(0);
|
|
|
|
+
|
|
|
|
+ for (Map.Entry<String, Tensor<Integer>> entry : sequenceTensorMap.entrySet()) {
|
|
|
|
+ entry.getValue().close();
|
|
|
|
+ }
|
|
|
|
+
|
|
}else{
|
|
}else{
|
|
- return this.session.runner().feed(this.X, inputTensor)
|
|
|
|
|
|
+ t = this.session.runner().feed(this.X, inputTensor)
|
|
.feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
.feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
- .fetch(this.SOFT_MAX).run().get(0)
|
|
|
|
- .copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
|
|
|
|
+ .fetch(this.SOFT_MAX).run().get(0);
|
|
}
|
|
}
|
|
|
|
+ result = t.copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
|
|
+
|
|
|
|
+ t.close();
|
|
|
|
+ inputTensor.close();
|
|
|
|
+
|
|
|
|
+ return result;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|