Explorar o código

重新添加朴素贝叶斯算法

hujing %!s(int64=5) %!d(string=hai) anos
pai
achega
4dac6a4e6a

+ 18 - 4
bigdata-web/src/main/java/org/diagbot/bigdata/work/AlgorithmCore.java

@@ -6,6 +6,7 @@ import org.algorithm.util.AlgorithmClassify;
 import org.diagbot.common.push.bean.FeatureRate;
 import org.diagbot.common.push.bean.ResponseData;
 import org.diagbot.common.push.bean.SearchData;
+import org.diagbot.common.push.naivebayes.factory.AlgorithmNaiveBayesFactory;
 import org.diagbot.nlp.feature.FeatureType;
 import org.diagbot.nlp.util.Constants;
 import org.diagbot.nlp.util.NlpCache;
@@ -47,12 +48,25 @@ public class AlgorithmCore {
             if (classifies[i] == null) {
                 continue;
             }
-            //算法推理
-            AlgorithmExecutor executor = AlgorithmFactory.getInstance(classifies[i]);
             Map<String, Float> featuresMap = null;
+            AlgorithmExecutor executor = null;
+            switch (searchData.getAlgorithmType() == null ? 1 : searchData.getAlgorithmType()) {
+                case 1: //机器学习算法推理
+                    executor = AlgorithmFactory.getInstance(classifies[i]);
+                    break;
+                case 2: //朴素贝叶斯算法推理
+                    if (FeatureType.parse(featureTypes[i]) == FeatureType.DIAG) {
+                        executor = AlgorithmNaiveBayesFactory.getInstance();
+                    }
+                    break;
+                default:
+                    executor = AlgorithmFactory.getInstance(classifies[i]);
+            }
+
             if (executor != null) {
-                featuresMap = executor.execute(bigDataSearchData.getInputs());
-                ;
+                featuresMap = executor.execute(bigDataSearchData.getInputs());;
+            } else {
+                continue;
             }
             List<Map.Entry<String, Float>> featuresOrderList = null;
             if (featuresMap == null) {

+ 10 - 0
common-push/src/main/java/org/diagbot/common/push/bean/SearchData.java

@@ -60,6 +60,8 @@ public class SearchData {
     private Map<String, Map<String, String>> filters = new HashMap<>(10, 0.8f);
     //满足规则的ID集合
     private Map<String, List<Rule>> rules = new HashMap<>();
+    //特征推送走的模型 1:机器学习 2:朴素贝叶斯
+    private Integer algorithmType;
 
     public Integer getDisType() {
         return disType;
@@ -315,6 +317,14 @@ public class SearchData {
         this.rules = rules;
     }
 
+    public Integer getAlgorithmType() {
+        return algorithmType;
+    }
+
+    public void setAlgorithmType(Integer algorithmType) {
+        this.algorithmType = algorithmType;
+    }
+
     public String getDiseaseName() {
         return diseaseName;
     }

+ 121 - 0
common-push/src/main/java/org/diagbot/common/push/cache/ApplicationCacheUtil.java

@@ -30,6 +30,10 @@ public class ApplicationCacheUtil {
     public static Map<String, RuleApp> kl_rule_app_filter_map = null;
     //pacs关系抽取过滤
     public static Map<String, Map<String, String>> kl_diagnose_detail_filter_map = null;
+    //朴素贝叶斯
+    public static Map<String, Map<String, Float>> doc_feature_naivebayes_prob_map = null;
+    //朴素贝叶斯规则过滤
+    public static Map<String, Map<String, Float>> relevant_feature_bayes_map = null;
 
     public static Map<String, Map<String, String>> getStandard_info_synonym_map() {
         if (standard_info_synonym_map == null) {
@@ -237,4 +241,121 @@ public class ApplicationCacheUtil {
             }
         }
     }
+
+    public static Map<String, Map<String, Float>> getDoc_feature_naivebayes_prob_map() {
+        if (doc_feature_naivebayes_prob_map == null) {
+            create_doc_feature_naivebayes_prob_map();
+        }
+        return doc_feature_naivebayes_prob_map;
+    }
+
+    public static void create_doc_feature_naivebayes_prob_map() {
+        doc_feature_naivebayes_prob_map = new HashMap<>();
+        //<rdn,[feature...]> 存储每个rdn对应的特征List
+        Map<String, List<String>> featureMap = new HashMap<>();
+        List<String> featureList = null;
+        Configuration configuration = new DefaultConfig();
+        List<String> fileFeatureContents = configuration.readFileContents("bigdata_naivebayes_features.dict");
+        for (String line : fileFeatureContents) {
+            String[] content = line.split("\\|", -1);
+            if (featureMap.get(content[0]) == null) {
+                featureList = new ArrayList<>();
+                for (String feature : content[1].split(" ")) {
+                    featureList.add(feature);
+                }
+                featureMap.put(content[0], featureList);
+            }
+        }
+
+        //<rdn,diagnose> 存每个rdn对应疾病
+        Map<String, String> diagnoseMap = new HashMap<>();
+        //<diagnose,count> 存每个疾病的数量
+        Map<String, Integer> diagnoseCount = new HashMap<>();
+        List<String> fileDiagnoseContents = configuration.readFileContents("bigdata_naivebayes_diagnose.dict");
+        diagnoseCount.put("diagnoseCount", fileDiagnoseContents.size());
+        for (String line : fileDiagnoseContents) {
+            String[] content = line.split("\\|", -1);
+            if (diagnoseMap.get(content[0]) == null) {
+                diagnoseMap.put(content[0], content[1]);
+            }
+            if (diagnoseCount.get(content[1]) == null) {
+                diagnoseCount.put(content[1], 1);
+            } else {
+                diagnoseCount.put(content[1], diagnoseCount.get(content[1]) + 1);
+            }
+        }
+
+        Map<String, Map<String, Integer>> diagnose2featureCount = new HashMap<>();
+        Map<String, Integer> featureCount = new HashMap<>();
+        for (Map.Entry<String, String> diagnoseMapEntry : diagnoseMap.entrySet()) {
+            //featureMap -> <1000000_144 , [咳嗽,咳痰,1周,气管炎]>
+            if (featureMap.get(diagnoseMapEntry.getKey()) == null) {
+                continue;
+            }
+            for (String feature : featureMap.get(diagnoseMapEntry.getKey())) {
+                /**
+                 diagnoseMapEntry <1596386_9,鼻炎> -> <rdn,diagnose>
+                 如果疾病对应特征列表为空 diagnoseMapEntry.getValue()->疾病
+                 */
+                if (diagnose2featureCount.get(diagnoseMapEntry.getValue()) == null) {
+                    featureCount = new HashMap<>();
+                    //featureMap -> <1000000_144 , [咳嗽,咳痰,1周,气管炎]>
+                    if (featureCount.get(feature) == null) {
+                        featureCount.put(feature, 1);
+                    } else {
+                        featureCount.put(feature, featureCount.get(feature) + 1);
+                    }
+                    //疾病对应病历数
+                    featureCount.put("diagnoseCount", diagnoseCount.get(diagnoseMapEntry.getValue()));
+                    diagnose2featureCount.put(diagnoseMapEntry.getValue(), featureCount);
+                } else {
+                    if (diagnose2featureCount.get(diagnoseMapEntry.getValue()).get(feature) == null) {
+                        diagnose2featureCount.get(diagnoseMapEntry.getValue()).put(feature, 1);
+                    } else {
+                        diagnose2featureCount.get(diagnoseMapEntry.getValue())
+                                .put(feature, diagnose2featureCount.get(diagnoseMapEntry.getValue()).get(feature) + 1);
+                    }
+                }
+            }
+        }
+
+        Map<String, Float> prob = null;
+        for (Map.Entry<String, Map<String, Integer>> diagnose2featureCountEntry : diagnose2featureCount.entrySet()) {
+            prob = new HashMap<>();
+            //计算先验概率
+            float priorProb = (float) diagnose2featureCountEntry.getValue().get("diagnoseCount") / diagnoseCount.get("diagnoseCount");
+            prob.put("priorProb", priorProb);
+            //计算条件概率
+            for (Map.Entry<String, Integer> featuresCount : diagnose2featureCountEntry.getValue().entrySet()) {
+                float conditionProb = (float) featuresCount.getValue() / diagnose2featureCountEntry.getValue().get("diagnoseCount");
+                prob.put(featuresCount.getKey(), conditionProb);
+            }
+            doc_feature_naivebayes_prob_map.put(diagnose2featureCountEntry.getKey(), prob);
+        }
+    }
+
+    public static Map<String, Map<String,Float>> getRelevant_feature_map() {
+        if (relevant_feature_bayes_map == null) {
+            createRelevant_feature_map();
+        }
+        return relevant_feature_bayes_map;
+    }
+
+    public static Map<String, Map<String,Float>> createRelevant_feature_map() {
+        relevant_feature_bayes_map = new HashMap<>();
+        Map<String,Float> relevantFeatureProb = null;
+        Configuration configuration = new DefaultConfig();
+        List<String> relevantFeatureList = configuration.readFileContents("bigdata_relevant_feature.dict");
+        for (String relevantFeature:relevantFeatureList) {
+            String[] content = relevantFeature.split("\\|", -1);
+            if (relevant_feature_bayes_map.get(content[0]) == null){
+                relevantFeatureProb = new HashMap<>();
+                relevantFeatureProb.put(content[1],0.00f);
+                relevant_feature_bayes_map.put(content[0],relevantFeatureProb);
+            } else {
+                relevant_feature_bayes_map.get(content[0]).put(content[1],0.00f);
+            }
+        }
+        return relevant_feature_bayes_map;
+    }
 }

+ 40 - 4
common-push/src/main/java/org/diagbot/common/push/cache/CacheFileManager.java

@@ -22,9 +22,9 @@ import java.util.*;
 public class CacheFileManager {
     Logger logger = LoggerFactory.getLogger(CacheFileManager.class);
 
-    private String user = "teamdata";
-    private String password = "jiO2rfnYhg";
-    private String url = "jdbc:mysql://192.168.2.121:3306/med?useUnicode=true&characterEncoding=UTF-8";
+    private String user = "root";
+    private String password = "lantone";
+    private String url = "jdbc:mysql://192.168.2.236:3306/med?useUnicode=true&characterEncoding=UTF-8";
 
     private String path = "";
 
@@ -381,7 +381,7 @@ public class CacheFileManager {
             }
             fw.close();
 
-            sql = "SELECT id, rule_id, type_id, remind FROM kl_rule_app";
+            sql = "SELECT id, rule_id, rule_type, remind FROM kl_rule_app";
             st = conn.createStatement();
             rs = st.executeQuery(sql);
             fw = new FileWriter(path + "bigdata_rule_app_filter.dict");
@@ -411,6 +411,42 @@ public class CacheFileManager {
             }
             fw.close();
 
+            sql = "SELECT rdn, GROUP_CONCAT(feature_name ORDER BY sn SEPARATOR ' ') AS features FROM doc_feature WHERE feature_type = 9 GROUP BY rdn;";
+            st = conn.createStatement();
+            rs = st.executeQuery(sql);
+            fw = new FileWriter(path + "bigdata_naivebayes_features.dict");
+            while (rs.next()) {
+                r1 = rs.getString(1);
+                r2 = rs.getString(2);
+                fw.write(encrypDES.encrytor(r1+ "|" + r2));
+                fw.write("\n");
+            }
+            fw.close();
+
+            sql = "select rdn, feature_name as diagnose from doc_feature where feature_type=2";
+            st = conn.createStatement();
+            rs = st.executeQuery(sql);
+            fw = new FileWriter(path + "bigdata_naivebayes_diagnose.dict");
+            while (rs.next()) {
+                r1 = rs.getString(1);
+                r2 = rs.getString(2);
+                fw.write(encrypDES.encrytor(r1+ "|" + r2));
+                fw.write("\n");
+            }
+            fw.close();
+
+            sql = "SELECT diagnose,feature FROM doc_relevant_feature;";
+            st = conn.createStatement();
+            rs = st.executeQuery(sql);
+            fw = new FileWriter(path + "bigdata_relevant_feature.dict");
+            while (rs.next()) {
+                r1 = rs.getString(1);
+                r2 = rs.getString(2);
+                fw.write(encrypDES.encrytor(r1+ "|" + r2));
+                fw.write("\n");
+            }
+            fw.close();
+
             //化验辅检体征性别年龄
             sql = "SELECT k1.lib_name, k1.lib_type, kcc.sex_type, kcc.min_age, kcc.max_age " +
                     "FROM kl_concept_common kcc, kl_concept k1 " +

+ 30 - 0
common-push/src/main/java/org/diagbot/common/push/naivebayes/NaiveBayesTest.java

@@ -0,0 +1,30 @@
+package org.diagbot.common.push.naivebayes;
+
+import org.diagbot.common.push.naivebayes.core.AlgorithmNaiveBayesExecutor;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * @Description:
+ * @Author: HUJING
+ * @Date: 2019/10/11 14:30
+ */
+public class NaiveBayesTest {
+    public static void main(String[] args) {
+        AlgorithmNaiveBayesExecutor a = new AlgorithmNaiveBayesExecutor();
+        Map<String, Map<String, String>> inputs = new HashMap<>();
+        inputs.put("咽部异物感",new HashMap<>());
+//        inputs.put("腹胀",new HashMap<>());
+//        inputs.put("乏力",new HashMap<>());
+        Map<String, Float> softmax = a.execute(inputs);
+        double i = 0.00;
+        for (Map.Entry<String, Float> s:softmax.entrySet()) {
+            i += s.getValue();
+            if (s.getValue() == 0){
+                System.out.println(s.getKey());
+            }
+        }
+        System.out.println(i);
+    }
+}

+ 92 - 0
common-push/src/main/java/org/diagbot/common/push/naivebayes/core/AlgorithmNaiveBayesExecutor.java

@@ -0,0 +1,92 @@
+package org.diagbot.common.push.naivebayes.core;
+
+import org.algorithm.core.AlgorithmExecutor;
+import org.diagbot.common.push.cache.ApplicationCacheUtil;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+/**
+ * @Description:
+ * @Author: HUJING
+ * @Date: 2019/10/11 14:25
+ */
+public class AlgorithmNaiveBayesExecutor extends AlgorithmExecutor {
+    private double e = Math.E;
+    private static double unknownProbWithRelevant = -2; //已知有关,但未在病历中统计出来的特征
+    private static double unknownProbWithoutRelevant = -6;  //无关事件间的共现概率
+    private static double denominator = 0.00;
+
+    public Map<String, Float> execute(Map<String, Map<String, String>> inputs) {
+        return softmax(probCalc(inputs));
+    }
+
+    public Map<String, Float> probCalc(Map<String, Map<String, String>> inputs) {
+        Map<String, Map<String, Float>> doc_feature_naivebayes_prob_map = ApplicationCacheUtil.getDoc_feature_naivebayes_prob_map();
+        Map<String, Map<String, Float>> relevant_feature_map = ApplicationCacheUtil.getRelevant_feature_map();
+        Map<String, Float> naivebayesResult = new HashMap<>();
+        for (Map.Entry<String, Map<String, Float>> naivebayesProb : doc_feature_naivebayes_prob_map.entrySet()) {
+            float sum = 0.00f;
+            int i = 1;
+            for (String input : inputs.keySet()) {
+                //先验概率表里有该特征,就使用该特征的先验概率
+                if (naivebayesProb.getValue().containsKey(input)) {
+                    sum += Math.log10(naivebayesProb.getValue().get(input));
+                } else if (relevant_feature_map.get(naivebayesProb.getKey()) != null &&
+                        relevant_feature_map.get(naivebayesProb.getKey()).containsKey(input)) {
+                    //先验概率表里没有该特征 但 关联规则表里有该特征,则平滑处理(默认此时先验概率为10^-2)
+                    sum += unknownProbWithRelevant;
+                } else {
+                    sum += unknownProbWithoutRelevant;
+                }
+
+                if (i == inputs.size()) {
+                    sum += Math.log10(naivebayesProb.getValue().get("priorProb"));
+                    naivebayesResult.put(naivebayesProb.getKey(), sum);
+                }
+                i++;
+            }
+        }
+//        naivebayesResult = sortMap(naivebayesResult);
+        return naivebayesResult;
+    }
+
+    private Map<String, Float> softmax(Map<String, Float> naivebayesResultMap) {
+        Map<String, Float> softmaxResult = new HashMap<>();
+        calaDenominator(naivebayesResultMap);
+
+        for (Map.Entry<String, Float> naivebayesResult : naivebayesResultMap.entrySet()) {
+            softmaxResult.put(naivebayesResult.getKey(), (float) (Math.pow(this.e, naivebayesResult.getValue()) / denominator));
+        }
+
+        softmaxResult = sortMap(softmaxResult);
+        return softmaxResult;
+    }
+
+    private void calaDenominator(Map<String, Float> naivebayesResultMap) {
+        if (denominator == 0) {
+            for (Map.Entry<String, Float> naivebayesResult : naivebayesResultMap.entrySet()) {
+                //计算softmax算法分母
+                denominator += Math.pow(this.e, naivebayesResult.getValue());
+            }
+        }
+    }
+
+    public Map<String, Float> sortMap(Map<String, Float> ResultMap) {
+        ArrayList<Map.Entry<String, Float>> softmaxResultList = new ArrayList<>(ResultMap.entrySet());
+        softmaxResultList.sort(new Comparator<Map.Entry<String, Float>>() {
+            @Override
+            public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
+                return o2.getValue().compareTo(o1.getValue());
+            }
+        });
+        ResultMap = new LinkedHashMap<>();
+        for (Map.Entry<String, Float> softmaxResultMap : softmaxResultList) {
+            ResultMap.put(softmaxResultMap.getKey(), softmaxResultMap.getValue());
+        }
+        return ResultMap;
+    }
+}

+ 34 - 0
common-push/src/main/java/org/diagbot/common/push/naivebayes/factory/AlgorithmNaiveBayesFactory.java

@@ -0,0 +1,34 @@
+package org.diagbot.common.push.naivebayes.factory;
+
+import org.algorithm.core.AlgorithmExecutor;
+import org.algorithm.core.cnn.model.RelationExtractionEnsembleModel;
+import org.diagbot.common.push.naivebayes.core.AlgorithmNaiveBayesExecutor;
+
+/**
+ * @Description:
+ * @Author: HUJING
+ * @Date: 2019/9/10 15:25
+ */
+public class AlgorithmNaiveBayesFactory {
+    private static AlgorithmNaiveBayesExecutor algorithmNaiveBayesExecutorInstance = null;
+
+    public static AlgorithmExecutor getInstance() {
+        try {
+            algorithmNaiveBayesExecutorInstance = (AlgorithmNaiveBayesExecutor) create(algorithmNaiveBayesExecutorInstance, AlgorithmNaiveBayesExecutor.class);
+        } catch (InstantiationException inst) {
+            inst.printStackTrace();
+        } catch (IllegalAccessException ille) {
+            ille.printStackTrace();
+        }
+        return algorithmNaiveBayesExecutorInstance;
+    }
+
+    private static Object create(Object obj, Class cls) throws InstantiationException, IllegalAccessException {
+        if (obj == null) {
+            synchronized (cls) {
+                obj = cls.newInstance();
+            }
+        }
+        return obj;
+    }
+}