Browse Source

1- 添加推送疾病个数限制。

bijl 5 năm trước cách đây
mục cha
commit
e4a4ae9425

+ 151 - 14
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSet.java

@@ -21,12 +21,21 @@ public abstract class NNDataSet {
 
     protected final Map<String, Integer> LABEL_DICT = new HashMap<>();
     protected final Map<String, Integer> NEGATIVE_DICT = new HashMap<>();
-    protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>();
-    protected final Map<String, Map<String, Integer>> RELATED_DIAGNOSIS_DICT = new HashMap<>();
-    protected final List<String> FEATURE_NAME_STORE = new ArrayList<>();
+
     private final String[] FEATURE_DICT_ARRAY;
     private final String[] LABEL_DICT_ARRAY;
-    private boolean doFilterDiagnosis = false;
+
+    // 再分词和疾病过滤相关容器
+    protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>();  // 在分词表
+    protected final List<String> FEATURE_NAME_STORE = new ArrayList<>();  // 特征保存
+    protected final Map<String, Map<String, Integer>> RELATED_DIAGNOSIS_DICT = new HashMap<>();  // 特征与疾病相关表
+    private boolean doFilterDiagnosis = false;  // 是否做疾病过滤
+
+    private final float firstRateThreshold = 0.15f;  // 第一个疾病的概率阈值
+    private final float rateSumThreshold = 0.6f;  // 概率和阈值
+    private final int numToPush = 3;  // 推荐推送的个数
+    private final float rapidFallTimes = 5;  // 骤降倍数
+
 
     public NNDataSet(String modelAndVersion) {
         this.readDict(modelAndVersion);
@@ -74,36 +83,152 @@ public abstract class NNDataSet {
     }
 
     /**
-     * 打包特征名和概率 + 过滤疾病
-     * 基本操作,过滤前20个疾病,如果
+     * 推送个数过滤[无效病历]
+     * 规则:最大概率疾病的概率要超过给定阈值,如果不超过,则认为疾病不收敛,不予推送
+     *
+     * @param nameAndValueListSorted
+     */
+    private void pushCountFilterBefore(List<NameAndValue> nameAndValueListSorted) {
+        if (nameAndValueListSorted.get(0).getValue() < this.firstRateThreshold)
+            nameAndValueListSorted.clear();
+    }
+
+    /**
+     * 推送个数过滤[概率和和概率骤降过滤]
+     * 规则:
+     * 1- 为了防止一棍子打死,我们还是尽量要推送3个病历的,除非概率骤降。
+     * 2- 概率骤降过滤,当病历收敛到一个或几个疾病之后,再出现的疾病,概率会骤然下降很多倍
+     * ,这时,这个疾病差不多是随机推送的,因此要过滤掉。【都要做】
+     * 2- 概率和,就是概率和不超过某个阈值【只有在剩余疾病个数超过阈值时做】
+     *
+     * @param nameAndValueListSorted
+     */
+    private void pushCountFilterAfter(List<NameAndValue> nameAndValueListSorted) {
+
+        // 如果不超过尽量推送的个数,只做概率骤降判断
+        Iterator<NameAndValue> it = nameAndValueListSorted.iterator();
+        boolean deleteTheRest = false;   // 是否删除剩余的疾病
+        float preRate = 0.0f; // 前一个疾病的概率
+        int restCnt = 0;  // 剩余疾病数
+        float rateSum = 0.0f;  // 概率和
+
+        while (it.hasNext()) {
+            NameAndValue nameAndValue = it.next();
+            if (!deleteTheRest) {
+                // 相对于前一个疾病概率骤降rapidFallTimes倍
+                if (preRate / nameAndValue.getValue() >= this.rapidFallTimes)
+                    deleteTheRest = true;
+                else {
+                    rateSum += nameAndValue.getValue();
+                    preRate = nameAndValue.getValue();
+                    restCnt += 1;
+                }
+            }
+
+            if (deleteTheRest)  // 删除剩下的疾病
+                it.remove();
+
+
+            if (!deleteTheRest && restCnt >= this.numToPush) {
+
+                // 如果超过尽量推送的个数,那么做概率和阈值过滤【从下一个开始删除】
+                if (rateSum >= this.rateSumThreshold)
+                    deleteTheRest = true;
+            }
+        }
+
+    }
+
+    /**
+     * 打包特征名和概率 + 过滤疾病 + 推送个数选择
+     * 基本操作,过滤前20个疾病,如果有疾病留下,否则前50个疾病
      *
      * @param predict 模型输出
      * @return
      */
-    public Map<String, Float> wrapAndFilter(float[][] predict) {
+    public Map<String, Float> wrapAndFilterWithPushCountFilter(float[][] predict) {
         List<NameAndValue> nameAndValueList = new ArrayList<>();
         for (int i = 0; i < predict[0].length; i++)
             nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
         nameAndValueList.sort(Comparator.reverseOrder());  // 按概率从大到小排列
 
+        // TODO:delete
+        System.out.println("原来__推送:...............................................................");
+        System.out.println(nameAndValueList.subList(0, 10));
+
+        pushCountFilterBefore(nameAndValueList);  // 推送个数过滤【无效病历过滤】
+
+        nameAndValueList = filterDiagnosis(nameAndValueList);  // 疾病过滤
+
+        this.pushCountFilterAfter(nameAndValueList);  // 推送个数过滤【概率骤降和概率和阈值过滤】
+
+        // TODO:delete
+        System.out.println("新版本__最终__推送:.......................................................");
+        System.out.println("长度:" + nameAndValueList.size());
+        System.out.println(nameAndValueList);
+
         Map<String, Float> result = new HashMap<>();
+        for (NameAndValue nameAndValue : nameAndValueList)
+            result.put(nameAndValue.getName(), nameAndValue.getValue());
+
+        return result;
+    }
+
+    /**
+     * 疾病过滤
+     * 基本规则:
+     * 如果没有一个特征与该疾病共现过,那么删除该疾病
+     *
+     * @param nameAndValueListSorted
+     * @return
+     */
+    public List<NameAndValue> filterDiagnosis(List<NameAndValue> nameAndValueListSorted) {
         Integer cnt = 0;
         String diagnosis;
         NameAndValue nameAndValue;
         Map<String, Integer> relatedDiagnoses = null;
-        for (int i = 0; i < nameAndValueList.size(); i++) {
-            nameAndValue = nameAndValueList.get(i);
+        List<NameAndValue> candidateNameAndValues = new ArrayList<>();
+        for (int i = 0; i < nameAndValueListSorted.size(); i++) {
+            nameAndValue = nameAndValueListSorted.get(i);
             diagnosis = nameAndValue.getName();
+
             for (String featureName : this.FEATURE_NAME_STORE) {
                 relatedDiagnoses = this.RELATED_DIAGNOSIS_DICT.get(featureName);
                 if (relatedDiagnoses != null && relatedDiagnoses.get(diagnosis) != null) {
-                    result.put(nameAndValue.getName(), nameAndValue.getValue());
+                    candidateNameAndValues.add(nameAndValue);
                     cnt += 1;
+                    break;  // 有一个共现即可
                 }
             }
             if ((i >= 20 || i >= 50) && cnt > 0)  // 如果前20或50个推送中有相关的疾病,只过滤他们
                 break;
         }
+        return candidateNameAndValues;
+    }
+
+    /**
+     * 打包特征名和概率 + 过滤疾病
+     * 基本操作,过滤前20个疾病,如果
+     *
+     * @param predict 模型输出
+     * @return
+     */
+    public Map<String, Float> wrapAndFilter(float[][] predict) {
+        List<NameAndValue> nameAndValueList = new ArrayList<>();
+        for (int i = 0; i < predict[0].length; i++)
+            nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
+        nameAndValueList.sort(Comparator.reverseOrder());  // 按概率从大到小排列
+
+        nameAndValueList = filterDiagnosis(nameAndValueList);  // 疾病过滤
+
+        // TODO:delete
+        System.out.println("原版本__最终__推送 ......................................................");
+        System.out.println("长度:" + nameAndValueList.size());
+        System.out.println(nameAndValueList);
+
+        Map<String, Float> result = new HashMap<>();
+        for (NameAndValue nameAndValue : nameAndValueList)
+            result.put(nameAndValue.getName(), nameAndValue.getValue());
         return result;
     }
 
@@ -137,6 +262,14 @@ public abstract class NNDataSet {
         public String getName() {
             return name;
         }
+
+        @Override
+        public String toString() {
+            return "NameAndValue{" +
+                    "name='" + name + '\'' +
+                    ", value=" + value +
+                    '}';
+        }
     }
 
     /**
@@ -147,8 +280,11 @@ public abstract class NNDataSet {
      */
     public Map<String, Float> wrap(float[][] predict) {
         if (this.doFilterDiagnosis)  // 过滤疾病
-            return this.wrapAndFilter(predict);
-        else
+//            r
+        {
+            this.wrapAndFilter(predict);
+            return this.wrapAndFilterWithPushCountFilter(predict);
+        } else
             return this.basicWrap(predict);
     }
 
@@ -175,10 +311,11 @@ public abstract class NNDataSet {
     }
 
     /**
-     *  存储特征名称
+     * 存储特征名称
+     *
      * @param features
      */
-    public void storeFeatureNames(Map<String, Map<String, String>> features){
+    public void storeFeatureNames(Map<String, Map<String, String>> features) {
         this.FEATURE_NAME_STORE.clear();
         this.FEATURE_NAME_STORE.addAll(features.keySet());
     }