瀏覽代碼

测试只保留大数据推送,服务器内存使用情况

louhr 5 年之前
父節點
當前提交
aa6c03297e
共有 1 個文件被更改,包括 25 次插入30 次删除
  1. 25 30
      algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

+ 25 - 30
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -101,39 +101,34 @@ public class TensorflowModel {
      * @return 模型的输出
      */
     private float[][] run(float[] inputValues, Map<String, int[]> sequenceValues, int numExamples) {
-        float[][] f  = new float[numExamples][NUM_LABEL];
-        for (int i = 0; i < numExamples; i++) {
-            for (int j = 0; j < NUM_LABEL; j++) {
-                f[i][j] = 0.1f;
-            }
-        }
-        return f;
-
-//        long[] inputShape = {numExamples, this.NUM_FEATURE};
-//        Tensor<Float> inputTensor = Tensor.create(
-//                inputShape,
-//                FloatBuffer.wrap(inputValues)
-//        );
-
-//        // 序列数据
-//        if (this.withSequenceInputs){
-//            Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
+        long[] inputShape = {numExamples, this.NUM_FEATURE};
+        Tensor<Float> inputTensor = Tensor.create(
+                inputShape,
+                FloatBuffer.wrap(inputValues)
+        );
 
+        float[][] result = null;
 
-//            return this.session.runner().feed(this.X, inputTensor)
-//                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
-//                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
-//                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-//                    .fetch(this.SOFT_MAX).run().get(0)
-//                    .copyTo(new float[numExamples][this.NUM_LABEL]);
-//        }else{
-//            return this.session.runner().feed(this.X, inputTensor)
-//                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-//                    .fetch(this.SOFT_MAX).run().get(0)
-//                    .copyTo(new float[numExamples][this.NUM_LABEL]);
-//        }
-
+        // 序列数据
+        if (this.withSequenceInputs){
+            Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
+            this.session.runner();
+
+            result = this.session.runner().feed(this.X, inputTensor)
+                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
+                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
+                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+                    .fetch(this.SOFT_MAX).run().get(0)
+                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+        }else{
+            result = this.session.runner().feed(this.X, inputTensor)
+                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+                    .fetch(this.SOFT_MAX).run().get(0)
+                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+        }
 
+        inputTensor.close();
+        return result;
     }