Ver código fonte

1- 添加疾病过滤模块。

bijl 5 anos atrás
pai
commit
589d263efc

+ 2 - 1
algorithm/src/main/java/org/algorithm/core/neural/DiagnosisPredictExecutor.java

@@ -15,7 +15,8 @@ public class DiagnosisPredictExecutor extends AlgorithmNeuralExecutor {
     public DiagnosisPredictExecutor() {
         String modelVersion = "diagnosisPredict.version";
 
-        this.model = TensorFlowModelLoadFactory.create(modelVersion);
+//        this.model = TensorFlowModelLoadFactory.create(modelVersion);
+        this.model = TensorFlowModelLoadFactory.createAndFilterDiagnosis(modelVersion);  // 加了疾病过滤
     }
 
 }

+ 30 - 7
algorithm/src/main/java/org/algorithm/core/neural/TensorFlowModelLoadFactory.java

@@ -11,13 +11,10 @@ import org.diagbot.pub.utils.PropertiesUtil;
  * @Description:
  */
 public class TensorFlowModelLoadFactory {
-    
+
     /**
      * 加载并创建模型类
-     * @param exportDir  模型保存地址
-     * @param inputOpName  输入op的名称
-     * @param outputOpName  输出op的名称
-     * @param dataSet     模型使用的数据集
+     * @param modelVersion  模型版本号
      * @return 模型
      */
     public static TensorflowModel create(String modelVersion) {
@@ -28,10 +25,9 @@ public class TensorFlowModelLoadFactory {
         String inputOpName = "X";  // 统一输入op名称
         String outputOpName = "softmax/softmax";  // 统一输出op名称
         
-        // TODO:修改的地方
 //        NNDataSet dataSet = new NNDataSetImplNonParallel(modelVersion);  // 新模型
         NNDataSet dataSet = new NNDataSetImpl(modelVersion);  // 老模型
-        
+
         String modelPath =prop.getProperty("basicPath");  // 模型基本路径
         modelVersion = prop.getProperty(modelVersion);
         modelPath = modelPath.replace("model_version_replacement", modelVersion);  // 生成模型路径
@@ -41,4 +37,31 @@ public class TensorFlowModelLoadFactory {
         return tm;
     }
 
+    /**
+     * 加载并创建模型类
+     * @param modelVersion  模型版本号
+     * @return 模型
+     */
+    public static TensorflowModel createAndFilterDiagnosis(String modelVersion) {
+
+
+        PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
+
+        String inputOpName = "X";  // 统一输入op名称
+        String outputOpName = "softmax/softmax";  // 统一输出op名称
+
+        NNDataSet dataSet = new NNDataSetImpl(modelVersion);  // 老模型
+
+        dataSet.setDoFilterDiagnosis(true);
+        dataSet.readFilterDiagnosisDict();
+
+        String modelPath =prop.getProperty("basicPath");  // 模型基本路径
+        modelVersion = prop.getProperty(modelVersion);
+        modelPath = modelPath.replace("model_version_replacement", modelVersion);  // 生成模型路径
+
+        TensorflowModel tm = new TensorflowModel(modelPath, inputOpName, outputOpName,
+                dataSet);
+        return tm;
+    }
+
 }

+ 116 - 14
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSet.java

@@ -1,10 +1,10 @@
 package org.algorithm.core.neural.dataset;
 
-import java.util.HashMap;
-import java.util.Map;
+import java.util.*;
 
 /**
  * 神经网络用数据处理模块
+ *
  * @Author: bijl
  * @Date: 2018年7月20日-下午4:01:34
  * @Description:
@@ -13,18 +13,20 @@ public abstract class NNDataSet {
     protected final int NUM_FEATURE;
     private final int NUM_LABEL;
     protected final Map<String, Integer> FEATURE_DICT = new HashMap<>();
-    
+
     // 新版本新加的三种关键词
     protected final Map<String, Integer> PARTBODY_DICT = new HashMap<>();
     protected final Map<String, Integer> PROPERTY_DICT = new HashMap<>();
     protected final Map<String, Integer> DURATION_DICT = new HashMap<>();
-    
+
     protected final Map<String, Integer> LABEL_DICT = new HashMap<>();
     protected final Map<String, Integer> NEGATIVE_DICT = new HashMap<>();
-    protected final Map<String, String>  RE_SPLIT_WORD_DICT = new HashMap<>();
+    protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>();
+    protected final Map<String, Map<String, Integer>> RELATED_DIAGNOSIS_DICT = new HashMap<>();
+    protected final List<String> FEATURE_NAME_STORE = new ArrayList<>();
     private final String[] FEATURE_DICT_ARRAY;
     private final String[] LABEL_DICT_ARRAY;
-
+    private boolean doFilterDiagnosis = false;
 
     public NNDataSet(String modelAndVersion) {
         this.readDict(modelAndVersion);
@@ -35,9 +37,10 @@ public abstract class NNDataSet {
         this.makeDictArr();
         this.readReSplitWordDict();
     }
-    
+
     /**
      * 装外部输入转为特征向量
+     *
      * @param inputs
      * @return
      */
@@ -52,28 +55,113 @@ public abstract class NNDataSet {
      * 读取再分词字典
      */
     public abstract void readReSplitWordDict();
-    
+
+    /**
+     * 读取过滤字典
+     */
+    public abstract void readFilterDiagnosisDict();
+
     /**
      * 生成字典列表
      */
     private void makeDictArr() {
-        for (Map.Entry<String, Integer> entry : this.FEATURE_DICT.entrySet()) 
+        for (Map.Entry<String, Integer> entry : this.FEATURE_DICT.entrySet())
             this.FEATURE_DICT_ARRAY[entry.getValue()] = entry.getKey();
-        
-        for (Map.Entry<String, Integer> entry : this.LABEL_DICT.entrySet()) 
+
+        for (Map.Entry<String, Integer> entry : this.LABEL_DICT.entrySet())
             this.LABEL_DICT_ARRAY[entry.getValue()] = entry.getKey();
-        
+
+    }
+
+    /**
+     * 打包特征名和概率 + 过滤疾病
+     * 基本操作,过滤前20个疾病,如果
+     *
+     * @param predict 模型输出
+     * @return
+     */
+    public Map<String, Float> wrapAndFilter(float[][] predict) {
+        List<NameAndValue> nameAndValueList = new ArrayList<>();
+        for (int i = 0; i < predict[0].length; i++)
+            nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
+        nameAndValueList.sort(Comparator.reverseOrder());  // 按概率从大到小排列
+
+        Map<String, Float> result = new HashMap<>();
+        Integer cnt = 0;
+        String diagnosis;
+        NameAndValue nameAndValue;
+        Map<String, Integer> relatedDiagnoses = null;
+        for (int i = 0; i < nameAndValueList.size(); i++) {
+            nameAndValue = nameAndValueList.get(i);
+            diagnosis = nameAndValue.getName();
+            for (String featureName : this.FEATURE_NAME_STORE) {
+                relatedDiagnoses = this.RELATED_DIAGNOSIS_DICT.get(featureName);
+                if (relatedDiagnoses != null && relatedDiagnoses.get(diagnosis) == 1) {
+                    result.put(nameAndValue.getName(), nameAndValue.getValue());
+                    cnt += 1;
+                }
+            }
+            if ((i >= 20 || i >= 50) && cnt > 0)  // 如果前20或50个推送中有相关的疾病,只过滤他们
+                break;
+        }
+        return result;
+    }
+
+    /**
+     * 用于排序的类
+     */
+    class NameAndValue implements Comparable<NameAndValue> {
+
+        private String name;
+        private Float value;
+
+        NameAndValue(String name, Float value) {
+            this.name = name;
+            this.value = value;
+        }
+
+        @Override
+        public int compareTo(NameAndValue o) {
+            if (this.value > o.getValue())
+                return 1;
+            else if (this.value.equals(o.getValue()))
+                return 0;
+            else
+                return -1;
+        }
+
+        public Float getValue() {
+            return value;
+        }
+
+        public String getName() {
+            return name;
+        }
     }
 
     /**
      * 打包模型输出结果给调用者
-     * 
+     *
      * @param predict 模型输出
      * @return
      */
     public Map<String, Float> wrap(float[][] predict) {
+        if (this.doFilterDiagnosis)  // 过滤疾病
+            return this.wrapAndFilter(predict);
+        else
+            return this.basicWrap(predict);
+    }
+
+
+    /**
+     * 打包模型输出结果给调用者
+     *
+     * @param predict 模型输出
+     * @return
+     */
+    public Map<String, Float> basicWrap(float[][] predict) {
         Map<String, Float> result = new HashMap<>();
-        for (int i=0; i<predict[0].length; i++) {  // 只返回一维向量
+        for (int i = 0; i < predict[0].length; i++) {  // 只返回一维向量
             result.put(this.LABEL_DICT_ARRAY[i], predict[0][i]);
         }
         return result;
@@ -86,6 +174,15 @@ public abstract class NNDataSet {
         return this.NUM_FEATURE;
     }
 
+    /**
+     *  存储特征名称
+     * @param features
+     */
+    public void storeFeatureNames(Map<String, Map<String, String>> features){
+        this.FEATURE_NAME_STORE.clear();
+        this.FEATURE_NAME_STORE.addAll(features.keySet());
+    }
+
     /**
      * @return
      */
@@ -93,4 +190,9 @@ public abstract class NNDataSet {
         return this.NUM_LABEL;
     }
 
+
+    public void setDoFilterDiagnosis(boolean doFilterDiagnosis) {
+        this.doFilterDiagnosis = doFilterDiagnosis;
+    }
+
 }

+ 36 - 2
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSetImpl.java

@@ -23,11 +23,13 @@ public class NNDataSetImpl extends NNDataSet {
         super(modelAndVersion);
     }
 
-
     @Override
     public float[] toFeatureVector(Map<String, Map<String, String>> inputs) {
 
+        // 新添加的操作
         this.reSplitWord(inputs);  // 再分词
+        this.storeFeatureNames(inputs);  // 保存特征名称
+
         float[] featureVector = new float[this.NUM_FEATURE];
 
         Iterator<Entry<String, Map<String, String>>> entries = inputs.entrySet().iterator();
@@ -144,7 +146,7 @@ public class NNDataSetImpl extends NNDataSet {
         String filePath = prop.getProperty("basicPath");  // 基本目录
         filePath = filePath.substring(0, filePath.indexOf("model_version_replacement"));
 
-        filePath = filePath + "dictionaries.bin";  // 字典文件位置
+        filePath = filePath + "re_split_word.bin";  // 字典文件位置
 
         List<String> lines = TextFileReader.readLines(filePath);
 
@@ -168,5 +170,37 @@ public class NNDataSetImpl extends NNDataSet {
 
     }
 
+    @Override
+    public void readFilterDiagnosisDict() {
+        PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
+        String filePath = prop.getProperty("basicPath");  // 基本目录
+        filePath = filePath.substring(0, filePath.indexOf("model_version_replacement"));
+
+        filePath = filePath + "filter_diagnoses.bin";  // 字典文件位置
+
+        List<String> lines = TextFileReader.readLines(filePath);
+
+        boolean firstLine = true;
+
+        String[] temp = null;
+        String[] diagnoses = null;
+        Map<String, Integer> diagnosis_map = null;
+        for (String line : lines) {
+            if (firstLine) {  // 去除第一行
+                firstLine = false;
+                continue;
+            }
+
+            temp = line.split("\\|");
+            diagnoses = temp[1].split("_");
+            diagnosis_map = new HashMap<>();
+            for (String diagnosis: diagnoses)
+                diagnosis_map.put(diagnosis, 1);
+            this.RELATED_DIAGNOSIS_DICT.put(temp[0], diagnosis_map);
+        }
+
+        System.out.println("疾病过滤字典大小:" + this.RELATED_DIAGNOSIS_DICT.size());
+    }
+
 
 }

+ 11 - 1
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSetImplNonParallel.java

@@ -22,7 +22,17 @@ public class NNDataSetImplNonParallel extends NNDataSet {
         super(modelAndVersion);
     }
 
-    
+
+    @Override
+    public void readReSplitWordDict() {
+
+    }
+
+    @Override
+    public void readFilterDiagnosisDict() {
+
+    }
+
     @Override
     public float[] toFeatureVector(Map<String, Map<String, String>> inputs) {
         // inputs {症状名:{partbody:部位名, property:属性名, duration:时间类别, sex:性别值, age:年龄值}