Browse Source

测试只保留大数据推送,服务器内存使用情况

louhr 5 years ago
parent
commit
d1ad548f03
1 changed files with 30 additions and 22 deletions
  1. 30 22
      algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

+ 30 - 22
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -101,29 +101,37 @@ public class TensorflowModel {
      * @return 模型的输出
      */
     private float[][] run(float[] inputValues, Map<String, int[]> sequenceValues, int numExamples) {
-        long[] inputShape = {numExamples, this.NUM_FEATURE};
-        Tensor<Float> inputTensor = Tensor.create(
-                inputShape,
-                FloatBuffer.wrap(inputValues)
-        );
-
-        // 序列数据
-        if (this.withSequenceInputs){
-            Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
-            this.session.runner();
-
-            return this.session.runner().feed(this.X, inputTensor)
-                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
-                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
-                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                    .fetch(this.SOFT_MAX).run().get(0)
-                    .copyTo(new float[numExamples][this.NUM_LABEL]);
-        }else{
-            return this.session.runner().feed(this.X, inputTensor)
-                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                    .fetch(this.SOFT_MAX).run().get(0)
-                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+        float[][] f  = new float[numExamples][NUM_LABEL];
+        for (int i = 0; i < numExamples; i++) {
+            for (int j = 0; j < NUM_LABEL; j++) {
+                f[i][j] = 0.1f;
+            }
         }
+        return f;
+
+//        long[] inputShape = {numExamples, this.NUM_FEATURE};
+//        Tensor<Float> inputTensor = Tensor.create(
+//                inputShape,
+//                FloatBuffer.wrap(inputValues)
+//        );
+//
+//        // 序列数据
+//        if (this.withSequenceInputs){
+//            Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
+//            this.session.runner();
+//
+//            return this.session.runner().feed(this.X, inputTensor)
+//                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
+//                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
+//                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+//                    .fetch(this.SOFT_MAX).run().get(0)
+//                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+//        }else{
+//            return this.session.runner().feed(this.X, inputTensor)
+//                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+//                    .fetch(this.SOFT_MAX).run().get(0)
+//                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+//        }
     }