|
@@ -101,39 +101,34 @@ public class TensorflowModel {
|
|
* @return 模型的输出
|
|
* @return 模型的输出
|
|
*/
|
|
*/
|
|
private float[][] run(float[] inputValues, Map<String, int[]> sequenceValues, int numExamples) {
|
|
private float[][] run(float[] inputValues, Map<String, int[]> sequenceValues, int numExamples) {
|
|
- float[][] f = new float[numExamples][NUM_LABEL];
|
|
|
|
- for (int i = 0; i < numExamples; i++) {
|
|
|
|
- for (int j = 0; j < NUM_LABEL; j++) {
|
|
|
|
- f[i][j] = 0.1f;
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
- return f;
|
|
|
|
-
|
|
|
|
-// long[] inputShape = {numExamples, this.NUM_FEATURE};
|
|
|
|
-// Tensor<Float> inputTensor = Tensor.create(
|
|
|
|
-// inputShape,
|
|
|
|
-// FloatBuffer.wrap(inputValues)
|
|
|
|
-// );
|
|
|
|
-
|
|
|
|
-// // 序列数据
|
|
|
|
-// if (this.withSequenceInputs){
|
|
|
|
-// Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
|
|
|
|
|
|
+ long[] inputShape = {numExamples, this.NUM_FEATURE};
|
|
|
|
+ Tensor<Float> inputTensor = Tensor.create(
|
|
|
|
+ inputShape,
|
|
|
|
+ FloatBuffer.wrap(inputValues)
|
|
|
|
+ );
|
|
|
|
|
|
|
|
+ float[][] result = null;
|
|
|
|
|
|
-// return this.session.runner().feed(this.X, inputTensor)
|
|
|
|
-// .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
|
|
|
|
-// .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
|
|
|
|
-// .feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
|
|
-// .fetch(this.SOFT_MAX).run().get(0)
|
|
|
|
-// .copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
|
|
-// }else{
|
|
|
|
-// return this.session.runner().feed(this.X, inputTensor)
|
|
|
|
-// .feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
|
|
-// .fetch(this.SOFT_MAX).run().get(0)
|
|
|
|
-// .copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
|
|
-// }
|
|
|
|
-
|
|
|
|
|
|
+ // 序列数据
|
|
|
|
+ if (this.withSequenceInputs){
|
|
|
|
+ Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
|
|
|
|
+ this.session.runner();
|
|
|
|
+
|
|
|
|
+ result = this.session.runner().feed(this.X, inputTensor)
|
|
|
|
+ .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
|
|
|
|
+ .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
|
|
|
|
+ .feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
|
|
+ .fetch(this.SOFT_MAX).run().get(0)
|
|
|
|
+ .copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
|
|
+ }else{
|
|
|
|
+ result = this.session.runner().feed(this.X, inputTensor)
|
|
|
|
+ .feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
|
|
+ .fetch(this.SOFT_MAX).run().get(0)
|
|
|
|
+ .copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
|
|
+ }
|
|
|
|
|
|
|
|
+ inputTensor.close();
|
|
|
|
+ return result;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|