|
@@ -1,10 +1,10 @@
|
|
|
package org.algorithm.core.neural.dataset;
|
|
|
|
|
|
-import java.util.HashMap;
|
|
|
-import java.util.Map;
|
|
|
+import java.util.*;
|
|
|
|
|
|
/**
|
|
|
* 神经网络用数据处理模块
|
|
|
+ *
|
|
|
* @Author: bijl
|
|
|
* @Date: 2018年7月20日-下午4:01:34
|
|
|
* @Description:
|
|
@@ -13,18 +13,20 @@ public abstract class NNDataSet {
|
|
|
protected final int NUM_FEATURE;
|
|
|
private final int NUM_LABEL;
|
|
|
protected final Map<String, Integer> FEATURE_DICT = new HashMap<>();
|
|
|
-
|
|
|
+
|
|
|
// 新版本新加的三种关键词
|
|
|
protected final Map<String, Integer> PARTBODY_DICT = new HashMap<>();
|
|
|
protected final Map<String, Integer> PROPERTY_DICT = new HashMap<>();
|
|
|
protected final Map<String, Integer> DURATION_DICT = new HashMap<>();
|
|
|
-
|
|
|
+
|
|
|
protected final Map<String, Integer> LABEL_DICT = new HashMap<>();
|
|
|
protected final Map<String, Integer> NEGATIVE_DICT = new HashMap<>();
|
|
|
- protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>();
|
|
|
+ protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>();
|
|
|
+ protected final Map<String, Map<String, Integer>> RELATED_DIAGNOSIS_DICT = new HashMap<>();
|
|
|
+ protected final List<String> FEATURE_NAME_STORE = new ArrayList<>();
|
|
|
private final String[] FEATURE_DICT_ARRAY;
|
|
|
private final String[] LABEL_DICT_ARRAY;
|
|
|
-
|
|
|
+ private boolean doFilterDiagnosis = false;
|
|
|
|
|
|
public NNDataSet(String modelAndVersion) {
|
|
|
this.readDict(modelAndVersion);
|
|
@@ -35,9 +37,10 @@ public abstract class NNDataSet {
|
|
|
this.makeDictArr();
|
|
|
this.readReSplitWordDict();
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
/**
|
|
|
* 装外部输入转为特征向量
|
|
|
+ *
|
|
|
* @param inputs
|
|
|
* @return
|
|
|
*/
|
|
@@ -52,28 +55,113 @@ public abstract class NNDataSet {
|
|
|
* 读取再分词字典
|
|
|
*/
|
|
|
public abstract void readReSplitWordDict();
|
|
|
-
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 读取过滤字典
|
|
|
+ */
|
|
|
+ public abstract void readFilterDiagnosisDict();
|
|
|
+
|
|
|
/**
|
|
|
* 生成字典列表
|
|
|
*/
|
|
|
private void makeDictArr() {
|
|
|
- for (Map.Entry<String, Integer> entry : this.FEATURE_DICT.entrySet())
|
|
|
+ for (Map.Entry<String, Integer> entry : this.FEATURE_DICT.entrySet())
|
|
|
this.FEATURE_DICT_ARRAY[entry.getValue()] = entry.getKey();
|
|
|
-
|
|
|
- for (Map.Entry<String, Integer> entry : this.LABEL_DICT.entrySet())
|
|
|
+
|
|
|
+ for (Map.Entry<String, Integer> entry : this.LABEL_DICT.entrySet())
|
|
|
this.LABEL_DICT_ARRAY[entry.getValue()] = entry.getKey();
|
|
|
-
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 打包特征名和概率 + 过滤疾病
|
|
|
+ * 基本操作,过滤前20个疾病,如果
|
|
|
+ *
|
|
|
+ * @param predict 模型输出
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ public Map<String, Float> wrapAndFilter(float[][] predict) {
|
|
|
+ List<NameAndValue> nameAndValueList = new ArrayList<>();
|
|
|
+ for (int i = 0; i < predict[0].length; i++)
|
|
|
+ nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
|
|
|
+ nameAndValueList.sort(Comparator.reverseOrder()); // 按概率从大到小排列
|
|
|
+
|
|
|
+ Map<String, Float> result = new HashMap<>();
|
|
|
+ Integer cnt = 0;
|
|
|
+ String diagnosis;
|
|
|
+ NameAndValue nameAndValue;
|
|
|
+ Map<String, Integer> relatedDiagnoses = null;
|
|
|
+ for (int i = 0; i < nameAndValueList.size(); i++) {
|
|
|
+ nameAndValue = nameAndValueList.get(i);
|
|
|
+ diagnosis = nameAndValue.getName();
|
|
|
+ for (String featureName : this.FEATURE_NAME_STORE) {
|
|
|
+ relatedDiagnoses = this.RELATED_DIAGNOSIS_DICT.get(featureName);
|
|
|
+ if (relatedDiagnoses != null && relatedDiagnoses.get(diagnosis) == 1) {
|
|
|
+ result.put(nameAndValue.getName(), nameAndValue.getValue());
|
|
|
+ cnt += 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if ((i >= 20 || i >= 50) && cnt > 0) // 如果前20或50个推送中有相关的疾病,只过滤他们
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 用于排序的类
|
|
|
+ */
|
|
|
+ class NameAndValue implements Comparable<NameAndValue> {
|
|
|
+
|
|
|
+ private String name;
|
|
|
+ private Float value;
|
|
|
+
|
|
|
+ NameAndValue(String name, Float value) {
|
|
|
+ this.name = name;
|
|
|
+ this.value = value;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int compareTo(NameAndValue o) {
|
|
|
+ if (this.value > o.getValue())
|
|
|
+ return 1;
|
|
|
+ else if (this.value.equals(o.getValue()))
|
|
|
+ return 0;
|
|
|
+ else
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Float getValue() {
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String getName() {
|
|
|
+ return name;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
* 打包模型输出结果给调用者
|
|
|
- *
|
|
|
+ *
|
|
|
* @param predict 模型输出
|
|
|
* @return
|
|
|
*/
|
|
|
public Map<String, Float> wrap(float[][] predict) {
|
|
|
+ if (this.doFilterDiagnosis) // 过滤疾病
|
|
|
+ return this.wrapAndFilter(predict);
|
|
|
+ else
|
|
|
+ return this.basicWrap(predict);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 打包模型输出结果给调用者
|
|
|
+ *
|
|
|
+ * @param predict 模型输出
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ public Map<String, Float> basicWrap(float[][] predict) {
|
|
|
Map<String, Float> result = new HashMap<>();
|
|
|
- for (int i=0; i<predict[0].length; i++) { // 只返回一维向量
|
|
|
+ for (int i = 0; i < predict[0].length; i++) { // 只返回一维向量
|
|
|
result.put(this.LABEL_DICT_ARRAY[i], predict[0][i]);
|
|
|
}
|
|
|
return result;
|
|
@@ -86,6 +174,15 @@ public abstract class NNDataSet {
|
|
|
return this.NUM_FEATURE;
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * 存储特征名称
|
|
|
+ * @param features
|
|
|
+ */
|
|
|
+ public void storeFeatureNames(Map<String, Map<String, String>> features){
|
|
|
+ this.FEATURE_NAME_STORE.clear();
|
|
|
+ this.FEATURE_NAME_STORE.addAll(features.keySet());
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* @return
|
|
|
*/
|
|
@@ -93,4 +190,9 @@ public abstract class NNDataSet {
|
|
|
return this.NUM_LABEL;
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+ public void setDoFilterDiagnosis(boolean doFilterDiagnosis) {
|
|
|
+ this.doFilterDiagnosis = doFilterDiagnosis;
|
|
|
+ }
|
|
|
+
|
|
|
}
|