Просмотр исходного кода

pacs关系抽取模型入参出参代码实现

hujing 5 лет назад
Родитель
Сommit
0da9113a18

+ 15 - 8
algorithm/src/main/java/org/algorithm/core/RelationTreeUtils.java

@@ -64,10 +64,12 @@ public class RelationTreeUtils {
             Lemma l1 = triad.getL_1();
             Lemma l2 = triad.getL_2();
             if (l1.getStartPosition() < l2.getStartPosition()) {  // 在前者为父节点
-                l1.setHasChildren(true);
+                l1.setLeafNode(false);
+                l2.setLeafNode(true);
                 l2.setParent(l1);
             } else {
-                l2.setHasChildren(true);
+                l1.setLeafNode(true);
+                l2.setLeafNode(false);
                 l1.setParent(l2);
             }
         }
@@ -79,17 +81,22 @@ public class RelationTreeUtils {
      * @param triads      有关系,并且设置了父子节点关系的三元组
      */
     public static List<List<String>> getRelationTreeBranches(List<Triad> triads) {
-        List<Lemma> hasNoChildrenLemmas = new ArrayList<>();
+        Map<Lemma, Integer> leafNodeLemmas = new HashMap<>();
+
         for (Triad triad : triads) {
-            if (!triad.getL_1().isHasChildren())
-                hasNoChildrenLemmas.add(triad.getL_1());
+            if (triad.getL_1().isLeafNode()){
+                if(leafNodeLemmas.get(triad.getL_1()) == null)
+                    leafNodeLemmas.put(triad.getL_1(), 1);
+            }
 
-            if (!triad.getL_2().isHasChildren())
-                hasNoChildrenLemmas.add(triad.getL_2());
+            if (triad.getL_2().isLeafNode()){
+                if(leafNodeLemmas.get(triad.getL_2()) == null)
+                    leafNodeLemmas.put(triad.getL_2(), 1);
+            }
         }
 
         List<List<String>> branches = new ArrayList<>();
-        for (Lemma lemma : hasNoChildrenLemmas) {
+        for (Lemma lemma : leafNodeLemmas.keySet()) {
             List<String> aBranch = new ArrayList<>();
             while (lemma != null) {
                 aBranch.add(lemma.getText());  // 只加入,文本

+ 1 - 1
algorithm/src/main/java/org/algorithm/core/RuleCheckMachine.java

@@ -415,7 +415,7 @@ public class RuleCheckMachine {
                     NameTypeStartPosition nameTypeStartPosition = nameTypeStartPositions.get(i);
                     if (nameTypeStartPosition.getStartPosition() > commaIndex) {
                         commaPadType = "," + nameTypeStartPosition.getType();
-                        set.addAll(this.insideMap.get("typePunctuation").get(commaPadType));
+                        this.addAll(set, this.insideMap.get("typePunctuation").get(commaPadType));
                     }
 
                 }

+ 2 - 2
algorithm/src/main/java/org/algorithm/core/cnn/AlgorithmCNNExecutor.java

@@ -16,7 +16,7 @@ public abstract class AlgorithmCNNExecutor {
      *
      * @param content 输入句子
      * @param triads 实体列表(三元组列表)
-     * @return
+     * @return  [[有关系的一系列词]]
      */
-    public abstract List<List<String>> execute(String content, List<Triad> triads);
+    public abstract List<Triad> execute(String content, List<Triad> triads);
 }

+ 22 - 0
algorithm/src/main/java/org/algorithm/core/cnn/AlgorithmCNNExecutorPacs.java

@@ -0,0 +1,22 @@
+package org.algorithm.core.cnn;
+
+import org.algorithm.core.cnn.entity.Triad;
+
+import java.util.List;
+
+/**
+ * @ClassName org.algorithm.core.cnn.model.AlgorithmCNNExecutor
+ * @Description
+ * @Author fyeman
+ * @Date 2019/1/17/017 19:18
+ * @Version 1.0
+ **/
+public abstract class AlgorithmCNNExecutorPacs {
+    /**
+     *
+     * @param content 输入句子
+     * @param triads 实体列表(三元组列表)
+     * @return  [[有关系的一系列词]]
+     */
+    public abstract List<List<String>>  execute(String content, List<Triad> triads);
+}

+ 10 - 9
algorithm/src/main/java/org/algorithm/core/cnn/entity/Lemma.java

@@ -17,23 +17,24 @@ public class Lemma {
     private String property;
 
     private Lemma parent;
-    private boolean hasChildren;
 
-    public Lemma getParent() {
-        return parent;
+    private boolean isLeafNode;
+
+    public boolean isLeafNode() {
+        return isLeafNode;
     }
 
-    public void setParent(Lemma parent) {
-        this.parent = parent;
+    public void setLeafNode(boolean leafNode) {
+        isLeafNode = leafNode;
     }
 
 
-    public boolean isHasChildren() {
-        return hasChildren;
+    public Lemma getParent() {
+        return parent;
     }
 
-    public void setHasChildren(boolean hasChildren) {
-        this.hasChildren = hasChildren;
+    public void setParent(Lemma parent) {
+        this.parent = parent;
     }
 
     public int getStartPosition(){

+ 2 - 2
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionEnsembleModel.java

@@ -2,7 +2,7 @@ package org.algorithm.core.cnn.model;
 
 import org.algorithm.core.RelationTreeUtils;
 import org.algorithm.core.RuleCheckMachine;
-import org.algorithm.core.cnn.AlgorithmCNNExecutor;
+import org.algorithm.core.cnn.AlgorithmCNNExecutorPacs;
 import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
 import org.algorithm.core.cnn.entity.Triad;
 import org.diagbot.pub.utils.PropertiesUtil;
@@ -23,7 +23,7 @@ import java.util.concurrent.*;
  * @Date: 2019/1/22 10:21
  * @Description: 集成模型
  */
-public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
+public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutorPacs {
     private final String X_PLACEHOLDER = "X";
     private final String PREDICTION = "prediction/prediction";
     private final int NUM_LABEL = 1;

+ 4 - 4
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionModel.java

@@ -4,7 +4,7 @@ import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSONArray;
 import com.alibaba.fastjson.JSONObject;
 import com.alibaba.fastjson.TypeReference;
-import org.algorithm.core.cnn.AlgorithmCNNExecutor;
+import org.algorithm.core.cnn.AlgorithmCNNExecutorPacs;
 import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
 import org.algorithm.core.cnn.entity.LemmaInfo;
 import org.algorithm.core.cnn.entity.Triad;
@@ -21,7 +21,7 @@ import java.util.List;
  * @Date: 2019/1/22 10:21
  * @Decription:
  */
-public class RelationExtractionModel extends AlgorithmCNNExecutor {
+public class RelationExtractionModel extends AlgorithmCNNExecutorPacs {
 //    self.X = tf.placeholder(tf.int32, shape=[None, self.max_length], name='X')
 //    self.pos1 = tf.placeholder(tf.int32, shape=[None, self.max_length], name='pos1')
 //    self.pos2 = tf.placeholder(tf.int32, shape=[None, self.max_length], name='pos2')
@@ -54,7 +54,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
     }
 
     @Override
-    public List<Triad> execute(String content, List<Triad> triads) {
+    public List<List<String>> execute(String content, List<Triad> triads) {
 //        List<Lemma[]> combinations = new ArrayList<>();
 //        // 组合
 //        for(int i=0; i < lemmas.size() - 1; i++){  // 两两组合成实体对
@@ -83,7 +83,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
 //            }
 //
 //        }
-        return triads;
+        return null;
     }
 
     /**

+ 33 - 0
algorithm/src/main/java/org/algorithm/factory/RelationExtractionFactory.java

@@ -0,0 +1,33 @@
+package org.algorithm.factory;
+
+import org.algorithm.core.cnn.AlgorithmCNNExecutorPacs;
+import org.algorithm.core.cnn.model.RelationExtractionEnsembleModel;
+
+/**
+ * @Description:
+ * @Author: HUJING
+ * @Date: 2019/9/10 15:25
+ */
+public class RelationExtractionFactory {
+    private static RelationExtractionEnsembleModel relationExtractionEnsembleModelInstance = null;
+
+    public static AlgorithmCNNExecutorPacs getInstance() {
+        try {
+            relationExtractionEnsembleModelInstance = (RelationExtractionEnsembleModel) create(relationExtractionEnsembleModelInstance, RelationExtractionEnsembleModel.class);
+        } catch (InstantiationException inst) {
+            inst.printStackTrace();
+        } catch (IllegalAccessException ille) {
+            ille.printStackTrace();
+        }
+        return relationExtractionEnsembleModelInstance;
+    }
+
+    private static Object create(Object obj, Class cls) throws InstantiationException, IllegalAccessException {
+        if (obj == null) {
+            synchronized (cls) {
+                obj = cls.newInstance();
+            }
+        }
+        return obj;
+    }
+}

+ 5 - 3
algorithm/src/main/java/org/algorithm/test/ReEnsembleModelTest.java

@@ -18,7 +18,7 @@ public class ReEnsembleModelTest {
 
     public static void main(String[] args) {
         RelationExtractionEnsembleModel ensembleModel = new RelationExtractionEnsembleModel();
-
+        List<List<String>> result = new ArrayList<>();
         List<Triad> triads = new ArrayList<>();
         Triad triad_1 = new Triad();
         Lemma l_1 = new Lemma();
@@ -36,9 +36,11 @@ public class ReEnsembleModelTest {
 
         long start = System.nanoTime();
         for (int i=0; i<200; i++)  // 重复100次
-            triads = ensembleModel.execute("患者剧烈胸痛头痛失眠不安", triads);
+        {
+            result = ensembleModel.execute("患者剧烈胸痛头痛失眠不安", triads);
+        }
         long elapsedTime = System.nanoTime() - start;
-        System.out.println(triads.size());
+        System.out.println(result.size());
         System.out.println(elapsedTime);
     }
 }

+ 22 - 0
common-push/src/main/java/org/diagbot/common/push/util/PushConstants.java

@@ -1,5 +1,8 @@
 package org.diagbot.common.push.util;
 
+import java.util.HashMap;
+import java.util.Map;
+
 /**
  * @ClassName org.diagbot.bigdata.util.BigDataConstants
  * @Description TODO
@@ -38,4 +41,23 @@ public class PushConstants {
     public final static String result_mapping_vital = "resultMappingVitalMap";          //推送体征结果名称映射
     public final static String result_mapping_diag = "resultMappingDiagMap";          //推送疾病科室名称映射
     public final static String result_mapping_filter = "resultMappingFilterMap";          //推送结果年龄 性别过滤
+
+    //关系抽取property_id对应property_name
+    public final static Map<String,String> featureTypeMap = new HashMap<String,String>(){{
+        put("80","辅检其他");
+        put("9","单位");
+        put("2","时间");
+        put("3","部位");
+        put("7","反意或虚拟");
+        put("16","辅检项目");
+        put("17","辅检结果");
+        put("81","属性");
+        put("82","方位");
+        put("83","形容词");
+        put("84","局部结构");
+        put("85","属性值");
+        put("86","表现");
+        put("28","字母与数值");
+        put("87","正常表现");
+    }};
 }

+ 14 - 3
common-push/src/main/java/org/diagbot/common/push/work/ParamsDataProxy.java

@@ -1,5 +1,8 @@
 package org.diagbot.common.push.work;
 
+import org.algorithm.core.cnn.AlgorithmCNNExecutor;
+import org.algorithm.core.cnn.AlgorithmCNNExecutorPacs;
+import org.algorithm.factory.RelationExtractionFactory;
 import org.apache.commons.lang3.StringUtils;
 import org.diagbot.common.push.bean.SearchData;
 import org.diagbot.common.push.util.PushConstants;
@@ -10,9 +13,7 @@ import org.diagbot.nlp.util.NegativeEnum;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import javax.servlet.http.HttpServletRequest;
 import java.util.*;
-import java.util.regex.Pattern;
 
 /**
  * @ClassName org.diagbot.bigdata.work.ParamsDataProxy
@@ -26,7 +27,7 @@ public class ParamsDataProxy {
     //标准词只处理的词性
     public static NegativeEnum[] negativeEnums = new NegativeEnum[] { NegativeEnum.VITAL_INDEX, NegativeEnum.SYMPTOM
             , NegativeEnum.DIGITS, NegativeEnum.EVENT_TIME, NegativeEnum.UNIT, NegativeEnum.DIAG_STAND
-            , NegativeEnum.OTHER};
+            , NegativeEnum.OTHER };
     //标准词处理的三元组
     public static NegativeEnum[][] negativeEnumTriple = {
             { NegativeEnum.VITAL_INDEX, NegativeEnum.DIGITS, NegativeEnum.UNIT },
@@ -138,6 +139,16 @@ public class ParamsDataProxy {
             featuresList = fa.start(searchData.getDiag(), FeatureType.DIAG);
             paramFeatureInit(searchData, featuresList);
         }
+        if (!StringUtils.isEmpty(searchData.getPacs())) {
+            //关系抽取模型
+            AlgorithmCNNExecutorPacs algorithmCNNExecutor = RelationExtractionFactory.getInstance();
+            RelationExtractionCreateSearchData re = new RelationExtractionCreateSearchData();
+            //Pacs原始分词结果
+            List<List<String>> execute = algorithmCNNExecutor.execute(searchData.getPacs(), re.createSearchData(searchData));
+            if (execute != null && execute.size() > 0) {
+                re.start(execute, searchData);
+            }
+        }
     }
 
     /**

+ 74 - 0
common-push/src/main/java/org/diagbot/common/push/work/RelationExtractionCreateSearchData.java

@@ -0,0 +1,74 @@
+package org.diagbot.common.push.work;
+
+import org.algorithm.core.cnn.entity.Lemma;
+import org.algorithm.core.cnn.entity.Triad;
+import org.apache.commons.lang3.StringUtils;
+import org.diagbot.common.push.bean.SearchData;
+import org.diagbot.common.push.util.PushConstants;
+import org.diagbot.nlp.participle.ParticipleUtil;
+import org.diagbot.nlp.participle.word.Lexeme;
+import org.diagbot.nlp.participle.word.LexemePath;
+import org.diagbot.nlp.util.Constants;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @Description:
+ * @Author: HUJING
+ * @Date: 2019/9/9 17:30
+ */
+public class RelationExtractionCreateSearchData {
+    public List<Triad> createSearchData(SearchData searchData) throws IOException {
+        List<Triad> triads = new ArrayList<>();
+        String[] pacsSplits = searchData.getPacs().trim().split("。|\n");
+        List<Lemma> lemmaList = new ArrayList<>();
+        Lemma lemma = null;
+        for (String pacsSplit : pacsSplits) {
+            LexemePath<Lexeme> pacsLexemes = ParticipleUtil.participlePacs(pacsSplit);
+            for (int i = 0; i < pacsLexemes.size(); i++) {
+                if ("44".contains(pacsLexemes.get(i).getProperty())) {
+                    continue;
+                }
+                lemma = new Lemma();
+                lemma.setText(pacsLexemes.get(i).getText());
+                lemma.setPosition(String.valueOf(pacsLexemes.get(i).getOffset()) + "," + (Integer.valueOf(pacsLexemes.get(i).getOffset() + pacsLexemes.get(i).getLength()) - 1));
+                lemma.setProperty(PushConstants.featureTypeMap.get(pacsLexemes.get(i).getProperty()));
+                lemmaList.add(lemma);
+            }
+        }
+        for (int i = 0; i < lemmaList.size() - 1; i++) {
+            for (int j = i + 1; j < lemmaList.size(); j++) {
+                Triad triad = new Triad();
+                triad.setL_1(lemmaList.get(i));
+                triad.setL_2(lemmaList.get(j));
+                triads.add(triad);
+            }
+        }
+        return triads;
+    }
+
+    public void start(List<List<String>> relationExtractionContents, SearchData searchData) throws Exception {
+        StringBuffer sb = null;
+        for (List<String> contents : relationExtractionContents) {
+            sb = new StringBuffer();
+            for (String content : contents) {
+                sb.append(content);
+            }
+            Map<String, String> map = new HashMap<>();
+            map.put("featureType", "4");
+            map.put("featureName", sb.toString());
+            map.put("property", "17");
+            map.put("concept", sb.toString());
+            //全是有
+            map.put("negative", Constants.default_negative);
+            if (searchData.getInputs().get(map.get("featureName")) == null) {
+                searchData.getInputs().put(map.get("featureName"), map);
+            }
+        }
+    }
+
+}