Browse Source

1- 添加辅检数据到关系树分枝的代码。

bijl 5 years ago
parent
commit
c751d41e2c

+ 211 - 0
algorithm/src/main/java/org/algorithm/core/RelationTreeUtils.java

@@ -0,0 +1,211 @@
+package org.algorithm.core;
+
+import org.algorithm.core.cnn.entity.Lemma;
+import org.algorithm.core.cnn.entity.Triad;
+
+import java.util.*;
+
+/**
+ * 关系树工具类
+ *
+ * @Author: bijl
+ * @Date: 2019/9/5 15:16
+ * @Description:
+ */
+public class RelationTreeUtils {
+
+    /**
+     * 同名实体(这里也叫词项)归并
+     * 规则:
+     *  1- 直接替代为位置最前面的一个
+     *
+     * @param triads 实体对列表
+     */
+    public static void sameTextLemmaMerge(List<Triad> triads) {
+
+        Map<String, Lemma> lemmaMap = new HashMap<>();
+        for (Triad triad : triads) {
+            Lemma l1 = triad.getL_1();
+            Lemma l2 = triad.getL_2();
+
+            if (lemmaMap.get(l1.getText()) == null)
+                lemmaMap.put(l1.getText(), l1);
+            else {
+                Lemma l1Pre = lemmaMap.get(l1.getText());
+                if (l1Pre.getStartPosition() > l1.getStartPosition())
+                    triad.setL_1(l1);  // 取靠前的
+            }
+
+            if (lemmaMap.get(l2.getText()) == null)
+                lemmaMap.put(l2.getText(), l2);
+            else {
+                Lemma l2Pre = lemmaMap.get(l2.getText());
+                if (l2Pre.getStartPosition() > l2.getStartPosition())
+                    triad.setL_2(l2);  // 取靠前的
+            }
+        }
+        for (Triad triad : triads) {
+            Lemma l1 = triad.getL_1();
+            Lemma l2 = triad.getL_2();
+            triad.setL_1(lemmaMap.get(l1.getText()));  // 用前面的同名实体(这里也叫词项)替代后面的
+            triad.setL_2(lemmaMap.get(l2.getText()));  // 用前面的同名实体(这里也叫词项)替代后面的
+        }
+    }
+
+    /**
+     * 构建关系树
+     * 基本规则:
+     *  1- 两个有关系的实体,前面的为父节点,后面的为子节点
+     *
+     * @param triads 有关系的三元组列表
+     */
+    public static void buildRelationTree(List<Triad> triads) {
+        for (Triad triad : triads) {
+            Lemma l1 = triad.getL_1();
+            Lemma l2 = triad.getL_2();
+            if (l1.getStartPosition() < l2.getStartPosition()) {  // 在前者为父节点
+                l1.setHasChildren(true);
+                l2.setParent(l1);
+            } else {
+                l2.setHasChildren(true);
+                l1.setParent(l2);
+            }
+        }
+    }
+
+    /**
+     * 获取关系树的分枝
+     *
+     * @param projectName 项目名称,如:核磁共振
+     * @param triads      有关系,并且设置了父子节点关系的三元组
+     */
+    public static Object[] getRelationTreeBranches(String projectName, List<Triad> triads) {
+        List<Lemma> hasNoChildrenLemmas = new ArrayList<>();
+        for (Triad triad : triads) {
+            if (!triad.getL_1().isHasChildren())
+                hasNoChildrenLemmas.add(triad.getL_1());
+
+            if (!triad.getL_2().isHasChildren())
+                hasNoChildrenLemmas.add(triad.getL_2());
+        }
+
+        List<List<String>> branches = new ArrayList<>();
+        for (Lemma lemma : hasNoChildrenLemmas) {
+            List<String> aBranch = new ArrayList<>();
+            while (lemma != null) {
+                aBranch.add(lemma.getText());  // 只加入,文本
+                lemma = lemma.getParent();
+            }
+            branches.addAll(permute(aBranch));  // 排列
+        }
+
+        Object[] obj = {projectName, branches};
+
+        return obj;
+    }
+
+    /**
+     * 从三元组列表到关系树分枝
+     *
+     * @param projectName
+     * @param triads
+     * @return
+     */
+    public static Object[] triadsToRelationTreeBranches(String projectName, List<Triad> triads) {
+        sameTextLemmaMerge(triads);
+        buildRelationTree(triads);
+        Object[] obj = getRelationTreeBranches("胃造影", triads);
+        return obj;
+    }
+
+    /**
+     * 全排列算法
+     *
+     * @param stringList 字符串列表
+     * @return
+     */
+    public static ArrayList<ArrayList<String>> permute(List<String> stringList) {
+        ArrayList<ArrayList<String>> result = new ArrayList<ArrayList<String>>();
+        result.add(new ArrayList<String>());
+
+        for (int i = 0; i < stringList.size(); i++) {
+            //list of list in current iteration of the stringList num
+            ArrayList<ArrayList<String>> current = new ArrayList<ArrayList<String>>();
+
+            for (ArrayList<String> l : result) {
+                // # of locations to insert is largest index + 1
+                for (int j = 0; j < l.size() + 1; j++) {
+                    // + add num[i] to different locations
+                    l.add(j, stringList.get(i));
+
+                    ArrayList<String> temp = new ArrayList<String>(l);
+                    current.add(temp);
+
+                    // - remove num[i] add
+                    l.remove(j);
+                }
+            }
+
+            result = new ArrayList<ArrayList<String>>(current);
+        }
+
+        return result;
+    }
+
+
+    /**
+     * 测试文件
+     */
+    public static void test() {
+
+        List<Triad> triads = new ArrayList<>();
+        Lemma l1_1 = new Lemma();
+        Lemma l1_2 = new Lemma();
+        l1_1.setText("子宫");
+        l1_1.setPosition("0,2");
+
+        l1_2.setText("内膜");
+        l1_2.setPosition("5,8");
+
+        Triad triad_1 = new Triad();
+        triad_1.setL_1(l1_1);
+        triad_1.setL_2(l1_2);
+        triads.add(triad_1);
+
+        Lemma l2_1 = new Lemma();
+        Lemma l2_2 = new Lemma();
+        l2_1.setText("宫颈线");
+        l2_1.setPosition("11,13");
+
+        l2_2.setText("很长");
+        l2_2.setPosition("15,18");
+
+        Triad triad_2 = new Triad();
+        triad_2.setL_1(l2_1);
+        triad_2.setL_2(l2_2);
+        triads.add(triad_2);
+
+
+        Lemma l3_1 = new Lemma();
+        Lemma l3_2 = new Lemma();
+
+        l3_1.setText("内膜");
+        l3_1.setPosition("5,8");
+
+        l3_2.setText("出血");
+        l3_2.setPosition("9,10");
+
+        Triad triad_3 = new Triad();
+        triad_3.setL_1(l3_1);
+        triad_3.setL_2(l3_2);
+        triads.add(triad_3);
+
+        sameTextLemmaMerge(triads);
+        buildRelationTree(triads);
+        Object[] obj = getRelationTreeBranches("胃造影", triads);
+
+        System.out.println(obj[0]);
+        System.out.println(obj[1]);
+    }
+
+}

