Browse Source

Merge remote-tracking branch 'origin/push-dev-concat-model' into push-dev-bayes

louhr 5 years ago
parent
commit
c9b302dcc9

+ 16 - 18
algorithm/src/main/java/org/algorithm/core/neural/TensorFlowModelLoadFactory.java

@@ -6,6 +6,7 @@ import org.diagbot.pub.utils.PropertiesUtil;
 
 /**
  * Tensorlflow 模型加载工厂
+ *
  * @Author: bijl
  * @Date: 2018年7月19日-下午7:28:58
  * @Description:
@@ -14,32 +15,31 @@ public class TensorFlowModelLoadFactory {
 
     /**
      * 加载并创建模型类
-     * @param modelVersion  模型版本号
+     *
+     * @param modelVersion 模型版本号
      * @return 模型
      */
     public static TensorflowModel create(String modelVersion) {
-        
-        
+
+
         PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
-        
-        String inputOpName = "X";  // 统一输入op名称
-        String outputOpName = "softmax/softmax";  // 统一输出op名称
-        
+
+
 //        NNDataSet dataSet = new NNDataSetImplNonParallel(modelVersion);  // 新模型
         NNDataSet dataSet = new NNDataSetImpl(modelVersion);  // 老模型
 
-        String modelPath =prop.getProperty("basicPath");  // 模型基本路径
+        String modelPath = prop.getProperty("basicPath");  // 模型基本路径
         modelVersion = prop.getProperty(modelVersion);
         modelPath = modelPath.replace("model_version_replacement", modelVersion);  // 生成模型路径
-        
-        TensorflowModel tm = new TensorflowModel(modelPath, inputOpName, outputOpName,
-                dataSet);
+
+        TensorflowModel tm = new TensorflowModel(modelPath, dataSet);
         return tm;
     }
 
     /**
      * 加载并创建模型类
-     * @param modelVersion  模型版本号
+     *
+     * @param modelVersion 模型版本号
      * @return 模型
      */
     public static TensorflowModel createAndFilterDiagnosis(String modelVersion) {
@@ -47,20 +47,18 @@ public class TensorFlowModelLoadFactory {
 
         PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
 
-        String inputOpName = "X";  // 统一输入op名称
-        String outputOpName = "softmax/softmax";  // 统一输出op名称
-
         NNDataSet dataSet = new NNDataSetImpl(modelVersion);  // 老模型
 
         dataSet.setDoFilterDiagnosis(true);
         dataSet.readFilterDiagnosisDict();
+        dataSet.setWithSequenceInputs(true);  // 使用序列输入
+        dataSet.readChar2IdDict(modelVersion);  // 读取字符字典
 
-        String modelPath =prop.getProperty("basicPath");  // 模型基本路径
+        String modelPath = prop.getProperty("basicPath");  // 模型基本路径
         modelVersion = prop.getProperty(modelVersion);
         modelPath = modelPath.replace("model_version_replacement", modelVersion);  // 生成模型路径
 
-        TensorflowModel tm = new TensorflowModel(modelPath, inputOpName, outputOpName,
-                dataSet);
+        TensorflowModel tm = new TensorflowModel(modelPath, dataSet);
         return tm;
     }
 

+ 90 - 34
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -6,43 +6,55 @@ import org.tensorflow.Session;
 import org.tensorflow.Tensor;
 
 import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.util.HashMap;
 import java.util.Map;
 
 /**
  * tensorflow 模型类,要求单个样本是1维向量,而不是高维向量
+ *
  * @Author: bijl
  * @Date: 2018年7月19日-下午7:21:24
  * @Description:
  */
 public class TensorflowModel {
-    
-    private final String INPUT_OPERATION_NAME;   // 输入op的名称
-    private final String OUTPUT_OPERATION_NAME;  // 输出op的名称
+
+
+    private final String X = "X";  // 输入op x的名字
+    private final String Char_ids = "Char_ids";  // 输入op Char_ids的名字
+    private final String Pos_ids = "Pos_ids";  // 输入op Pos_ids的名字
+    private final String SOFT_MAX = "softmax/softmax";  // 输出op的名称
+
     private final int NUM_FEATURE;  // 特征个数
     private final int NUM_LABEL;  //  标签(类别)个数
     private SavedModelBundle bundle; // 模型捆绑
     private Session session;  // 会话
     private NNDataSet dataSet;  // 数据集
-    
+
+
+    private boolean withSequenceInputs = false;  // 是否带有序列输入
+    private final int MAX_LEN; // 最大长度
+
+
     /**
-     * 
-     * @param exportDir  模型保存地址
-     * @param inputOpName  输入op的名称
-     * @param outputOpName  输出op的名称
-     * @param dataSet  模型使用的数据集
+     * @param exportDir 模型保存地址
+     * @param dataSet   模型使用的数据集
      */
-    public TensorflowModel(String exportDir, String inputOpName, String outputOpName, NNDataSet dataSet) {
-        this.INPUT_OPERATION_NAME = inputOpName;
-        this.OUTPUT_OPERATION_NAME = outputOpName;
+    public TensorflowModel(String exportDir, NNDataSet dataSet) {
+
+        this.init(exportDir);
         this.dataSet = dataSet;
         this.NUM_FEATURE = this.dataSet.getNumFeature();
         this.NUM_LABEL = this.dataSet.getNumLabel();
-        this.init(exportDir);
-                
+
+        // 序列数据有段的属性
+        this.MAX_LEN = this.dataSet.getMAX_LEN();
+        this.withSequenceInputs = this.dataSet.isWithSequenceInputs();
     }
-    
+
     /**
      * 初始化:加载模型,获取会话。
+     *
      * @param exportDir
      */
     public void init(String exportDir) {
@@ -54,29 +66,61 @@ public class TensorflowModel {
         }
 
         // create the session from the Bundle
-        this.session = bundle.session(); 
+        this.session = bundle.session();
+    }
+
+
+    /**
+     * 包装序列化输入
+     *
+     * @param sequenceValuesMap 序列输入的map
+     * @param numExamples       样本数
+     * @return
+     */
+    private Map<String, Tensor<Integer>> wrapSequenceInputs(Map<String, int[]> sequenceValuesMap, int numExamples) {
+        long[] inputShape = {numExamples, this.MAX_LEN};
+        Map<String, Tensor<Integer>> sequenceTensorMap = new HashMap<>();
+        for (Map.Entry<String, int[]> entry : sequenceValuesMap.entrySet()) {
+            String mapKey = entry.getKey();
+            Tensor<Integer> inputTensor = Tensor.create(
+                    inputShape,
+                    IntBuffer.wrap(entry.getValue())
+            );
+            sequenceTensorMap.put(mapKey, inputTensor);
+        }
+
+        return sequenceTensorMap;
     }
-    
+
+
     /**
      * 运行模型
-     * @param inputValues  输入值
-     * @param numExamples  样本个数
+     *
+     * @param inputValues 输入值
+     * @param numExamples 样本个数
      * @return 模型的输出
      */
-    private float[][] run(float[] inputValues, int numExamples){
-//        long[] inputShape = {numExamples, this.NUM_FEATURE, 4, 1};  // 新模型
-        long[] inputShape = {numExamples, this.NUM_FEATURE};  // 老模型
+    private float[][] run(float[] inputValues, Map<String, int[]> sequenceValues, int numExamples) {
+        long[] inputShape = {numExamples, this.NUM_FEATURE};
         Tensor<Float> inputTensor = Tensor.create(
-                inputShape,  
-                FloatBuffer.wrap(inputValues) 
+                inputShape,
+                FloatBuffer.wrap(inputValues)
         );
-        return this.session.runner().feed(this.INPUT_OPERATION_NAME, inputTensor)
+
+        // 序列数据
+        if (this.withSequenceInputs){
+            Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
+            this.session.runner().feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
+                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids));
+        }
+
+        return this.session.runner().feed(this.X, inputTensor)
                 .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                .fetch(this.OUTPUT_OPERATION_NAME).run().get(0)
+                .fetch(this.SOFT_MAX).run().get(0)
                 .copyTo(new float[numExamples][this.NUM_LABEL]);
     }
-    
-    
+
+
     /**
      * 运行模型,并将结果打包成目标格式
      */
@@ -85,14 +129,22 @@ public class TensorflowModel {
         float sum = 0;
         for (float f : inputValues)
             sum += f;
-        if(sum == 0)  // 如果输入没有有效特征,则直接返回null
+        if (sum == 0)  // 如果输入没有有效特征,则直接返回null
             return null;
-        
-        float[][] predict = this.run(inputValues, 1);  // 一次一个样本
-        return this.dataSet.wrap(predict);  
+
+        Map<String, int[]> sequenceValues = null;
+        if (this.withSequenceInputs){
+            sequenceValues = new HashMap<>();
+            sequenceValues.put(this.Char_ids, this.dataSet.toCharIds(inputs));
+            sequenceValues.put(this.Pos_ids, this.dataSet.toPosIds(inputs));
+        }
+
+
+        float[][] predict = this.run(inputValues, sequenceValues, 1);  // 一次一个样本
+        return this.dataSet.wrap(predict);
     }
-    
-    
+
+
     /**
      * 关闭会话,释放资源
      */
@@ -101,4 +153,8 @@ public class TensorflowModel {
         this.bundle.close();
     }
 
+    public void setWithSequenceInputs(boolean withSequenceInputs) {
+        this.withSequenceInputs = withSequenceInputs;
+    }
+
 }

+ 194 - 17
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSet.java

@@ -21,12 +21,26 @@ public abstract class NNDataSet {
 
     protected final Map<String, Integer> LABEL_DICT = new HashMap<>();
     protected final Map<String, Integer> NEGATIVE_DICT = new HashMap<>();
-    protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>();
-    protected final Map<String, Map<String, Integer>> RELATED_DIAGNOSIS_DICT = new HashMap<>();
-    protected final List<String> FEATURE_NAME_STORE = new ArrayList<>();
+
     private final String[] FEATURE_DICT_ARRAY;
     private final String[] LABEL_DICT_ARRAY;
-    private boolean doFilterDiagnosis = false;
+
+    // 再分词和疾病过滤相关容器
+    protected final Map<String, String> RE_SPLIT_WORD_DICT = new HashMap<>();  // 在分词表
+    protected final List<String> FEATURE_NAME_STORE = new ArrayList<>();  // 特征保存
+    protected final Map<String, Map<String, Integer>> RELATED_DIAGNOSIS_DICT = new HashMap<>();  // 特征与疾病相关表
+    private boolean doFilterDiagnosis = false;  // 是否做疾病过滤
+
+    private final float firstRateThreshold = 0.15f;  // 第一个疾病的概率阈值
+    private final float rateSumThreshold = 0.6f;  // 概率和阈值
+    private final int numToPush = 3;  // 推荐推送的个数
+    private final float rapidFallTimes = 5;  // 骤降倍数
+
+    // 序列数据
+    private final int MAX_LEN = 257;
+    private boolean withSequenceInputs = false;  // 是否带有序列输入
+    protected final Map<String, Integer> CHAR2ID_DICT = new HashMap<>();
+
 
     public NNDataSet(String modelAndVersion) {
         this.readDict(modelAndVersion);
@@ -36,6 +50,8 @@ public abstract class NNDataSet {
         this.LABEL_DICT_ARRAY = new String[this.NUM_LABEL];
         this.makeDictArr();
         this.readReSplitWordDict();
+
+
     }
 
     /**
@@ -46,11 +62,33 @@ public abstract class NNDataSet {
      */
     public abstract float[] toFeatureVector(Map<String, Map<String, String>> inputs);
 
+    /**
+     * 装外部输入转为字符ids
+     *
+     * @param inputs
+     * @return
+     */
+    public abstract int[] toCharIds(Map<String, Map<String, String>> inputs);
+
+    /**
+     * 装外部输入转为位置ids
+     *
+     * @param inputs
+     * @return
+     */
+    public abstract int[] toPosIds(Map<String, Map<String, String>> inputs);
+
     /**
      * 读取特征和类别字典
      */
     public abstract void readDict(String modelAndVersion);
 
+
+    /**
+     * 读取特征和类别字典
+     */
+    public abstract void readChar2IdDict(String modelAndVersion);
+
     /**
      * 读取再分词字典
      */
@@ -74,36 +112,152 @@ public abstract class NNDataSet {
     }
 
     /**
-     * 打包特征名和概率 + 过滤疾病
-     * 基本操作,过滤前20个疾病,如果
+     * 推送个数过滤[无效病历]
+     * 规则:最大概率疾病的概率要超过给定阈值,如果不超过,则认为疾病不收敛,不予推送
+     *
+     * @param nameAndValueListSorted
+     */
+    private void pushCountFilterBefore(List<NameAndValue> nameAndValueListSorted) {
+        if (nameAndValueListSorted.get(0).getValue() < this.firstRateThreshold)
+            nameAndValueListSorted.clear();
+    }
+
+    /**
+     * 推送个数过滤[概率和和概率骤降过滤]
+     * 规则:
+     * 1- 为了防止一棍子打死,我们还是尽量要推送3个病历的,除非概率骤降。
+     * 2- 概率骤降过滤,当病历收敛到一个或几个疾病之后,再出现的疾病,概率会骤然下降很多倍
+     * ,这时,这个疾病差不多是随机推送的,因此要过滤掉。【都要做】
+     * 2- 概率和,就是概率和不超过某个阈值【只有在剩余疾病个数超过阈值时做】
+     *
+     * @param nameAndValueListSorted
+     */
+    private void pushCountFilterAfter(List<NameAndValue> nameAndValueListSorted) {
+
+        // 如果不超过尽量推送的个数,只做概率骤降判断
+        Iterator<NameAndValue> it = nameAndValueListSorted.iterator();
+        boolean deleteTheRest = false;   // 是否删除剩余的疾病
+        float preRate = 0.0f; // 前一个疾病的概率
+        int restCnt = 0;  // 剩余疾病数
+        float rateSum = 0.0f;  // 概率和
+
+        while (it.hasNext()) {
+            NameAndValue nameAndValue = it.next();
+            if (!deleteTheRest) {
+                // 相对于前一个疾病概率骤降rapidFallTimes倍
+                if (preRate / nameAndValue.getValue() >= this.rapidFallTimes)
+                    deleteTheRest = true;
+                else {
+                    rateSum += nameAndValue.getValue();
+                    preRate = nameAndValue.getValue();
+                    restCnt += 1;
+                }
+            }
+
+            if (deleteTheRest)  // 删除剩下的疾病
+                it.remove();
+
+
+            if (!deleteTheRest && restCnt >= this.numToPush) {
+
+                // 如果超过尽量推送的个数,那么做概率和阈值过滤【从下一个开始删除】
+                if (rateSum >= this.rateSumThreshold)
+                    deleteTheRest = true;
+            }
+        }
+
+    }
+
+    /**
+     * 打包特征名和概率 + 过滤疾病 + 推送个数选择
+     * 基本操作,过滤前20个疾病,如果有疾病留下,否则前50个疾病
      *
      * @param predict 模型输出
      * @return
      */
-    public Map<String, Float> wrapAndFilter(float[][] predict) {
+    public Map<String, Float> wrapAndFilterWithPushCountFilter(float[][] predict) {
         List<NameAndValue> nameAndValueList = new ArrayList<>();
         for (int i = 0; i < predict[0].length; i++)
             nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
         nameAndValueList.sort(Comparator.reverseOrder());  // 按概率从大到小排列
 
+        // TODO:delete
+        System.out.println("原来__推送:...............................................................");
+        System.out.println(nameAndValueList.subList(0, 10));
+
+        pushCountFilterBefore(nameAndValueList);  // 推送个数过滤【无效病历过滤】
+
+        nameAndValueList = filterDiagnosis(nameAndValueList);  // 疾病过滤
+
+        this.pushCountFilterAfter(nameAndValueList);  // 推送个数过滤【概率骤降和概率和阈值过滤】
+
+        // TODO:delete
+        System.out.println("新版本__最终__推送:.......................................................");
+        System.out.println("长度:" + nameAndValueList.size());
+        System.out.println(nameAndValueList);
+
         Map<String, Float> result = new HashMap<>();
+        for (NameAndValue nameAndValue : nameAndValueList)
+            result.put(nameAndValue.getName(), nameAndValue.getValue());
+
+        return result;
+    }
+
+    /**
+     * 疾病过滤
+     * 基本规则:
+     * 如果没有一个特征与该疾病共现过,那么删除该疾病
+     *
+     * @param nameAndValueListSorted
+     * @return
+     */
+    public List<NameAndValue> filterDiagnosis(List<NameAndValue> nameAndValueListSorted) {
         Integer cnt = 0;
         String diagnosis;
         NameAndValue nameAndValue;
         Map<String, Integer> relatedDiagnoses = null;
-        for (int i = 0; i < nameAndValueList.size(); i++) {
-            nameAndValue = nameAndValueList.get(i);
+        List<NameAndValue> candidateNameAndValues = new ArrayList<>();
+        for (int i = 0; i < nameAndValueListSorted.size(); i++) {
+            nameAndValue = nameAndValueListSorted.get(i);
             diagnosis = nameAndValue.getName();
+
             for (String featureName : this.FEATURE_NAME_STORE) {
                 relatedDiagnoses = this.RELATED_DIAGNOSIS_DICT.get(featureName);
                 if (relatedDiagnoses != null && relatedDiagnoses.get(diagnosis) != null) {
-                    result.put(nameAndValue.getName(), nameAndValue.getValue());
+                    candidateNameAndValues.add(nameAndValue);
                     cnt += 1;
+                    break;  // 有一个共现即可
                 }
             }
             if ((i >= 20 || i >= 50) && cnt > 0)  // 如果前20或50个推送中有相关的疾病,只过滤他们
                 break;
         }
+        return candidateNameAndValues;
+    }
+
+    /**
+     * 打包特征名和概率 + 过滤疾病
+     * 基本操作,过滤前20个疾病,如果
+     *
+     * @param predict 模型输出
+     * @return
+     */
+    public Map<String, Float> wrapAndFilter(float[][] predict) {
+        List<NameAndValue> nameAndValueList = new ArrayList<>();
+        for (int i = 0; i < predict[0].length; i++)
+            nameAndValueList.add(new NameAndValue(this.LABEL_DICT_ARRAY[i], predict[0][i]));
+        nameAndValueList.sort(Comparator.reverseOrder());  // 按概率从大到小排列
+
+        nameAndValueList = filterDiagnosis(nameAndValueList);  // 疾病过滤
+
+        // TODO:delete
+        System.out.println("原版本__最终__推送 ......................................................");
+        System.out.println("长度:" + nameAndValueList.size());
+        System.out.println(nameAndValueList);
+
+        Map<String, Float> result = new HashMap<>();
+        for (NameAndValue nameAndValue : nameAndValueList)
+            result.put(nameAndValue.getName(), nameAndValue.getValue());
         return result;
     }
 
@@ -137,6 +291,14 @@ public abstract class NNDataSet {
         public String getName() {
             return name;
         }
+
+        @Override
+        public String toString() {
+            return "NameAndValue{" +
+                    "name='" + name + '\'' +
+                    ", value=" + value +
+                    '}';
+        }
     }
 
     /**
@@ -147,8 +309,11 @@ public abstract class NNDataSet {
      */
     public Map<String, Float> wrap(float[][] predict) {
         if (this.doFilterDiagnosis)  // 过滤疾病
-            return this.wrapAndFilter(predict);
-        else
+//            r
+        {
+            this.wrapAndFilter(predict);
+            return this.wrapAndFilterWithPushCountFilter(predict);
+        } else
             return this.basicWrap(predict);
     }
 
@@ -175,17 +340,15 @@ public abstract class NNDataSet {
     }
 
     /**
-     *  存储特征名称
+     * 存储特征名称
+     *
      * @param features
      */
-    public void storeFeatureNames(Map<String, Map<String, String>> features){
+    public void storeFeatureNames(Map<String, Map<String, String>> features) {
         this.FEATURE_NAME_STORE.clear();
         this.FEATURE_NAME_STORE.addAll(features.keySet());
     }
 
-    /**
-     * @return
-     */
     public int getNumLabel() {
         return this.NUM_LABEL;
     }
@@ -195,4 +358,18 @@ public abstract class NNDataSet {
         this.doFilterDiagnosis = doFilterDiagnosis;
     }
 
+
+    public int getMAX_LEN() {
+        return MAX_LEN;
+    }
+
+
+    public void setWithSequenceInputs(boolean withSequenceInputs) {
+        this.withSequenceInputs = withSequenceInputs;
+    }
+
+
+    public boolean isWithSequenceInputs() {
+        return withSequenceInputs;
+    }
 }

+ 75 - 4
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSetImpl.java

@@ -1,12 +1,14 @@
 package org.algorithm.core.neural.dataset;
 
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONObject;
 import org.algorithm.util.TextFileReader;
 import org.diagbot.pub.utils.PropertiesUtil;
 
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.*;
 import java.util.Map.Entry;
 
 /**
@@ -66,6 +68,42 @@ public class NNDataSetImpl extends NNDataSet {
         return featureVector;
     }
 
+    @Override
+    public int[] toCharIds(Map<String, Map<String, String>> inputs) {
+        String sentence = inputs.get("sentence").get("sentence");
+        int max_len = this.getMAX_LEN();
+        int[] ids = new int[max_len];
+        char ch = '1';
+        Integer id = null;
+        for (int i = 0; i < sentence.length(); i++) {
+            ch = sentence.charAt(i);
+            id = this.CHAR2ID_DICT.get(String.valueOf(ch));
+            if (id == null) {
+                id = this.CHAR2ID_DICT.get("<UKC>");
+            }
+            ids[i] = id.intValue();
+        }
+        for (int i = sentence.length(); i < max_len; i++)  // padding
+            ids[i] = this.CHAR2ID_DICT.get("<PAD>");
+
+        return ids;
+    }
+
+    @Override
+    public int[] toPosIds(Map<String, Map<String, String>> inputs) {
+        int max_len = this.getMAX_LEN();
+        String sentence = inputs.get("sentence").get("sentence");
+        int[] pos_ids = new int[max_len];
+        for (int j=0; j<max_len; j++)
+            pos_ids[j] = max_len - 1;  // 位置的padding
+
+        // 绝对位置编码
+        for (int i = 0 ; i < (sentence.length() < max_len ? sentence.length() : max_len); i++)
+            pos_ids[i] = i;
+
+        return pos_ids;
+    }
+
     @Override
     public void readDict(String modelAndVersion) {
 
@@ -109,6 +147,39 @@ public class NNDataSetImpl extends NNDataSet {
 
     }
 
+    @Override
+    public void readChar2IdDict(String modelAndVersion) {
+
+        // 获取文件目录
+        PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
+        String filePath = prop.getProperty("basicPath");  // 基本目录
+        filePath = filePath.substring(0, filePath.indexOf("model_version_replacement"));
+
+        filePath = filePath + "char2id.bin";  // 字典文件位置
+
+        // 读取以json字符串保存的数据
+        BufferedReader br = null;
+        try {
+            br = new BufferedReader(new FileReader(filePath));  // 读取原始json文件
+            String s = null;
+            while ((s = br.readLine()) != null) {
+                JSONObject jsonObject = (JSONObject) JSON.parse(s);
+                Set<Entry<String, Object>> entries = jsonObject.entrySet();
+                for (Map.Entry<String, Object> entry : entries)
+                    this.CHAR2ID_DICT.put(entry.getKey(), (Integer) entry.getValue());
+            }
+        } catch (Exception e) {
+            e.printStackTrace();
+        } finally {
+            try {
+                br.close();
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+        }
+
+    }
+
     /**
      * 再分词:
      * 基本操作:

+ 24 - 44
algorithm/src/main/java/org/algorithm/test/Test.java

@@ -1,56 +1,36 @@
 package org.algorithm.test;
 
-import java.util.*;
-
-public class Test {
 
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
 
+public class Test {
+    
     public static void main(String[] args) {
-        List<Integer> data = new ArrayList<>();
-        data.add(1);
-        data.add(3);
-        data.add(5);
-        data.add(7);
-        Test t = new Test();
-
-        List<List<Integer>> workSpace = new ArrayList<>();
-        for (int i = 1; i < data.size(); i++) {
-            t.combinerSelect(data, new ArrayList<>(), workSpace, data.size(), i);
-        }
-
-        System.out.println(workSpace);
 
-    }
-
-    /**
-     * 组合生成器
-     *
-     * @param data      原始数据
-     * @param workSpace 自定义一个临时空间,用来存储每次符合条件的值
-     * @param k         C(n,k)中的k
-     */
-    public <E> void combinerSelect(List<E> data, List<E> workSpace, List<List<E>> result, int n, int k) {
-        List<E> copyData;
-        List<E> copyWorkSpace = null;
-
-        if (workSpace.size() == k) {
-            for (E c : workSpace)
-                System.out.print(c);
-
-            result.add(new ArrayList<>(workSpace));
-            System.out.println();
+        List<String> aList = new ArrayList<>();
+        aList.add("del");
+        aList.add("del");
+        aList.add("xx");
+        aList.add("yy");
+
+        Iterator<String> it = aList.iterator();
+        boolean xx = false;
+        while(it.hasNext()){
+            String x = it.next();
+            if (!xx){
+
+                if (x.equals("xx"))
+                    xx = true;
+            }
+            if(xx){
+                it.remove();
+            }
         }
 
-        for (int i = 0; i < data.size(); i++) {
-            copyData = new ArrayList<E>(data);
-            copyWorkSpace = new ArrayList<E>(workSpace);
+        System.out.println(aList);
 
-            copyWorkSpace.add(copyData.get(i));
-            for (int j = i; j >= 0; j--)
-                copyData.remove(j);
-            combinerSelect(copyData, copyWorkSpace, result, n, k);
-        }
     }
 
 }
-

+ 2 - 2
algorithm/src/main/resources/algorithm.properties

@@ -1,8 +1,8 @@
 ################################ model basic url ###################################
 
 #basicPath=E:/project/push/algorithm/src/main/models/model_version_replacement/model
-basicPath=/opt/models/dev/models/model_version_replacement/model
-#basicPath=F:/models/model_version_replacement/model
+#basicPath=/opt/models/dev/models/model_version_replacement/model
+basicPath=E:/models_2019_9_24_16_21_29/model_version_replacement/model
 
 ############################### current model version ################################
 diagnosisPredict.version=outpatient_556_IOE_1

+ 5 - 0
common-push/src/main/java/org/diagbot/common/push/work/ParamsDataProxy.java

@@ -136,6 +136,11 @@ public class ParamsDataProxy {
                 re.addToSearchDataInputs(execute, searchData);
             }
         }
+        //模型需要病历文本信息传入
+        Map<String, String> map = new HashMap<>();
+        map.put("sentence", searchData.getSymptom());
+        searchData.getInputs().put("sentence", map);
+
     }
 
     /**

+ 1 - 1
nlp-web/src/main/resources/application.yml

@@ -12,7 +12,7 @@ spring:
       charset: UTF-8
       enabled: true
   datasource:       # mybatis 配置,使用druid数据源
-    url: jdbc:mysql://1.1.1.1:3306/diagbot-app?useUnicode=true&characterEncoding=UTF-8
+    url: jdbc:mysql://192.168.2.235:3306/med-s?useUnicode=true&characterEncoding=UTF-8
     username: root
     password: diagbot@20180822
     type: com.alibaba.druid.pool.DruidDataSource

+ 3 - 1
nlp/src/main/java/org/diagbot/nlp/feature/extract/CaseTokenFeature.java

@@ -14,7 +14,9 @@ import java.util.Map;
 public class CaseTokenFeature extends CaseToken {
     private NegativeEnum[] nees_symptom = new NegativeEnum[]{NegativeEnum.SYMPTOM,
             NegativeEnum.BODY_PART, NegativeEnum.PROPERTY, NegativeEnum.DEEP, NegativeEnum.DISEASE,
-            NegativeEnum.CAUSE, NegativeEnum.VITAL_RESULT, NegativeEnum.DIAG_STAND};
+            NegativeEnum.CAUSE, NegativeEnum.VITAL_RESULT, NegativeEnum.VITAL_INDEX_VALUE, NegativeEnum.DIAG_STAND,
+            NegativeEnum.SYMPTOM_PERFORMANCE, NegativeEnum.MEDICINE,NegativeEnum.MEDICINE_NAME, NegativeEnum.MEDICINE_PRD,
+            NegativeEnum.OPERATION, NegativeEnum.TREATMENT, NegativeEnum.SYMPTOM_INDEX};
 
     {
         stop_symbol = NlpUtil.extendsSymbol(stop_symbol, new String[]{",", ",", ":", ":"});

+ 7 - 1
nlp/src/main/java/org/diagbot/nlp/util/NegativeEnum.java

@@ -12,7 +12,7 @@ public enum NegativeEnum {
     SYMPTOM_PERFORMANCE("26"), NUMBER_QUANTIFIER("27"), DIGITS("28"),
     OTHER("44"),
     VITAL_INDEX("33"), VITAL_INDEX_VALUE("34"), VITAL_RESULT("35"),
-    ADDRESS("36"), PERSON("38"), PERSON_FEATURE_DESC("39"), PUB_NAME("46"),
+    ADDRESS("36"), PERSON("38"), PERSON_FEATURE_DESC("39"), PUB_NAME("46"), MEDICINE_NAME("53"),MEDICINE_PRD("54"),
     RETURN_VISIT("68"), DIAG_STAND("70");
     private String value;
 
@@ -150,6 +150,12 @@ public enum NegativeEnum {
             case "46":
                 negativeEnum = NegativeEnum.PUB_NAME;
                 break;
+            case "53":
+                negativeEnum = NegativeEnum.MEDICINE_NAME;
+                break;
+            case "54":
+                negativeEnum = NegativeEnum.MEDICINE_PRD;
+                break;
             case "68":
                 negativeEnum = NegativeEnum.RETURN_VISIT;
                 break;

+ 1 - 1
push-web/src/main/resources/static/dist/js/push.js

@@ -1,4 +1,4 @@
-var nlp_web_url = "http://192.168.2.186:5002/nlp-web";
+var nlp_web_url = "http://192.168.3.100:5002/nlp-web";
 var bigdata_web_url = "http://192.168.2.186:5001/bigdata-web";
 var graph_web_url = "http://192.168.2.186:5003/graph-web";
 // var push_web_url = "http://192.168.2.234:5008/push-web";

+ 461 - 0
push-web/src/main/resources/static/pages/eyehospital/list.html

@@ -0,0 +1,461 @@
+<!DOCTYPE html>
+<html>
+<head>
+    <meta charset="utf-8">
+    <meta http-equiv="X-UA-Compatible" content="IE=edge">
+    <title>AdminLTE 2 | Invoice</title>
+    <!-- Tell the browser to be responsive to screen width -->
+    <meta content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no" name="viewport">
+    <!-- Bootstrap 3.3.6 -->
+    <link rel="stylesheet" href="../bootstrap/css/bootstrap.min.css">
+    <!-- Font Awesome -->
+    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.5.0/css/font-awesome.min.css">
+    <!-- Ionicons -->
+    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/ionicons/2.0.1/css/ionicons.min.css">
+    <!-- Theme style -->
+    <link rel="stylesheet" href="../dist/css/AdminLTE.min.css">
+    <!-- AdminLTE Skins. Choose a skin from the css/skins
+         folder instead of downloading all of them to reduce the load. -->
+    <link rel="stylesheet" href="../dist/css/skins/_all-skins.min.css">
+
+    <!-- HTML5 Shim and Respond.js IE8 support of HTML5 elements and media queries -->
+    <!-- WARNING: Respond.js doesn't work if you view the page via file:// -->
+    <!--[if lt IE 9]>
+    <script src="https://oss.maxcdn.com/html5shiv/3.7.3/html5shiv.min.js"></script>
+    <script src="https://oss.maxcdn.com/respond/1.4.2/respond.min.js"></script>
+    <![endif]-->
+    <style>
+        .interval {
+            padding: 1px 8px !important;
+        }
+
+        #my_file {
+            visibility: hidden; /* 隐藏 */
+        }
+    </style>
+</head>
+<body class="hold-transition skin-blue sidebar-mini">
+    <!-- Left side column. contains the logo and sidebar -->
+    <!-- Content Wrapper. Contains page content -->
+    <div class="content-wrapper">
+        <!-- Content Header (Page header) -->
+        <section class="content-header">
+            <h1>
+                眼科医院
+            </h1>
+            <!--<ol class="breadcrumb">
+                <li><a href="#"><i class="fa fa-dashboard"></i> Home</a></li>
+                <li><a href="#">Tables</a></li>
+                <li class="active">Data tables</li>
+            </ol>-->
+        </section>
+
+        <form role="form">
+            <div class="box-body">
+                <div class="form-group">
+                    <label for="symptom_id">主诉</label>&nbsp;
+                    <input type="text" id="chief_id" placeholder="" size="150">&nbsp;&nbsp;&nbsp;&nbsp;
+                </div>
+                <div class="form-group">
+                    <label for="symptom_id">现病史</label>&nbsp;
+                    <input type="text" id="symptom_id" placeholder="" size="150">&nbsp;&nbsp;&nbsp;&nbsp;
+                </div>
+                <div class="form-group">
+                    <label for="vital_id">专科检查</label>&nbsp;
+                    <input type="text" id="vital_id" placeholder="" size="150">&nbsp;&nbsp;&nbsp;&nbsp;
+                </div>
+                <div class="form-group">
+                    <label for="pacs_id">辅检</label>&nbsp;
+                    <input type="text" id="pacs_id" placeholder="" size="150">&nbsp;&nbsp;&nbsp;&nbsp;
+                </div>
+               <!-- <div class="form-group">
+                    <label for="symptom_id">性别</label>&nbsp;
+                    <select style="padding:2px;width: 150px;" id="sex">
+                        <option></option>
+                        <option value="M">男(M)</option>
+                        <option value="F">女(F)</option>
+                    </select>&nbsp;&nbsp;&nbsp;&nbsp;
+                    <label for="symptom_id">年龄</label>&nbsp;
+                    <input type="text" id="age" placeholder="">
+                </div>-->
+            </div>
+            <!-- /.box-body -->
+            <div class="box-footer">
+                <button type="button" class="btn btn-primary" onclick="clickes();">推送</button>
+            </div>
+        </form>
+
+        <!-- Main content -->
+        <section class="content">
+            <div class="row">
+                <div class="col-xs-12">
+                    <div class="box">
+                        <div class="box-header">
+                            <h3 class="box-title">推送结果</h3>&nbsp;&nbsp;
+                        </div>
+                        <div class="box-body" id="feature_inputs_div">
+                        </div>
+                    </div>
+                    <!-- /.box -->
+                </div>
+            </div>
+            <!-- /.row -->
+                <!-- /.col -->
+            <!-- /.row -->
+        </section>
+        <!-- /.content -->
+    </div>
+    <!-- /.content-wrapper -->
+
+    <!-- /.control-sidebar -->
+    <!-- Add the sidebar's background. This div must be placed
+         immediately after the control sidebar -->
+    <div class="control-sidebar-bg"></div>
+<!-- ./wrapper -->
+<div class="modal fade" id="modal-default">
+    <div class="modal-dialog">
+        <div class="modal-content">
+            <div class="modal-header">
+                <button type="button" class="close" data-dismiss="modal" aria-label="Close">
+                    <span aria-hidden="true">&times;</span></button>
+                <h4 class="modal-title">消息</h4>
+            </div>
+            <div class="row" id="modal-loading">
+                <!-- /.col -->
+                <div class="col-md-12">
+                    <div class="box box-danger box-solid">
+                        <div class="box-header">
+                            <h3 class="box-title">诊断归一</h3>
+                        </div>
+                        <div class="box-body">
+                            此过程可能需要较长时间,请耐心等待... ...
+                        </div>
+                        <!-- /.box-body -->
+                        <!-- Loading (remove the following to stop the loading)-->
+                        <div class="overlay">
+                            <i class="fa fa-refresh fa-spin"></i>
+                        </div>
+                        <!-- end loading -->
+                    </div>
+                    <!-- /.box -->
+                </div>
+                <!-- /.col -->
+            </div>
+            <!-- /.row -->
+            <div class="modal-body">
+                <p></p>
+            </div>
+            <div class="modal-footer">
+                <button type="button" class="btn btn-default pull-left" data-dismiss="modal">Close</button>
+            </div>
+        </div>
+        <!-- /.modal-content -->
+    </div>
+    <!-- /.modal-dialog -->
+</div>
+<!-- /.modal -->
+<!-- jQuery 2.2.3 -->
+<script src="../plugins/jQuery/jquery-2.2.3.min.js"></script>
+<!-- Bootstrap 3.3.6 -->
+<script src="../bootstrap/js/bootstrap.min.js"></script>
+<!-- DataTables -->
+<script src="../plugins/datatables/jquery.dataTables.min.js"></script>
+<script src="../plugins/datatables/dataTables.bootstrap.min.js"></script>
+<!-- SlimScroll -->
+<script src="../plugins/slimScroll/jquery.slimscroll.min.js"></script>
+<!-- FastClick -->
+<script src="../plugins/fastclick/fastclick.js"></script>
+<!-- AdminLTE App -->
+<script src="../dist/js/app.min.js"></script>
+<!-- AdminLTE for demo purposes -->
+<script src="../dist/js/demo.js"></script>
+
+<script src="../dist/js/push.js"></script>
+
+<script>
+    $(function () {
+    });
+    function clickes(){
+        alert("进来了");
+        $.ajax({
+            url:push_web_url+"/eyehospital/people",//访问的地址
+            type:"get",
+            dataType:'JSON',//后台返回的数据格式类型
+            success:function (data) {
+                alert("成功了");
+                $("#feature_inputs_div").append(data.name);
+            }
+        })
+    };
+    function bayesPage(resourceType) {
+        var diag = $("#diag_id").val();
+        var symptom = $("#symptom").val();
+        if (diag != '' && symptom == '') {
+            $('#diag_list').html("");
+            $('#before_combine_diag_list').html("");
+            startDiag('/algorithm/page_neural', '#symptom_list', '1', resourceType, '111', '1');
+            startDiag('/algorithm/page_neural', '#vital_list', '3,2,7', resourceType, '131', '3');
+            startDiag('/algorithm/page_neural', '#lis_list', '4,2,7', resourceType, '141', '4');
+            startDiag('/algorithm/page_neural', '#pacs_list', '5,2,7', resourceType, '151', '5');
+
+        } else {
+            startDiag('/algorithm/page_neural', '#symptom_list', '1', resourceType, '11', '1');
+            startDiag('/algorithm/page_neural', '#vital_list', '3,2,7', resourceType, '31', '3');
+            startDiag('/algorithm/page_neural', '#lis_list', '4,2,7', resourceType, '41', '4');
+            startDiag('/algorithm/page_neural', '#pacs_list', '5,2,7', resourceType, '51', '5');
+
+            startDiagMapping('/algorithm/page_neural', '#diag_list', '2', resourceType, '21', '2');
+            startDiagMapping('/algorithm/page_neural', '#before_combine_diag_list', '2', resourceType, '21', '6');
+        }
+    }
+
+    function startDiagMapping(url, obj, featureType, resourceType, algorithmClassify, tp) {
+        $(obj).DataTable({
+            "paging": false,
+            "bPaginate" : true,
+            "lengthChange": true,
+            "searching": false,
+            "ordering": false,
+            "info": false,
+            "autoWidth": false,
+            "serverSide": true,
+            "destroy": true,
+            "iDisplayLength": 25,
+            "columns": [
+                {"data": "featureName"},
+                {"data": "extraProperty"},
+                {"data": "rate"}
+            ],
+            "ajax": {
+                "url": push_web_url + url,
+                "data": function ( d ) {
+                    d.featureType = featureType;
+                    d.resourceType = resourceType;
+                    d.algorithmClassifyValue =  algorithmClassify;
+                    var symptom = $("#symptom_id").val();
+                    var vital = $("#vital_id").val();
+                    var past = $("#past_id").val();
+                    var other = $("#other_id").val();
+                    var lis = $("#lis_id").val();
+                    var pacs = $("#pacs_id").val();
+                    var lisOrder = $("#lis_order").val();
+                    var pacsOrder = $("#pacs_order").val();
+                    var diag = $("#diag_id").val();
+                    var length = $("#length").val();
+                    var sex = $("#sex").val();
+                    var age = $("#age").val();
+                    var age_start = $("#age_start").val();
+                    var age_end = $("#age_end").val();
+                    d.sysCode = "2";
+                    //添加额外的参数传给服务器
+                    if (symptom != null && symptom != undefined) {
+                        d.symptom = symptom;
+                    }
+                    if (vital != null && vital != undefined) {
+                        d.vital = vital;
+                    }
+                    if (past != null && past != undefined) {
+                        d.past = past;
+                    }
+                    if (other != null && other != undefined) {
+                        d.other = other;
+                    }
+                    if (lis != null && lis != undefined) {
+                        d.lis = lis;
+                    }
+                    if (pacs != null && pacs != undefined) {
+                        d.pacs = pacs;
+                    }
+                    if (lisOrder != null && lisOrder != undefined) {
+                        d.lisOrder = lisOrder;
+                    }
+                    if (pacsOrder != null && pacsOrder != undefined) {
+                        d.pacsOrder = pacsOrder;
+                    }
+                    if (diag != null && diag != undefined && diag != '') {
+                        d.diag = diag;
+                    }
+                    if (length != null && length != undefined) {
+                        d.length = length;
+                    }
+                    if (sex != null && sex != undefined) {
+                        d.sex = sex;
+                    }
+                    if (age != '' && age_start != age && age != undefined) {
+                        d.age = age;
+                    }
+                    if (age_start != '' && age_start != null && age_start != undefined) {
+                        d.age_start = age_start;
+                    }
+                    if (age_end != '' && age_end != null && age_end != undefined) {
+                        d.age_end = age_end;
+                    }
+                },
+                "dataSrc": function (json) {
+                    var inputs = json.data.inputs;
+                    var h = "";
+                    $.each(inputs, function (key, item) {
+                        h += "<div class='form-group'><label>" + key + "&nbsp;</label>";
+                        h += "</div>";
+                    });
+                    $("#feature_inputs_div").html(h);
+
+                    if (tp == '1') {
+                        $("#participle_symptom").html(json.data.participleSymptom);
+                        json.data = json.data.symptom;
+                    }
+                    if (tp == '2') {
+                        $("#participle_diag").html(json.data.participleSymptom);
+                        json.data = json.data.dis;
+                    }
+                    if (tp == '3') {
+                        $("#participle_vital").html(json.data.participleSymptom);
+                        json.data = json.data.vitals;
+                    }
+                    if (tp == '4') {
+                        $("#participle_lis").html(json.data.participleSymptom);
+                        json.data = json.data.labs;
+                    }
+                    if (tp == '5') {
+                        $("#participle_pacs").html(json.data.participleSymptom);
+                        json.data = json.data.pacs;
+                    }
+                    if (tp == '6') {
+                        $("#before_combine_participle_diag").html(json.data.participleSymptom);
+                        json.data = json.data.beforeCombineDis;
+                    }
+                    return json.data;
+                }
+            }
+        });
+    }
+
+
+    function startDiag(url, obj, featureType, resourceType, algorithmClassify, tp) {
+        $(obj).DataTable({
+            "paging": false,
+            "bPaginate" : true,
+            "lengthChange": true,
+            "searching": false,
+            "ordering": false,
+            "info": false,
+            "autoWidth": false,
+            "serverSide": true,
+            "destroy": true,
+            "iDisplayLength": 25,
+            "columns": [
+                {"data": "featureName"},
+                {"data": "rate"}
+            ],
+            "ajax": {
+                "url": push_web_url + url,
+                "data": function ( d ) {
+                    d.featureType = featureType;
+                    d.resourceType = resourceType;
+                    d.algorithmClassifyValue =  algorithmClassify;
+                    var symptom = $("#symptom_id").val();
+                    var vital = $("#vital_id").val();
+                    var past = $("#past_id").val();
+                    var other = $("#other_id").val();
+                    var lis = $("#lis_id").val();
+                    var pacs = $("#pacs_id").val();
+                    var lisOrder = $("#lis_order").val();
+                    var pacsOrder = $("#pacs_order").val();
+                    var diag = $("#diag_id").val();
+                    var length = $("#length").val();
+                    var sex = $("#sex").val();
+                    var age = $("#age").val();
+                    var age_start = $("#age_start").val();
+                    var age_end = $("#age_end").val();
+                    d.sysCode = "2";
+                    //添加额外的参数传给服务器
+                    if (symptom != null && symptom != undefined) {
+                        d.symptom = symptom;
+                    }
+                    if (vital != null && vital != undefined) {
+                        d.vital = vital;
+                    }
+                    if (past != null && past != undefined) {
+                        d.past = past;
+                    }
+                    if (other != null && other != undefined) {
+                        d.other = other;
+                    }
+                    if (lis != null && lis != undefined) {
+                        d.lis = lis;
+                    }
+                    if (pacs != null && pacs != undefined) {
+                        d.pacs = pacs;
+                    }
+                    if (lisOrder != null && lisOrder != undefined) {
+                        d.lisOrder = lisOrder;
+                    }
+                    if (pacsOrder != null && pacsOrder != undefined) {
+                        d.pacsOrder = pacsOrder;
+                    }
+                    if (diag != null && diag != undefined && diag != '') {
+                        d.diag = diag;
+                    }
+                    if (length != null && length != undefined) {
+                        d.length = length;
+                    }
+                    if (sex != null && sex != undefined) {
+                        d.sex = sex;
+                    }
+                    if (age != '' && age_start != age && age != undefined) {
+                        d.age = age;
+                    }
+                    if (age_start != '' && age_start != null && age_start != undefined) {
+                        d.age_start = age_start;
+                    }
+                    if (age_end != '' && age_end != null && age_end != undefined) {
+                        d.age_end = age_end;
+                    }
+                },
+                "dataSrc": function (json) {
+                    var inputs = json.data.inputs;
+                    var h = "";
+                    $.each(inputs, function (key, item) {
+                        h += "<div class='form-group'><label>" + key + ":&nbsp;</label>";
+                        $.each(item,function (k, t) {
+                            if  (t == null) {
+                                t = "";
+                            }
+                            h += "&nbsp;(<label>" + k + ":" + t + "</label>)&nbsp;";
+                        });
+                        h += "</div>";
+                    });
+                    $("#feature_inputs_div").html(h);
+
+                    if (tp == '1') {
+                        $("#participle_symptom").html(json.data.participleSymptom);
+                        json.data = json.data.symptom;
+                    }
+                    if (tp == '2') {
+                        $("#participle_diag").html(json.data.participleSymptom);
+                        json.data = json.data.dis;
+                    }
+                    if (tp == '3') {
+                        $("#participle_vital").html(json.data.participleSymptom);
+                        json.data = json.data.vitals;
+                    }
+                    if (tp == '4') {
+                        $("#participle_lis").html(json.data.participleSymptom);
+                        json.data = json.data.labs;
+                    }
+                    if (tp == '5') {
+                        $("#participle_pacs").html(json.data.participleSymptom);
+                        json.data = json.data.pacs;
+                    }
+                    if (tp == '6') {
+                        $("#before_combine_participle_diag").html(json.data.participleSymptom);
+                        json.data = json.data.beforeCombineDis;
+                    }
+                    return json.data;
+                }
+            }
+        });
+    }
+</script>
+</body>
+</html>