Browse Source

测试只保留大数据推送,服务器内存使用情况

louhr 5 years atrás
parent
commit
f55dc66774
1 changed files with 14 additions and 14 deletions
  1. 14 14
      algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

+ 14 - 14
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -107,25 +107,25 @@ public class TensorflowModel {
                 FloatBuffer.wrap(inputValues)
         );
 
-//        float[][] fl = new float[numExamples][NUM_LABEL];
-//        for (int i = 0; i < numExamples; i++) {
-//            for (int j = 0; j < NUM_LABEL; j++) {
-//                fl[i][j] = 0.1f;
-//            }
-//        }
+        float[][] fl = new float[numExamples][NUM_LABEL];
+        for (int i = 0; i < numExamples; i++) {
+            for (int j = 0; j < NUM_LABEL; j++) {
+                fl[i][j] = 0.1f;
+            }
+        }
 //        return fl;
         float[][] result = null;
 
         // 序列数据
         if (this.withSequenceInputs){
             Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
-
-            result = this.session.runner().feed(this.X, inputTensor)
-                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
-                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
-                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                    .fetch(this.SOFT_MAX).run().get(0)
-                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+//
+//            result = this.session.runner().feed(this.X, inputTensor)
+//                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
+//                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
+//                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+//                    .fetch(this.SOFT_MAX).run().get(0)
+//                    .copyTo(new float[numExamples][this.NUM_LABEL]);
 
             for (Map.Entry<String, Tensor<Integer>> entry : sequenceTensorMap.entrySet()) {
                 entry.getValue().close();
@@ -138,7 +138,7 @@ public class TensorflowModel {
         }
         inputTensor.close();
 
-        return result;
+        return fl;
     }