+ 1 - 1
algorithm/src/main/java/org/algorithm/core/cnn/dataset/RelationExtractionDataSet.java

@@ -17,7 +17,7 @@ import com.alibaba.fastjson.JSONObject;
 public class RelationExtractionDataSet {
 
     private Map<String, Integer> char2id = new HashMap<>();
-    public final int MAX_LEN = 512;
+    public final int MAX_LEN = 256;
 
 
     public RelationExtractionDataSet(String dir) {

+ 26 - 0
algorithm/src/main/java/org/algorithm/core/cnn/entity/Lemma.java

@@ -16,6 +16,32 @@ public class Lemma {
     private int len;
     private String property;
 
+    private Lemma parent;
+    private boolean hasChildren;
+
+    public Lemma getParent() {
+        return parent;
+    }
+
+    public void setParent(Lemma parent) {
+        this.parent = parent;
+    }
+
+
+    public boolean isHasChildren() {
+        return hasChildren;
+    }
+
+    public void setHasChildren(boolean hasChildren) {
+        this.hasChildren = hasChildren;
+    }
+
+    public int getStartPosition(){
+        String[] pos = this.position.split(",");
+        return Integer.parseInt(pos[0]);
+    }
+
+
     private List<Lemma> relationLemmas = new ArrayList<>();
 
     public String getText() {

+ 20 - 5
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionEnsembleModel.java

@@ -1,5 +1,6 @@
 package org.algorithm.core.cnn.model;
 
+import org.algorithm.core.RelationTreeUtils;
 import org.algorithm.core.cnn.AlgorithmCNNExecutor;
 import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
 import org.algorithm.core.cnn.entity.Triad;
@@ -32,6 +33,7 @@ public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
     private ExecutorService executorService = Executors.newCachedThreadPool();
 
     public RelationExtractionEnsembleModel() {
+        // 解析路径
         PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
 
         String modelsPath = prop.getProperty("basicPath");  // 模型基本路径
@@ -39,18 +41,20 @@ public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
         dataSetPath = dataSetPath + File.separator + "char2id.json";
         String exportDir = modelsPath.replace("model_version_replacement", "ensemble_model_2");
 
+        // 加载数据集和初始化集成模型
         this.dataSet = new RelationExtractionDataSet(dataSetPath);
         this.init(exportDir);
 
+        // 添加子模型系数,并加载子模型cnn_1d_low
         Map<String, Tensor<Float>> cnn_1d_low_map = new HashMap<>();
-        cnn_1d_low_map.put("keep_prob",Tensor.create(1.0f, Float.class));
+        cnn_1d_low_map.put("keep_prob", Tensor.create(1.0f, Float.class));
         subModels[0] = new RelationExtractionSubModel("cnn_1d_low", cnn_1d_low_map);
-//        subModels[1] = new RelationExtractionSubModel("cnn_1d_lstm_low");
 
+        // 添加子模型系数,并加载子模型lstm_low_api
         Map<String, Tensor<Float>> lstm_low_api_map = new HashMap<>();
-        lstm_low_api_map.put("input_keep_prob",Tensor.create(1.0f, Float.class));
-        lstm_low_api_map.put("output_keep_prob",Tensor.create(1.0f, Float.class));
-        lstm_low_api_map.put("state_keep_prob",Tensor.create(1.0f, Float.class));
+        lstm_low_api_map.put("input_keep_prob", Tensor.create(1.0f, Float.class));
+        lstm_low_api_map.put("output_keep_prob", Tensor.create(1.0f, Float.class));
+        lstm_low_api_map.put("state_keep_prob", Tensor.create(1.0f, Float.class));
         subModels[1] = new RelationExtractionSubModel("lstm_low_api", lstm_low_api_map);
     }
 
@@ -162,6 +166,17 @@ public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
         return triads;
     }
 
+    /**
+     * 从三元组列表到关系树分枝
+     *  TODO:真实与外部对接还没做,包括无实体对的情况
+     * @param projectName
+     * @param triads
+     * @return
+     */
+    public Object[] triadsToRelationTreeBranches(String projectName, List<Triad> triads) {
+        return RelationTreeUtils.triadsToRelationTreeBranches(projectName, triads);
+    }
+
 
     /**
      * @param inputValues 字符id,相对于实体1位置,相对于实体2位置

+ 5 - 39
algorithm/src/main/java/org/algorithm/test/Test.java

@@ -1,49 +1,15 @@
 package org.algorithm.test;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
 
 public class Test {
-    
-    public static void main(String[] args) {
-        
-//        Integer aa = new Integer(53);
-//        Integer bb = new Integer(954);
-//        float xx = 1.0f;
-//        for(int i=1; i< 955; i++) {
-//            xx = (float)(Math.round(1.0f * i / bb*100000))/100000;
-//            System.out.println(i+":"+xx);
-////        }
-//        String filePath = "/opt/models/model_version_replacement/model";
-//        int index = filePath.indexOf("model_version_replacement");
-//
-//        System.out.println(filePath.substring(0, index));
-//            public static void testJSONStrToJavaBeanObj(){
-//
-//        Student student = JSON.parseObject(JSON_OBJ_STR, new TypeReference<Student>() {});
-//        //Student student1 = JSONObject.parseObject(JSON_OBJ_STR, new TypeReference<Student>() {});//因为JSONObject继承了JSON,所以这样也是可以的
-//
-//        System.out.println(student.getStudentName()+":"+student.getStudentAge());
-//
-        String JSON_ARRAY_STR = "[{\"length\":4,\"offset\":0,\"property\":\"1\",\"text\":\"剑突下痛\",\"threshold\":0.0},{\"length\":2,\"offset\":4,\"property\":\"1\",\"text\":\"胀痛\",\"threshold\":0.0},{\"length\":2,\"offset\":6,\"property\":\"2\",\"text\":\"1天\",\"threshold\":0.0},{\"length\":1,\"offset\":8,\"text\":\",\",\"threshold\":0.0}]\n";
-//        JSONArray jsonArray = JSONArray.parseArray(JSON_ARRAY_STR);
-////        String jsonString = "{\"length\":4,\"offset\":0,\"property\":\"1\",\"text\":\"剑突下痛\",\"threshold\":0.0}";
-//
-//       for (int i = 0; i < jsonArray.size(); i++){
-//           JSONObject job = jsonArray.getJSONObject(i);
-//           LemmaInfo info = JSON.parseObject(job.toJSONString(), new TypeReference<LemmaInfo>() {});
-//           //Student student1 = JSONObject.parseObject(JSON_OBJ_STR, new TypeReference<Student>() {});//因为JSONObject继承了JSON,所以这样也是可以的
-//
-//           System.out.println(info.getLength()+":"+info.getText());
-//       }
 
-        int index = 0;
-        for (int i=0; i<5; i++)
-            for (int j = i+1; j< 6; j++){
-                System.out.println(i + "," + j);
-                index ++;
-            }
 
-        System.out.println(index);
+    public static void main(String[] args) {
 
     }
 
 }
+

+ 15 - 0
algorithm/src/main/java/org/algorithm/test/TestRelationTreeUtils.java

@@ -0,0 +1,15 @@
+package org.algorithm.test;
+
+import org.algorithm.core.RelationTreeUtils;
+
+/**
+ * @Author: bijl
+ * @Date: 2019/9/5 17:07
+ * @Description:
+ */
+public class TestRelationTreeUtils {
+
+    public static void main(String[] args) {
+        RelationTreeUtils.test();
+    }
+}

+ 1 - 1
algorithm/src/main/java/org/algorithm/util/MysqlConnector.java

@@ -45,7 +45,7 @@ public class MysqlConnector {
     
     /**
      * 执行sql语句
-     * @param sql
+     * @param sqls
      */
     public void executeBatch(List<String> sqls) {
         Statement stmt = null;

+ 2 - 2
algorithm/src/main/resources/algorithm.properties

@@ -1,8 +1,8 @@
 ################################ model basic url ###################################
 
 #basicPath=E:/project/push/algorithm/src/main/models/model_version_replacement/model
-basicPath=/opt/models/dev/models/model_version_replacement/model
-#basicPath=E:/xxx/model_version_replacement/model
+#basicPath=/opt/models/dev/models/model_version_replacement/model
+basicPath=E:/re_models/model_version_replacement/model
 
 ############################### current model version ################################
 diagnosisPredict.version=outpatient_556_IOE_1