瀏覽代碼

测试只保留大数据推送,服务器内存使用情况

louhr 5 年之前
父節點
當前提交
28c81eb563
共有 1 個文件被更改,包括 9 次插入10 次删除
  1. 9 10
      algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

+ 9 - 10
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -115,21 +115,21 @@ public class TensorflowModel {
         }
 //        return fl;
         float[][] result = null;
-
+        Session.Runner runner = null;
         // 序列数据
         if (this.withSequenceInputs){
             Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
-//
-//            result = this.session.runner().feed(this.X, inputTensor)
-//                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
-//                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
-//                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-//                    .fetch(this.SOFT_MAX).run().get(0)
-//                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+            Tensor<?> t = this.session.runner().feed(this.X, inputTensor)
+                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
+                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
+                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+                    .fetch(this.SOFT_MAX).run().get(0);
+            result = t.copyTo(new float[numExamples][this.NUM_LABEL]);
 
             for (Map.Entry<String, Tensor<Integer>> entry : sequenceTensorMap.entrySet()) {
                 entry.getValue().close();
             }
+            t.close();
         }else{
             result = this.session.runner().feed(this.X, inputTensor)
                     .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
@@ -137,8 +137,7 @@ public class TensorflowModel {
                     .copyTo(new float[numExamples][this.NUM_LABEL]);
         }
         inputTensor.close();
-
-        return fl;
+        return result;
     }