소스 검색

测试只保留大数据推送,服务器内存使用情况

louhr 5 년 전
부모
커밋
51068dcb8b
1개의 변경된 파일29개의 추가작업 그리고 24개의 파일을 삭제
  1. 29 24
      algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

+ 29 - 24
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -107,33 +107,38 @@ public class TensorflowModel {
                 FloatBuffer.wrap(inputValues)
         );
 
-        inputTensor.close();
+//        float[][] fl = new float[numExamples][NUM_LABEL];
+//        for (int i = 0; i < numExamples; i++) {
+//            for (int j = 0; j < NUM_LABEL; j++) {
+//                fl[i][j] = 0.1f;
+//            }
+//        }
+//        return fl;
+        float[][] result = null;
+
+        // 序列数据
+        if (this.withSequenceInputs){
+            Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
 
-        float[][] fl = new float[numExamples][NUM_LABEL];
-        for (int i = 0; i < numExamples; i++) {
-            for (int j = 0; j < NUM_LABEL; j++) {
-                fl[i][j] = 0.1f;
+            result = this.session.runner().feed(this.X, inputTensor)
+                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
+                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
+                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+                    .fetch(this.SOFT_MAX).run().get(0)
+                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+
+            for (Map.Entry<String, Tensor<Integer>> entry : sequenceTensorMap.entrySet()) {
+                entry.getValue().close();
             }
+        }else{
+            result = this.session.runner().feed(this.X, inputTensor)
+                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+                    .fetch(this.SOFT_MAX).run().get(0)
+                    .copyTo(new float[numExamples][this.NUM_LABEL]);
         }
-        return fl;
-//
-//        // 序列数据
-//        if (this.withSequenceInputs){
-//            Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
-//            this.session.runner();
-//
-//            return this.session.runner().feed(this.X, inputTensor)
-//                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
-//                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
-//                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-//                    .fetch(this.SOFT_MAX).run().get(0)
-//                    .copyTo(new float[numExamples][this.NUM_LABEL]);
-//        }else{
-//            return this.session.runner().feed(this.X, inputTensor)
-//                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-//                    .fetch(this.SOFT_MAX).run().get(0)
-//                    .copyTo(new float[numExamples][this.NUM_LABEL]);
-//        }
+        inputTensor.close();
+
+        return result;
     }