|
@@ -21,12 +21,21 @@ public abstract class NNDataSet {
|
|
|
|
|
|
protected final Map<String, Integer> LABEL_DICT = new HashMap<>();
|
|
protected final Map<String, Integer> LABEL_DICT = new HashMap<>();
|
|
protected final Map<String, Integer> NEGATIVE_DICT = new HashMap<>();
|
|
protected final Map<String, Integer> NEGATIVE_DICT = new HashMap<>();
|
|
- protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>();
|
|
|
|
- protected final Map<String, Map<String, Integer>> RELATED_DIAGNOSIS_DICT = new HashMap<>();
|
|
|
|
- protected final List<String> FEATURE_NAME_STORE = new ArrayList<>();
|
|
|
|
|
|
+
|
|
private final String[] FEATURE_DICT_ARRAY;
|
|
private final String[] FEATURE_DICT_ARRAY;
|
|
private final String[] LABEL_DICT_ARRAY;
|
|
private final String[] LABEL_DICT_ARRAY;
|
|
- private boolean doFilterDiagnosis = false;
|
|
|
|
|
|
+
|
|
|
|
+ // 再分词和疾病过滤相关容器
|
|
|
|
+ protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>(); // 在分词表
|
|
|
|
+ protected final List<String> FEATURE_NAME_STORE = new ArrayList<>(); // 特征保存
|
|
|
|
+ protected final Map<String, Map<String, Integer>> RELATED_DIAGNOSIS_DICT = new HashMap<>(); // 特征与疾病相关表
|
|
|
|
+ private boolean doFilterDiagnosis = false; // 是否做疾病过滤
|
|
|
|
+
|
|
|
|
+ private final float firstRateThreshold = 0.15f; // 第一个疾病的概率阈值
|
|
|
|
+ private final float rateSumThreshold = 0.6f; // 概率和阈值
|
|
|
|
+ private final int numToPush = 3; // 推荐推送的个数
|
|
|
|
+ private final float rapidFallTimes = 5; // 骤降倍数
|
|
|
|
+
|
|
|
|
|
|
public NNDataSet(String modelAndVersion) {
|
|
public NNDataSet(String modelAndVersion) {
|
|
this.readDict(modelAndVersion);
|
|
this.readDict(modelAndVersion);
|
|
@@ -74,36 +83,152 @@ public abstract class NNDataSet {
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
- * 打包特征名和概率 + 过滤疾病
|
|
|
|
- * 基本操作,过滤前20个疾病,如果
|
|
|
|
|
|
+ * 推送个数过滤[无效病历]
|
|
|
|
+ * 规则:最大概率疾病的概率要超过给定阈值,如果不超过,则认为疾病不收敛,不予推送
|
|
|
|
+ *
|
|
|
|
+ * @param nameAndValueListSorted
|
|
|
|
+ */
|
|
|
|
+ private void pushCountFilterBefore(List<NameAndValue> nameAndValueListSorted) {
|
|
|
|
+ if (nameAndValueListSorted.get(0).getValue() < this.firstRateThreshold)
|
|
|
|
+ nameAndValueListSorted.clear();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /**
|
|
|
|
+ * 推送个数过滤[概率和和概率骤降过滤]
|
|
|
|
+ * 规则:
|
|
|
|
+ * 1- 为了防止一棍子打死,我们还是尽量要推送3个病历的,除非概率骤降。
|
|
|
|
+ * 2- 概率骤降过滤,当病历收敛到一个或几个疾病之后,再出现的疾病,概率会骤然下降很多倍
|
|
|
|
+ * ,这时,这个疾病差不多是随机推送的,因此要过滤掉。【都要做】
|
|
|
|
+ * 2- 概率和,就是概率和不超过某个阈值【只有在剩余疾病个数超过阈值时做】
|
|
|
|
+ *
|
|
|
|
+ * @param nameAndValueListSorted
|
|
|
|
+ */
|
|
|
|
+ private void pushCountFilterAfter(List<NameAndValue> nameAndValueListSorted) {
|
|
|
|
+
|
|
|
|
+ // 如果不超过尽量推送的个数,只做概率骤降判断
|
|
|
|
+ Iterator<NameAndValue> it = nameAndValueListSorted.iterator();
|
|
|
|
+ boolean deleteTheRest = false; // 是否删除剩余的疾病
|
|
|
|
+ float preRate = 0.0f; // 前一个疾病的概率
|
|
|
|
+ int restCnt = 0; // 剩余疾病数
|
|
|
|
+ float rateSum = 0.0f; // 概率和
|
|
|
|
+
|
|
|
|
+ while (it.hasNext()) {
|
|
|
|
+ NameAndValue nameAndValue = it.next();
|
|
|
|
+ if (!deleteTheRest) {
|
|
|
|
+ // 相对于前一个疾病概率骤降rapidFallTimes倍
|
|
|
|
+ if (preRate / nameAndValue.getValue() >= this.rapidFallTimes)
|
|
|
|
+ deleteTheRest = true;
|
|
|
|
+ else {
|
|
|
|
+ rateSum += nameAndValue.getValue();
|
|
|
|
+ preRate = nameAndValue.getValue();
|
|
|
|
+ restCnt += 1;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if (deleteTheRest) // 删除剩下的疾病
|
|
|
|
+ it.remove();
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ if (!deleteTheRest && restCnt >= this.numToPush) {
|
|
|
|
+
|
|
|
|
+ // 如果超过尽量推送的个数,那么做概率和阈值过滤【从下一个开始删除】
|
|
|
|
+ if (rateSum >= this.rateSumThreshold)
|
|
|
|
+ deleteTheRest = true;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /**
|
|
|
|
+ * 打包特征名和概率 + 过滤疾病 + 推送个数选择
|
|
|
|
+ * 基本操作,过滤前20个疾病,如果有疾病留下,否则前50个疾病
|
|
*
|
|
*
|
|
* @param predict 模型输出
|
|
* @param predict 模型输出
|
|
* @return
|
|
* @return
|
|
*/
|
|
*/
|
|
- public Map<String, Float> wrapAndFilter(float[][] predict) {
|
|
|
|
|
|
+ public Map<String, Float> wrapAndFilterWithPushCountFilter(float[][] predict) {
|
|
List<NameAndValue> nameAndValueList = new ArrayList<>();
|
|
List<NameAndValue> nameAndValueList = new ArrayList<>();
|
|
for (int i = 0; i < predict[0].length; i++)
|
|
for (int i = 0; i < predict[0].length; i++)
|
|
nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
|
|
nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
|
|
nameAndValueList.sort(Comparator.reverseOrder()); // 按概率从大到小排列
|
|
nameAndValueList.sort(Comparator.reverseOrder()); // 按概率从大到小排列
|
|
|
|
|
|
|
|
+ // TODO:delete
|
|
|
|
+ System.out.println("原来__推送:...............................................................");
|
|
|
|
+ System.out.println(nameAndValueList.subList(0, 10));
|
|
|
|
+
|
|
|
|
+ pushCountFilterBefore(nameAndValueList); // 推送个数过滤【无效病历过滤】
|
|
|
|
+
|
|
|
|
+ nameAndValueList = filterDiagnosis(nameAndValueList); // 疾病过滤
|
|
|
|
+
|
|
|
|
+ this.pushCountFilterAfter(nameAndValueList); // 推送个数过滤【概率骤降和概率和阈值过滤】
|
|
|
|
+
|
|
|
|
+ // TODO:delete
|
|
|
|
+ System.out.println("新版本__最终__推送:.......................................................");
|
|
|
|
+ System.out.println("长度:" + nameAndValueList.size());
|
|
|
|
+ System.out.println(nameAndValueList);
|
|
|
|
+
|
|
Map<String, Float> result = new HashMap<>();
|
|
Map<String, Float> result = new HashMap<>();
|
|
|
|
+ for (NameAndValue nameAndValue : nameAndValueList)
|
|
|
|
+ result.put(nameAndValue.getName(), nameAndValue.getValue());
|
|
|
|
+
|
|
|
|
+ return result;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /**
|
|
|
|
+ * 疾病过滤
|
|
|
|
+ * 基本规则:
|
|
|
|
+ * 如果没有一个特征与该疾病共现过,那么删除该疾病
|
|
|
|
+ *
|
|
|
|
+ * @param nameAndValueListSorted
|
|
|
|
+ * @return
|
|
|
|
+ */
|
|
|
|
+ public List<NameAndValue> filterDiagnosis(List<NameAndValue> nameAndValueListSorted) {
|
|
Integer cnt = 0;
|
|
Integer cnt = 0;
|
|
String diagnosis;
|
|
String diagnosis;
|
|
NameAndValue nameAndValue;
|
|
NameAndValue nameAndValue;
|
|
Map<String, Integer> relatedDiagnoses = null;
|
|
Map<String, Integer> relatedDiagnoses = null;
|
|
- for (int i = 0; i < nameAndValueList.size(); i++) {
|
|
|
|
- nameAndValue = nameAndValueList.get(i);
|
|
|
|
|
|
+ List<NameAndValue> candidateNameAndValues = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < nameAndValueListSorted.size(); i++) {
|
|
|
|
+ nameAndValue = nameAndValueListSorted.get(i);
|
|
diagnosis = nameAndValue.getName();
|
|
diagnosis = nameAndValue.getName();
|
|
|
|
+
|
|
for (String featureName : this.FEATURE_NAME_STORE) {
|
|
for (String featureName : this.FEATURE_NAME_STORE) {
|
|
relatedDiagnoses = this.RELATED_DIAGNOSIS_DICT.get(featureName);
|
|
relatedDiagnoses = this.RELATED_DIAGNOSIS_DICT.get(featureName);
|
|
if (relatedDiagnoses != null && relatedDiagnoses.get(diagnosis) != null) {
|
|
if (relatedDiagnoses != null && relatedDiagnoses.get(diagnosis) != null) {
|
|
- result.put(nameAndValue.getName(), nameAndValue.getValue());
|
|
|
|
|
|
+ candidateNameAndValues.add(nameAndValue);
|
|
cnt += 1;
|
|
cnt += 1;
|
|
|
|
+ break; // 有一个共现即可
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if ((i >= 20 || i >= 50) && cnt > 0) // 如果前20或50个推送中有相关的疾病,只过滤他们
|
|
if ((i >= 20 || i >= 50) && cnt > 0) // 如果前20或50个推送中有相关的疾病,只过滤他们
|
|
break;
|
|
break;
|
|
}
|
|
}
|
|
|
|
+ return candidateNameAndValues;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /**
|
|
|
|
+ * 打包特征名和概率 + 过滤疾病
|
|
|
|
+ * 基本操作,过滤前20个疾病,如果
|
|
|
|
+ *
|
|
|
|
+ * @param predict 模型输出
|
|
|
|
+ * @return
|
|
|
|
+ */
|
|
|
|
+ public Map<String, Float> wrapAndFilter(float[][] predict) {
|
|
|
|
+ List<NameAndValue> nameAndValueList = new ArrayList<>();
|
|
|
|
+ for (int i = 0; i < predict[0].length; i++)
|
|
|
|
+ nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
|
|
|
|
+ nameAndValueList.sort(Comparator.reverseOrder()); // 按概率从大到小排列
|
|
|
|
+
|
|
|
|
+ nameAndValueList = filterDiagnosis(nameAndValueList); // 疾病过滤
|
|
|
|
+
|
|
|
|
+ // TODO:delete
|
|
|
|
+ System.out.println("原版本__最终__推送 ......................................................");
|
|
|
|
+ System.out.println("长度:" + nameAndValueList.size());
|
|
|
|
+ System.out.println(nameAndValueList);
|
|
|
|
+
|
|
|
|
+ Map<String, Float> result = new HashMap<>();
|
|
|
|
+ for (NameAndValue nameAndValue : nameAndValueList)
|
|
|
|
+ result.put(nameAndValue.getName(), nameAndValue.getValue());
|
|
return result;
|
|
return result;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -137,6 +262,14 @@ public abstract class NNDataSet {
|
|
public String getName() {
|
|
public String getName() {
|
|
return name;
|
|
return name;
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public String toString() {
|
|
|
|
+ return "NameAndValue{" +
|
|
|
|
+ "name='" + name + '\'' +
|
|
|
|
+ ", value=" + value +
|
|
|
|
+ '}';
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
@@ -147,8 +280,11 @@ public abstract class NNDataSet {
|
|
*/
|
|
*/
|
|
public Map<String, Float> wrap(float[][] predict) {
|
|
public Map<String, Float> wrap(float[][] predict) {
|
|
if (this.doFilterDiagnosis) // 过滤疾病
|
|
if (this.doFilterDiagnosis) // 过滤疾病
|
|
- return this.wrapAndFilter(predict);
|
|
|
|
- else
|
|
|
|
|
|
+// r
|
|
|
|
+ {
|
|
|
|
+ this.wrapAndFilter(predict);
|
|
|
|
+ return this.wrapAndFilterWithPushCountFilter(predict);
|
|
|
|
+ } else
|
|
return this.basicWrap(predict);
|
|
return this.basicWrap(predict);
|
|
}
|
|
}
|
|
|
|
|
|
@@ -175,10 +311,11 @@ public abstract class NNDataSet {
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
- * 存储特征名称
|
|
|
|
|
|
+ * 存储特征名称
|
|
|
|
+ *
|
|
* @param features
|
|
* @param features
|
|
*/
|
|
*/
|
|
- public void storeFeatureNames(Map<String, Map<String, String>> features){
|
|
|
|
|
|
+ public void storeFeatureNames(Map<String, Map<String, String>> features) {
|
|
this.FEATURE_NAME_STORE.clear();
|
|
this.FEATURE_NAME_STORE.clear();
|
|
this.FEATURE_NAME_STORE.addAll(features.keySet());
|
|
this.FEATURE_NAME_STORE.addAll(features.keySet());
|
|
}
|
|
}
|