Ver código fonte

1- 添加三元组过滤规则。

bijl 5 anos atrás
pai
commit
671572f61c

+ 93 - 38
algorithm/src/main/java/org/algorithm/core/RuleCheckMachine.java

@@ -19,6 +19,7 @@ public class RuleCheckMachine {
     private final List<FilterRule> filterRules = new ArrayList<>();
     private Map<String, Map<String, Set<Integer>>> key_1_map = null;
     private Map<String, Map<String, Set<Integer>>> key_2_map = null;
+    private Map<String, String> punctuations = new HashMap<>();
     private Map<String, Set<Integer>> despiteMap = null;  // 实体名:[规则uuid列表]
     private Map<String, Set<Integer>> despiteInsideMap = null; // 实体名:[规则uuid列表]
     private Map<String, Map<String, Set<Integer>>> insideMap = null;
@@ -65,36 +66,33 @@ public class RuleCheckMachine {
                 String despite = rs.getString("despite");
                 String despite_inside = rs.getString("despite_inside");
 
-                String[] insideSplit = inside.split("或");
                 String[] despiteSplit = despite.split(",");
                 String[] despiteInsideSplit = despite_inside.split(",");
-                for (int i = 0; i < insideSplit.length; i++) {
-                    for (int j = 0; j < despiteSplit.length; j++) {
-                        for (int k = 0; k < despiteInsideSplit.length; k++) {
-                            Map<String, String> variableMap = new HashMap<>();
-                            variableMap.put("key_1", key_1);
-                            variableMap.put("type_1", type_1);
+                for (int j = 0; j < despiteSplit.length; j++) {
+                    for (int k = 0; k < despiteInsideSplit.length; k++) {
+                        Map<String, String> variableMap = new HashMap<>();
+                        variableMap.put("key_1", key_1);
+                        variableMap.put("type_1", type_1);
 
-                            variableMap.put("key_2", key_2);
-                            variableMap.put("type_2", type_2);
+                        variableMap.put("key_2", key_2);
+                        variableMap.put("type_2", type_2);
 
-                            variableMap.put("inside", insideSplit[i]);
-                            variableMap.put("inside_type", inside_type);
+                        variableMap.put("inside", inside);
+                        variableMap.put("inside_type", inside_type);
 
-                            variableMap.put("despite", despiteSplit[j]);
-                            variableMap.put("despite_inside", despiteInsideSplit[k]);
+                        variableMap.put("despite", despiteSplit[j]);
+                        variableMap.put("despite_inside", despiteInsideSplit[k]);
 
-                            FilterRule filterRule = new FilterRule(variableMap);
-                            filterRule.setUuid(uuid);
-                            this.filterRules.add(filterRule);
+                        FilterRule filterRule = new FilterRule(variableMap);
+                        filterRule.setUuid(uuid);
+                        this.filterRules.add(filterRule);
 
-                            // TODO:delete
-                            System.out.println(filterRule);
+//                            System.out.println(filterRule);
 
-                            uuid += 1;
-                        }
+                        uuid += 1;
                     }
                 }
+
             }
 
         } catch (SQLException e) {
@@ -168,10 +166,13 @@ public class RuleCheckMachine {
             String inside = rule.getInside();
             String insideType = rule.getInsideType();
             Integer uuid = rule.getUuid();
+            if (insideType.equals("punc"))
+                this.punctuations.put(inside, inside);
+
             if (",".equals(inside.substring(0, 1)))
-                this.inputMaps(inside, inside, uuid, null, punctuationMap, typePunctuationMap, wordMap);
+                this.inputMaps(inside, insideType, uuid, null, typePunctuationMap, wordMap, punctuationMap);
             else
-                this.inputMaps(inside, inside, uuid, null, punctuationMap, typeMap, wordMap);
+                this.inputMaps(inside, insideType, uuid, null, typeMap, wordMap, punctuationMap);
         }
         this.insideMap = insideMap_;
     }
@@ -346,50 +347,63 @@ public class RuleCheckMachine {
 
         Set<Integer> set = null;
         set = this.despiteMap.get(entity_1_name);  // 过滤有实体1名为例外情况(即,不成立)的规则(的uuid)
-        remainUuids.removeAll(set);
+        this.removeAll(remainUuids, set);
 
         set = this.despiteMap.get(entity_2_name);  // 过滤有实体2名为例外情况(即,不成立)的规则(的uuid)
-        remainUuids.removeAll(set);
+        this.removeAll(remainUuids, set);
 
         // 过滤中间实体的名称触发例外条件情况
         for (int i = startIndex; i <= endIndex; i++) {
             NameTypeStartPosition nameTypeStartPosition = nameTypeStartPositions.get(i);
             set = this.despiteInsideMap.get(nameTypeStartPosition.getName());
-            remainUuids.remove(set);
+            this.removeAll(remainUuids, set);
         }
 
         // 三板斧过滤
         // 实体1,过滤
         set = new HashSet<>();
-        set.addAll(this.key_1_map.get("").get(""));
-        set.addAll(this.key_1_map.get("type").get(entity_1_type)); // 满足,形如("形容词", "type") 过滤条件的规则
-        set.addAll(this.key_1_map.get("word").get(entity_1_name)); // 满足,形如("胸痛", "word") 过滤条件的规则
-        remainUuids.retainAll(set);  // 求交集,同事满足实体1相关的过滤条件,且不不满足例外情况
+        this.addAll(set, this.key_1_map.get("").get(""));
+        // 满足,形如("形容词", "type") 过滤条件的规则
+        this.addAll(set, this.key_1_map.get("type").get(entity_1_type));
+        // 满足,形如("胸痛", "word") 过滤条件的规则
+        this.addAll(set, this.key_1_map.get("word").get(entity_1_name));
+        this.retainAll(remainUuids, set);  // 求交集,同事满足实体1相关的过滤条件,且不不满足例外情况
         if (remainUuids.size() == 0)
             return false;
 
         // 实体2,过滤
         set = new HashSet<>();
-        set.addAll(this.key_2_map.get("").get(""));
-        set.addAll(this.key_2_map.get("type").get(entity_2_type)); // 满足,形如("形容词", "type") 过滤条件的规则
-        set.addAll(this.key_2_map.get("word").get(entity_2_name)); // 满足,形如("胸痛", "word") 过滤条件的规则
-        remainUuids.retainAll(set);  // 求交集,同事满足实体2相关的过滤条件,且不不满足例外情况
+        this.addAll(set, this.key_2_map.get("").get(""));
+        // 满足,形如("形容词", "type") 过滤条件的规则
+        this.addAll(set, this.key_2_map.get("type").get(entity_2_type));
+        // 满足,形如("胸痛", "word") 过滤条件的规则
+        this.addAll(set, this.key_2_map.get("word").get(entity_2_name));
+        this.retainAll(remainUuids, set);  // 求交集,同事满足实体1相关的过滤条件,且不不满足例外情况
         if (remainUuids.size() == 0)
             return false;
 
         // 中间实体过滤
         set = new HashSet<>();
-        set.addAll(this.insideMap.get("").get(""));
         for (int i = startIndex; i <= endIndex; i++) {
             NameTypeStartPosition nameTypeStartPosition = nameTypeStartPositions.get(i);
             // 中间实体满足,形如("胸痛", "word") 过滤条件的规则
-            set.addAll(this.insideMap.get("word").get(nameTypeStartPosition.getName()));
-            set.addAll(this.insideMap.get("type").get(nameTypeStartPosition.getType()));  // 没有逗号的
+            this.addAll(set, this.insideMap.get("word").get(nameTypeStartPosition.getName()));
+            // 中间实体满足,形如(";", "punc") 过滤条件的规则
+            this.addAll(set, this.insideMap.get("type").get(nameTypeStartPosition.getType()));  // 没有逗号的
         }
 
         int entity_1_start = nameTypeStartPositions.get(startIndex).getStartPosition();
         int entity_2_start = nameTypeStartPositions.get(endIndex).getStartPosition();
 
+        // 标点过滤
+        String aPunc = null;
+        for (int i=entity_1_start; i<entity_2_start;i++){
+            aPunc = sentence.substring(i, i+1);
+            if (this.punctuations.get(aPunc) != null)
+                this.addAll(set, this.insideMap.get("punc").get(aPunc));
+        }
+
+        // 逗号+属性 过滤
         int commaIndex = sentence.indexOf(",", entity_1_start + 1);  // 逗号位置
         String commaPadType = "";  // 逗号拼接上类型
         while (commaIndex > -1 && commaIndex < entity_2_start) {
@@ -402,18 +416,59 @@ public class RuleCheckMachine {
 
             }
         }
-        remainUuids.retainAll(set);  // 求交集,同事中间实体相关的过滤条件,且不不满足例外情况
+        this.retainAll(remainUuids, set);  // 求交集,同事中间实体相关的过滤条件,且不不满足例外情况
+
+        // TODO: 剩下的规则
+        for (FilterRule rule: this.filterRules) {
+            if (remainUuids.contains(rule.getUuid()))
+                System.out.println(rule);
+
+        }
 
         return remainUuids.size() > 0;  // 还有规则满足,则过滤
 
     }
 
+    /**
+     * 求差集,避免null和空集
+     *
+     * @param basicSet
+     * @param set
+     */
+    private void removeAll(Set<Integer> basicSet, Set<Integer> set) {
+        if (set != null && set.size() > 0)
+            basicSet.removeAll(set);
+    }
+
+    /**
+     * 求交集,避免null和空集
+     *
+     * @param basicSet
+     * @param set
+     */
+    private void addAll(Set<Integer> basicSet, Set<Integer> set) {
+        if (set != null && set.size() > 0)
+            basicSet.addAll(set);
+    }
+
+    /**
+     * 求并集,避免null和空集
+     *
+     * @param basicSet
+     * @param set
+     */
+    private void retainAll(Set<Integer> basicSet, Set<Integer> set) {
+        if (set != null && set.size() > 0)
+            basicSet.retainAll(set);
+    }
+
     /**
      * 检查并移除
      *
+     * @param sentence 句子
      * @param triads 三元组列表
      */
-    public void checkAndRemove(List<Triad> triads, String sentence) {
+    public void checkAndRemove(String sentence, List<Triad> triads) {
         List<NameTypeStartPosition> nameTypeStartPositions = this.getSortedNameTypeByPosition(triads);
         Map<Integer, Integer> startPositionToIndexMap = new HashMap<>();
         for (int i = 0; i < nameTypeStartPositions.size(); i++)

+ 17 - 3
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionEnsembleModel.java

@@ -1,6 +1,7 @@
 package org.algorithm.core.cnn.model;
 
 import org.algorithm.core.RelationTreeUtils;
+import org.algorithm.core.RuleCheckMachine;
 import org.algorithm.core.cnn.AlgorithmCNNExecutor;
 import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
 import org.algorithm.core.cnn.entity.Triad;
@@ -31,6 +32,7 @@ public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
     private RelationExtractionDataSet dataSet;
     private RelationExtractionSubModel[] subModels = new RelationExtractionSubModel[2];
     private ExecutorService executorService = Executors.newCachedThreadPool();
+    private final RuleCheckMachine ruleCheckMachine = new RuleCheckMachine();
 
     public RelationExtractionEnsembleModel() {
         // 解析路径
@@ -96,12 +98,24 @@ public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
         return inputValues;
     }
 
+
+    /**
+     * 数据预处理,包括过滤,等操作
+     * @param content
+     * @param triads
+     */
+    private void preProcess(String content, List<Triad> triads){
+        if (!(content.length() > this.dataSet.MAX_LEN) && triads.size() > 0) // 句子长度不超过MAX_LEN,有三元组
+            this.ruleCheckMachine.checkAndRemove(content, triads);
+    }
+
     @Override
     public List<Triad> execute(String content, List<Triad> triads) {
-        // 句子长度不超过MAX_LEN,有三元组
-        if (content.length() > this.dataSet.MAX_LEN || triads.size() < 1) {
+        // 预处理
+        this.preProcess(content, triads);
+        if (content.length() > this.dataSet.MAX_LEN || triads.size() < 1)  // 句子长度不超过MAX_LEN,有三元组
             return new ArrayList<>();
-        }
+
         int[][] inputValues = this.convertData(content, triads);  // shape = [3, batchSize * this.subModels.length]
         int batchSize = triads.size();
 

+ 78 - 23
algorithm/src/main/java/org/algorithm/test/TestRuleCheckMachine.java

@@ -21,43 +21,98 @@ public class TestRuleCheckMachine {
         Lemma l_1 = null;
         Lemma l_2 = null;
 
-        l_1 = new Lemma();
-        l_1.setPosition("3,4");
-        l_1.setText("剧烈");
+//        l_1 = new Lemma();
+//        l_1.setPosition("3,4");
+//        l_1.setText("剧烈");
+//
+//        l_2 = new Lemma();
+//        l_2.setPosition("5,6");
+//        l_2.setText("胸痛");
+//
+//        triad = new Triad();
+//        triad.setL_1(l_1);
+//        triad.setL_2(l_2);
+//        triads.add(triad);
+//
+//        l_1 = new Lemma();
+//        l_1.setPosition("7,8");
+//        l_1.setText("头痛");
+//        l_1.setProperty("部位");
+//
+//        l_2 = new Lemma();
+//        l_2.setPosition("9,10");
+//        l_2.setText("失眠");
+//        l_1.setProperty("反义");
+//
+//        triad = new Triad();
+//        triad.setL_1(l_1);
+//        triad.setL_2(l_2);
+//        triads.add(triad);
+//
+//        for (int i=0; i < 500; i++)
+//            triads.add(triad);
 
-        l_2 = new Lemma();
-        l_2.setPosition("5,6");
-        l_2.setText("胸痛");
+//        String sentence = "肝脏外形饱满,包膜光整,肝实质回声增强细密,分布欠均匀,血管网显示欠清晰,未见明显占位,左右肝内胆管未见明显扩张";
+//        l_1 = new Lemma();
+//        l_1.setPosition("0,1");
+//        l_1.setText("肝脏");
+//        l_1.setProperty("部位");
+//
+//        l_2 = new Lemma();
+//        l_2.setPosition("9,10");
+//        l_2.setText("光整");
+//        l_2.setProperty("属性值");
+//
+//        triad = new Triad();
+//        triad.setL_1(l_1);
+//        triad.setL_2(l_2);
+//        triads.add(triad);
 
-        triad = new Triad();
-        triad.setL_1(l_1);
-        triad.setL_2(l_2);
-        triads.add(triad);
+//        String sentence = "双卵巢大小正常,内各见十数个小卵泡回声,沿周边排列,大小约0.5-0.7cm";
+//        l_1 = new Lemma();
+//        l_1.setText("十");
+//        l_1.setPosition("11,11");
+//        l_1.setProperty("");
+//
+//        l_2 = new Lemma();
+//        l_2.setText("排列");
+//        l_2.setPosition("23,24");
+//        l_2.setProperty("辅检其他");
+//
+//        triad = new Triad();
+//        triad.setL_1(l_1);
+//        triad.setL_2(l_2);
+//        triads.add(triad);
 
+        String sentence = "双肾形态可,左肾上极见高密度小结节,中部见等低密度结节,部分突出包膜";
         l_1 = new Lemma();
-        l_1.setPosition("7,8");
-        l_1.setText("头痛");
-        l_1.setProperty("部位");
-
         l_2 = new Lemma();
-        l_2.setPosition("9,10");
-        l_2.setText("失眠");
-        l_1.setProperty("反义");
+
+        l_1.setText("脾");
+        l_2.setText("均匀");
+
+        l_1.setPosition("0,0");
+        l_2.setPosition("5,6");
+
+        l_1.setProperty("部位");
+        l_2.setProperty("属性值");
 
         triad = new Triad();
         triad.setL_1(l_1);
         triad.setL_2(l_2);
         triads.add(triad);
 
+//        for (int i=0; i < 500; i++)
+//            triads.add(triad);
 
-        String sentence = "患者剧烈胸痛,头痛失眠不安";
-        for (int i=0;i<sentence.length() - 1; i++)
-            System.out.print("" + i +  sentence.substring(i, i+1) + " ");
+        for (int i = 0; i < sentence.length() - 1; i++)
+            System.out.print("" + i + sentence.substring(i, i + 1) + " ");
 
-        System.out.println("size of triads" + triads.size());
+        System.out.println();
+        System.out.println("size of triads " + triads.size());
         RuleCheckMachine ruleCheckMachine = new RuleCheckMachine();
-        ruleCheckMachine.checkAndRemove(triads, sentence);
+        ruleCheckMachine.checkAndRemove(sentence, triads);
 
-        System.out.println("size of triads" + triads.size());
+        System.out.println("size of triads " + triads.size());
     }
 }