Преглед изворни кода

1- 提交新训练的模型。

bijl пре 6 година
родитељ
комит
b5de333969

+ 14 - 4
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionEnsembleModel.java

@@ -11,7 +11,9 @@ import org.tensorflow.Tensor;
 import java.io.File;
 import java.nio.FloatBuffer;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.*;
 
 /**
@@ -26,7 +28,7 @@ public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
     private SavedModelBundle bundle; // 模型捆绑
     private Session session;  // 会话
     private RelationExtractionDataSet dataSet;
-    private RelationExtractionSubModel[] subModels = new RelationExtractionSubModel[3];
+    private RelationExtractionSubModel[] subModels = new RelationExtractionSubModel[2];
     private ExecutorService executorService = Executors.newCachedThreadPool();
 
     public RelationExtractionEnsembleModel() {
@@ -39,9 +41,17 @@ public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
 
         this.dataSet = new RelationExtractionDataSet(dataSetPath);
         this.init(exportDir);
-        subModels[0] = new RelationExtractionSubModel("cnn_1d_low");
-        subModels[1] = new RelationExtractionSubModel("cnn_1d_lstm_low");
-        subModels[2] = new RelationExtractionSubModel("lstm_low_api");
+
+        Map<String, Tensor<Float>> cnn_1d_low_map = new HashMap<>();
+        cnn_1d_low_map.put("keep_prob",Tensor.create(1.0f, Float.class));
+        subModels[0] = new RelationExtractionSubModel("cnn_1d_low", cnn_1d_low_map);
+//        subModels[1] = new RelationExtractionSubModel("cnn_1d_lstm_low");
+
+        Map<String, Tensor<Float>> lstm_low_api_map = new HashMap<>();
+        lstm_low_api_map.put("input_keep_prob",Tensor.create(1.0f, Float.class));
+        lstm_low_api_map.put("output_keep_prob",Tensor.create(1.0f, Float.class));
+        lstm_low_api_map.put("state_keep_prob",Tensor.create(1.0f, Float.class));
+        subModels[1] = new RelationExtractionSubModel("lstm_low_api", lstm_low_api_map);
     }
 
     /**

+ 21 - 6
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionSubModel.java

@@ -4,10 +4,12 @@ import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
 import org.diagbot.pub.utils.PropertiesUtil;
 import org.tensorflow.SavedModelBundle;
 import org.tensorflow.Session;
+import org.tensorflow.Session.Runner;
 import org.tensorflow.Tensor;
 
 import java.io.File;
 import java.nio.IntBuffer;
+import java.util.Map;
 
 /**
  * @Author: bijl
@@ -18,18 +20,20 @@ public class RelationExtractionSubModel {
     private final String X_PLACEHOLDER = "X";
     private final String pos1_PLACEHOLDER = "pos1";
     private final String pos2_PLACEHOLDER = "pos2";
+    private Map<String, Tensor<Float>> keep_probs = null;
     private String PREDICTION = null;
     private final int NUM_LABEL = 1;
     private SavedModelBundle bundle; // 模型捆绑
     private Session session;  // 会话
     protected RelationExtractionDataSet dataSet;
 
-    public RelationExtractionSubModel(String modelName) {
+    public RelationExtractionSubModel(String modelName, Map<String, Tensor<Float>> keep_probs) {
+
+        this.keep_probs = keep_probs;
 
         PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
 
         String modelsPath = prop.getProperty("basicPath");  // 模型基本路径
-        String re_path = prop.getProperty("relationExtraction");  // 模型基本路径
         this.PREDICTION = modelName + "/prediction/Sigmoid";
         String exportDir = modelsPath.replace("model_version_replacement", modelName);
         String dataSetPath = modelsPath.substring(0, modelsPath.indexOf("model_version_replacement"));
@@ -76,12 +80,23 @@ public class RelationExtractionSubModel {
                 IntBuffer.wrap(inputValues[2])
         );
 
-        return this.session.runner()
+//        return this.session.runner()
+//                .feed(this.X_PLACEHOLDER, charId)  // 输入,字符id
+//                .feed(this.pos1_PLACEHOLDER, pos1)  // 输入,相对位置1
+//                .feed(this.pos2_PLACEHOLDER, pos2)  //  输入,相对位置2
+//                .feed("keep_prob", Tensor.create(1.0f, Float.class))  // 输入,dropout保留率
+//                .fetch(this.PREDICTION).run().get(0)  // 输出,tensor
+//                .copyTo(new float[batchSize][this.NUM_LABEL]);  // tensor转float[]对象
+
+        Runner runner = this.session.runner()  // 不同子模型共用的输入
                 .feed(this.X_PLACEHOLDER, charId)  // 输入,字符id
                 .feed(this.pos1_PLACEHOLDER, pos1)  // 输入,相对位置1
-                .feed(this.pos2_PLACEHOLDER, pos2)  //  输入,相对位置2
-                .feed("keep_prob", Tensor.create(1.0f, Float.class))  // 输入,dropout保留率
-                .fetch(this.PREDICTION).run().get(0)  // 输出,tensor
+                .feed(this.pos2_PLACEHOLDER, pos2);  //  输入,相对位置2
+
+        for (Map.Entry<String, Tensor<Float>> entry : this.keep_probs.entrySet()) // 非共用的输入
+            runner = runner.feed(entry.getKey(), entry.getValue());
+
+        return  runner.fetch(this.PREDICTION).run().get(0)  // 输出,tensor
                 .copyTo(new float[batchSize][this.NUM_LABEL]);  // tensor转float[]对象
     }
 

+ 9 - 1
algorithm/src/main/java/org/algorithm/test/ReSubModelTest.java

@@ -1,6 +1,10 @@
 package org.algorithm.test;
 
 import org.algorithm.core.cnn.model.RelationExtractionSubModel;
+import org.tensorflow.Tensor;
+
+import java.util.HashMap;
+import java.util.Map;
 
 /**测试子模型加载
  * @Author: bijl
@@ -12,7 +16,11 @@ public class ReSubModelTest {
     public static void main(String[] args) {
 //        RelationExtractionSubModel subModel = new RelationExtractionSubModel("cnn_1d_low");
 //        RelationExtractionSubModel subModel = new RelationExtractionSubModel("cnn_1d_lstm_low");
-        RelationExtractionSubModel subModel = new RelationExtractionSubModel("lstm_low_api");
+        Map<String, Tensor<Float>> lstm_low_api_map = new HashMap<>();
+        lstm_low_api_map.put("input_keep_prob",Tensor.create(1.0f, Float.class));
+        lstm_low_api_map.put("output_keep_prob",Tensor.create(1.0f, Float.class));
+        lstm_low_api_map.put("state_keep_prob",Tensor.create(1.0f, Float.class));
+        RelationExtractionSubModel subModel = new RelationExtractionSubModel("lstm_low_api", lstm_low_api_map);
 
 
         int[][] inputValues = new int[3][512];