瀏覽代碼

解决推送模型内存过度占用导致溢出问题

louhr 5 年之前
父節點
當前提交
5dc928004f
共有 1 個文件被更改,包括 17 次插入7 次删除
  1. 17 7
      algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

+ 17 - 7
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -107,23 +107,33 @@ public class TensorflowModel {
                 FloatBuffer.wrap(inputValues)
         );
 
+        float[][] result = null;
+        Tensor<?> t = null;
         // 序列数据
         if (this.withSequenceInputs){
             Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
-            this.session.runner();
 
-            return this.session.runner().feed(this.X, inputTensor)
+            t = this.session.runner().feed(this.X, inputTensor)
                     .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
                     .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
                     .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                    .fetch(this.SOFT_MAX).run().get(0)
-                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+                    .fetch(this.SOFT_MAX).run().get(0);
+
+            for (Map.Entry<String, Tensor<Integer>> entry : sequenceTensorMap.entrySet()) {
+                entry.getValue().close();
+            }
+
         }else{
-            return this.session.runner().feed(this.X, inputTensor)
+            t =  this.session.runner().feed(this.X, inputTensor)
                     .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                    .fetch(this.SOFT_MAX).run().get(0)
-                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+                    .fetch(this.SOFT_MAX).run().get(0);
         }
+        result = t.copyTo(new float[numExamples][this.NUM_LABEL]);
+
+        t.close();
+        inputTensor.close();
+
+        return result;
     }