浏览代码

1- 添加集成模型。

bijl 6 年之前
父节点
当前提交
03c344fc19

+ 7 - 9
algorithm/src/main/java/org/algorithm/core/cnn/dataset/RelationExtractionDataSet.java

@@ -6,7 +6,6 @@ import java.io.IOException;
 import java.util.*;
 
 import org.algorithm.core.cnn.entity.Lemma;
-import org.algorithm.core.cnn.entity.LemmaInfo;
 import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSONObject;
 
@@ -18,8 +17,7 @@ import com.alibaba.fastjson.JSONObject;
 public class RelationExtractionDataSet {
 
     private Map<String, Integer> char2id = new HashMap<>();
-    private Map<Integer, Map<String, String>> entities_info = new HashMap<>();
-    public int maxLength = 500;
+    public final int MAX_LEN = 512;
 
 
     public RelationExtractionDataSet(String dir) {
@@ -60,7 +58,7 @@ public class RelationExtractionDataSet {
      * @return ids
      */
     public int[] sentence2ids(String sentence) {
-        int[] ids = new int[this.maxLength];
+        int[] ids = new int[this.MAX_LEN];
         char ch = '1';
         Integer id = null;
         for (int i = 0; i < sentence.length(); i++) {
@@ -71,7 +69,7 @@ public class RelationExtractionDataSet {
             }
             ids[i] = id.intValue();
         }
-        for (int i = sentence.length(); i < this.maxLength; i++)  // padding
+        for (int i = sentence.length(); i < this.MAX_LEN; i++)  // padding
             ids[i] = this.char2id.get("<PAD>");
 
         return ids;
@@ -83,7 +81,7 @@ public class RelationExtractionDataSet {
      * @return 句子中各个汉子相对于实体的位置
      */
     public int[] getRelativePositions(String sentence, String position) {
-        int[] relativePositions = new int[this.maxLength];
+        int[] relativePositions = new int[this.MAX_LEN];
         String[] positionPair = position.split(",");
         int startPos = Integer.parseInt(positionPair[0]);
         int endtPos = Integer.parseInt(positionPair[1]);
@@ -97,8 +95,8 @@ public class RelationExtractionDataSet {
                 relativePositions[i] = i - endtPos;
         }
 
-        for (int i = sentence.length(); i < this.maxLength; i++)
-            relativePositions[i] = this.maxLength - 1;
+        for (int i = sentence.length(); i < this.MAX_LEN; i++)
+            relativePositions[i] = this.MAX_LEN - 1;
 
         return relativePositions;
     }
@@ -110,7 +108,7 @@ public class RelationExtractionDataSet {
      * @return
      */
     public int[][] getExample(String sentence, Lemma entity1, Lemma entity2) {
-        int[][] example = new int[3][this.maxLength];
+        int[][] example = new int[3][this.MAX_LEN];
 
         example[0] = this.sentence2ids(sentence);
         example[1] = this.getRelativePositions(sentence, entity1.getPosition());

+ 184 - 0
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionEnsembleModel.java

@@ -0,0 +1,184 @@
+package org.algorithm.core.cnn.model;
+
+import org.algorithm.core.cnn.AlgorithmCNNExecutor;
+import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
+import org.algorithm.core.cnn.entity.Triad;
+import org.diagbot.pub.utils.PropertiesUtil;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+
+import java.io.File;
+import java.nio.FloatBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.*;
+
+/**
+ * @Author: bijl
+ * @Date: 2019/1/22 10:21
+ * @Description: 集成模型
+ */
+public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
+    private final String X_PLACEHOLDER = "X";
+    private final String PREDICTION = "prediction/prediction";
+    private final int NUM_LABEL = 1;
+    private SavedModelBundle bundle; // 模型捆绑
+    private Session session;  // 会话
+    private RelationExtractionDataSet dataSet;
+    private RelationExtractionSubModel[] subModels = new RelationExtractionSubModel[3];
+    private ExecutorService executorService = Executors.newCachedThreadPool();
+
+    public RelationExtractionEnsembleModel() {
+        PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
+
+        String modelsPath = prop.getProperty("basicPath");  // 模型基本路径
+        String dataSetPath = modelsPath.substring(0, modelsPath.indexOf("model_version_replacement"));
+        dataSetPath = dataSetPath + File.separator + "char2id.json";
+        String exportDir = modelsPath.replace("model_version_replacement", "ensemble_model_2");
+
+        this.dataSet = new RelationExtractionDataSet(dataSetPath);
+        this.init(exportDir);
+        subModels[0] = new RelationExtractionSubModel("cnn_1d_low");
+        subModels[1] = new RelationExtractionSubModel("cnn_1d_lstm_low");
+        subModels[2] = new RelationExtractionSubModel("lstm_low_api");
+    }
+
+    /**
+     * 初始化:加载模型,获取会话。
+     *
+     * @param exportDir 模型地址
+     */
+    public void init(String exportDir) {
+        /* load the model Bundle */
+        try {
+            this.bundle = SavedModelBundle.load(exportDir, "serve");
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+
+        // create the session from the Bundle
+        this.session = bundle.session();
+    }
+
+    /**
+     * 转化数据为张量
+     *
+     * @param content 句子
+     * @param triads  三元组list
+     * @return int[3][] 表示charId,pos1,pos2
+     */
+    private int[][] convertData(String content, List<Triad> triads) {
+
+        int[][] inputValues = new int[3][triads.size() * this.dataSet.MAX_LEN];
+        for (int i = 0; i < triads.size(); i++) {
+            Triad triad = triads.get(i);
+            int[][] aInput = this.dataSet.getExample(content, triad.getL_1(), triad.getL_2());
+            for (int j = 0; j < aInput.length; j++)
+                for (int k = 0; k < this.dataSet.MAX_LEN; k++)
+                    inputValues[j][i * this.dataSet.MAX_LEN] = aInput[j][k];
+        }
+
+        return inputValues;
+    }
+
+    @Override
+    public List<Triad> execute(String content, List<Triad> triads) {
+        // 句子长度不超过MAX_LEN,有三元组
+        if (content.length() > this.dataSet.MAX_LEN || triads.size() < 1) {
+            return new ArrayList<>();
+        }
+        int[][] inputValues = this.convertData(content, triads);  // shape = [3, batchSize * this.subModels.length]
+        int batchSize = triads.size();
+
+        float[] sigmoidS = new float[batchSize * this.subModels.length];  // 集成模型的输入
+
+        List<Future<float[][]>> futureList = new ArrayList<>();
+
+//         // 非并行运行子模型
+//        for (int i = 0; i < this.subModels.length; i++) {
+//            float[][] sigmoid = subModels[i].sigmoid(inputValues, batchSize);  // 子模型预测
+//            for (int j = 0; j < batchSize; j++)
+//                sigmoidS[i * batchSize + j] = sigmoid[j][0];
+//        }
+
+//         多线程运行子模型
+        for (int i = 0; i < this.subModels.length; i++) {
+            int index = i;
+            Future<float[][]> future = this.executorService.submit(new Callable<float[][]>() {
+                @Override
+                public float[][] call() throws Exception {
+                    return subModels[index].sigmoid(inputValues, batchSize);
+                }
+            });
+            futureList.add(future);
+        }
+
+        // 从future中获取数据,并填入sigmoidS中
+        for (int i = 0; i < this.subModels.length; i++) {
+            try {
+                float[][] sigmoid = futureList.get(i).get();
+                for (int j = 0; j < batchSize; j++)
+                    sigmoidS[i * batchSize + j] = sigmoid[j][0];
+
+            } catch (InterruptedException e) {
+                e.printStackTrace();
+                System.err.println("获取数据不成功");
+            } catch (ExecutionException e) {
+                e.printStackTrace();
+                System.err.println("获取数据不成功");
+            }
+        }
+//        this.executorService.shutdown();
+
+        float[][] prediction = this.run(sigmoidS, batchSize);
+
+        //设置三元组关系
+        for (int j = 0; j < prediction.length; j++) {
+            if (prediction[j][0] == 1.0)
+                triads.get(j).setRelation("有");
+            else
+                triads.get(j).setRelation("无");
+        }
+
+        //删除无关系三元组
+        List<Triad> deleteTriads = new ArrayList<>();
+        for (Triad triad : triads)
+            if ("无".equals(triad.getRelation()))  // 有关系着留下
+                deleteTriads.add(triad);
+        for (Triad triad : deleteTriads)
+            triads.remove(triad);
+
+        return triads;
+    }
+
+
+    /**
+     * @param inputValues 字符id,相对于实体1位置,相对于实体2位置
+     * @param batchSize   批量大小
+     * @return float[][] shape = [batchSize, 1]
+     */
+    private float[][] run(float[] inputValues, int batchSize) {
+        long[] shape = {batchSize, this.subModels.length};  // 老模型
+        Tensor<Float> sigmoidS = Tensor.create(
+                shape,
+                FloatBuffer.wrap(inputValues)
+        );
+
+
+        return this.session.runner()
+                .feed(this.X_PLACEHOLDER, sigmoidS)
+                .fetch(this.PREDICTION).run().get(0)
+                .copyTo(new float[batchSize][this.NUM_LABEL]);
+    }
+
+    /**
+     * 关闭会话,释放资源
+     */
+    public void close() {
+        this.session.close();
+        this.bundle.close();
+        for (RelationExtractionSubModel subModel : this.subModels)
+            subModel.close();
+    }
+}

+ 1 - 2
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionModel.java

@@ -6,7 +6,6 @@ import com.alibaba.fastjson.JSONObject;
 import com.alibaba.fastjson.TypeReference;
 import org.algorithm.core.cnn.AlgorithmCNNExecutor;
 import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
-import org.algorithm.core.cnn.entity.Lemma;
 import org.algorithm.core.cnn.entity.LemmaInfo;
 import org.algorithm.core.cnn.entity.Triad;
 import org.tensorflow.SavedModelBundle;
@@ -112,7 +111,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
      * @return
      */
     private float[][] run(int[][] inputValues, int batchSize){
-        long[] shape = {1, dataSet.maxLength};  // 老模型
+        long[] shape = {1, dataSet.MAX_LEN};  // 老模型
         Tensor<Integer> charId = Tensor.create(
                 shape,
                 IntBuffer.wrap(inputValues[0])

+ 95 - 0
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionSubModel.java

@@ -0,0 +1,95 @@
+package org.algorithm.core.cnn.model;
+
+import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
+import org.diagbot.pub.utils.PropertiesUtil;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+
+import java.io.File;
+import java.nio.IntBuffer;
+
+/**
+ * @Author: bijl
+ * @Date 2019/6/12 14:02:31
+ * @Decription:
+ */
+public class RelationExtractionSubModel {
+    private final String X_PLACEHOLDER = "X";
+    private final String pos1_PLACEHOLDER = "pos1";
+    private final String pos2_PLACEHOLDER = "pos2";
+    private String PREDICTION = null;
+    private final int NUM_LABEL = 1;
+    private SavedModelBundle bundle; // 模型捆绑
+    private Session session;  // 会话
+    protected RelationExtractionDataSet dataSet;
+
+    public RelationExtractionSubModel(String modelName) {
+
+        PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
+
+        String modelsPath = prop.getProperty("basicPath");  // 模型基本路径
+        String re_path = prop.getProperty("relationExtraction");  // 模型基本路径
+        this.PREDICTION = modelName + "/prediction/Sigmoid";
+        String exportDir = modelsPath.replace("model_version_replacement", modelName);
+        String dataSetPath = modelsPath.substring(0, modelsPath.indexOf("model_version_replacement"));
+        dataSetPath = dataSetPath + File.separator + "char2id.json";
+
+        this.dataSet = new RelationExtractionDataSet(dataSetPath);
+        this.init(exportDir);
+    }
+
+    /**
+     * 初始化:加载模型,获取会话。
+     *
+     * @param exportDir 模型dir
+     */
+    public void init(String exportDir) {
+        /* load the model Bundle */
+        try {
+            this.bundle = SavedModelBundle.load(exportDir, "serve");
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+
+        // create the session from the Bundle
+        this.session = bundle.session();
+    }
+
+    /**
+     * @param inputValues 字符id,相对于实体1位置,相对于实体2位置
+     * @param batchSize   批量大小
+     * @return float[batchSize][NUM_LABEL]
+     */
+    public float[][] sigmoid(int[][] inputValues, int batchSize) {
+        long[] shape = {batchSize, dataSet.MAX_LEN};  // 老模型
+        Tensor<Integer> charId = Tensor.create(
+                shape,
+                IntBuffer.wrap(inputValues[0])
+        );
+        Tensor<Integer> pos1 = Tensor.create(
+                shape,
+                IntBuffer.wrap(inputValues[1])
+        );
+        Tensor<Integer> pos2 = Tensor.create(
+                shape,
+                IntBuffer.wrap(inputValues[2])
+        );
+
+        return this.session.runner()
+                .feed(this.X_PLACEHOLDER, charId)  // 输入,字符id
+                .feed(this.pos1_PLACEHOLDER, pos1)  // 输入,相对位置1
+                .feed(this.pos2_PLACEHOLDER, pos2)  //  输入,相对位置2
+                .feed("keep_prob", Tensor.create(1.0f, Float.class))  // 输入,dropout保留率
+                .fetch(this.PREDICTION).run().get(0)  // 输出,tensor
+                .copyTo(new float[batchSize][this.NUM_LABEL]);  // tensor转float[]对象
+    }
+
+    /**
+     * 关闭会话,释放资源
+     */
+    public void close() {
+        this.session.close();
+        this.bundle.close();
+    }
+}

+ 75 - 0
algorithm/src/main/java/org/algorithm/test/MultiThreadsTest.java

@@ -0,0 +1,75 @@
+package org.algorithm.test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Timer;
+import java.util.concurrent.*;
+
+/**
+ * @Author: bijl
+ * @Date: 2019/6/13 10:32
+ * @Decription: 多线程测试
+ */
+public class MultiThreadsTest {
+    public static void main(String[] args) {
+        ExecutorService exe = Executors.newCachedThreadPool();
+//        for (int i = 0; i < 10; i++)
+//            exe.execute(new MyThread());
+        List<Future<Float>> futures = new ArrayList<>();
+        Future<Float> future = null;
+        for(int j=0; j<2;j ++){
+            for (int i = 0; i < 10; i++) {
+                future = exe.submit(new MyCallable("" + i));
+                System.out.println("............_" + i);
+                futures.add(future);
+
+//                try {
+//                    System.out.println(future.get());
+//                } catch (InterruptedException e) {
+//                    e.printStackTrace();
+//                } catch (ExecutionException e) {
+//                    e.printStackTrace();
+//                }
+
+            }
+            for (Future<Float> future1:futures){
+                try {
+                    System.out.println(future1.get());
+                } catch (InterruptedException e) {
+                    e.printStackTrace();
+                } catch (ExecutionException e) {
+                    e.printStackTrace();
+                }
+            }
+//            exe.shutdown();
+        }
+        System.out.println("All done.");
+    }
+}
+
+class MyThread implements Runnable {
+
+
+    public void run() {
+        System.out.println(Thread.currentThread());
+        System.out.println("............");
+    }
+}
+
+class MyCallable implements Callable<Float> {
+    private String values = null;
+
+    public MyCallable(String values) {
+        this.values = values;
+    }
+
+    @Override
+    public Float call() {
+        try {
+            Thread.sleep(5000);
+        } catch (InterruptedException e) {
+            e.printStackTrace();
+        }
+        return Float.parseFloat(this.values);
+    }
+}

+ 44 - 0
algorithm/src/main/java/org/algorithm/test/ReEnsembleModelTest.java

@@ -0,0 +1,44 @@
+package org.algorithm.test;
+
+import org.algorithm.core.cnn.entity.Lemma;
+import org.algorithm.core.cnn.entity.Triad;
+import org.algorithm.core.cnn.model.RelationExtractionEnsembleModel;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * 测试集成模型
+ *
+ * @Author: bijl
+ * @Date: 2019/6/12 17:53
+ * @Decription:
+ */
+public class ReEnsembleModelTest {
+
+    public static void main(String[] args) {
+        RelationExtractionEnsembleModel ensembleModel = new RelationExtractionEnsembleModel();
+
+        List<Triad> triads = new ArrayList<>();
+        Triad triad_1 = new Triad();
+        Lemma l_1 = new Lemma();
+        l_1.setPosition("3,4");
+        l_1.setText("剧烈");
+
+        Lemma l_2 = new Lemma();
+        l_2.setPosition("5,6");
+        l_2.setText("胸痛");
+
+        triad_1.setL_1(l_1);
+        triad_1.setL_2(l_2);
+        for (int i = 0; i < 500; i++)  // 500个样本
+            triads.add(triad_1);
+
+        long start = System.nanoTime();
+        for (int i=0; i<200; i++)  // 重复100次
+            triads = ensembleModel.execute("患者剧烈胸痛头痛失眠不安", triads);
+        long elapsedTime = System.nanoTime() - start;
+        System.out.println(triads.size());
+        System.out.println(elapsedTime);
+    }
+}

+ 27 - 0
algorithm/src/main/java/org/algorithm/test/ReSubModelTest.java

@@ -0,0 +1,27 @@
+package org.algorithm.test;
+
+import org.algorithm.core.cnn.model.RelationExtractionSubModel;
+
+/**测试子模型加载
+ * @Author: bijl
+ * @Date: 2019/6/12 16:57
+ * @Decription:
+ */
+public class ReSubModelTest {
+
+    public static void main(String[] args) {
+//        RelationExtractionSubModel subModel = new RelationExtractionSubModel("cnn_1d_low");
+//        RelationExtractionSubModel subModel = new RelationExtractionSubModel("cnn_1d_lstm_low");
+        RelationExtractionSubModel subModel = new RelationExtractionSubModel("lstm_low_api");
+
+
+        int[][] inputValues = new int[3][512];
+        for (int i=0; i< inputValues.length; i++)
+            for (int j=0; j<inputValues[0].length; j++)
+                inputValues[i][j] = 0;
+        float[][] result = subModel.sigmoid(inputValues, 1);
+        System.out.println(result[0][0]);
+        System.out.println(result.length);
+        System.out.println(result[0].length);
+    }
+}

+ 5 - 1
algorithm/src/main/java/org/algorithm/test/RelationExtractionDataSetTest.java

@@ -11,7 +11,7 @@ public class RelationExtractionDataSetTest {
 
     public static void main(String[] args) {
 
-        String filePath = "E:\\relation_extraction\\shao_yi_fu_data\\char2id.json";
+        String filePath = "E:\\char2id.json";
         String sentence = "有双手麻木感,活动后好转,颈部及肩部活动度无殊";
         RelationExtractionDataSet dataSet = new RelationExtractionDataSet(filePath);
 //        for (float id:dataSet.sentence2ids(sentence)) {
@@ -20,6 +20,10 @@ public class RelationExtractionDataSetTest {
 //        for (float id:dataSet.getRelativePositions(sentence, "1,2")) {
 //            System.out.println(id); //pass
 //        }
+        for (float id:dataSet.sentence2ids(sentence)) {
+            System.out.println(id);
+        }
+        System.out.println("...............");
 
         for (float id:dataSet.getRelativePositions(sentence, "1,2")) {
             System.out.println(id);

+ 29 - 0
algorithm/src/main/java/org/algorithm/test/TensorMethodTest.java

@@ -0,0 +1,29 @@
+package org.algorithm.test;
+
+import org.tensorflow.Tensor;
+
+import java.nio.FloatBuffer;
+
+/**
+ * 测试tensorflow Tensor方法
+ *
+ * @Author: bijl
+ * @Date: 2019/6/12 16:44
+ * @Decription:
+ */
+public class TensorMethodTest {
+
+    public static void main(String[] args) {
+        float[] inputValues = {1, 2, 3, 4, 5, 6};
+        long[] shape = {2, 3};  // 老模型
+        Tensor<Float> tensor = Tensor.create(
+                shape,
+                FloatBuffer.wrap(inputValues)
+        );
+        float[][] xx = tensor.copyTo(new float[2][3]);
+        for (int i = 0; i < xx.length; i++)
+            for (int j = 0; j < xx[0].length; j++)
+                System.out.println(i + "," + j + " --> " + xx[i][j]);
+
+    }
+}

+ 2 - 2
algorithm/src/main/resources/algorithm.properties

@@ -1,7 +1,7 @@
 ################################ model basic url ###################################
 
 #basicPath=E:/git/push/algorithm/src/main/models/model_version_replacement/model
-basicPath=/opt/models/dev/models/model_version_replacement/model
+#basicPath=/opt/models/dev/models/model_version_replacement/model
 #basicPath=E:/xxx/model_version_replacement/model
 
 ############################### current model version ################################
@@ -18,4 +18,4 @@ diagnosisToVital.version=diagnosis_to_vital_1
 
 ############################ relation extraction ######################################
 relationExtraction=relation_extraction
-relationExtractionUrl=http://192.168.2.234:54321/api/re/predict
+relationExtractionUrl=http://192.168.3.40:54321/api/re/predict