Преглед изворни кода

Merge remote-tracking branch 'origin/master'

louhr пре 6 година
родитељ
комит
37946f45f6

+ 7 - 9
algorithm/src/main/java/org/algorithm/core/cnn/dataset/RelationExtractionDataSet.java

@@ -6,7 +6,6 @@ import java.io.IOException;
 import java.util.*;
 
 import org.algorithm.core.cnn.entity.Lemma;
-import org.algorithm.core.cnn.entity.LemmaInfo;
 import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSONObject;
 
@@ -18,8 +17,7 @@ import com.alibaba.fastjson.JSONObject;
 public class RelationExtractionDataSet {
 
     private Map<String, Integer> char2id = new HashMap<>();
-    private Map<Integer, Map<String, String>> entities_info = new HashMap<>();
-    public int maxLength = 500;
+    public final int MAX_LEN = 512;
 
 
     public RelationExtractionDataSet(String dir) {
@@ -60,7 +58,7 @@ public class RelationExtractionDataSet {
      * @return ids
      */
     public int[] sentence2ids(String sentence) {
-        int[] ids = new int[this.maxLength];
+        int[] ids = new int[this.MAX_LEN];
         char ch = '1';
         Integer id = null;
         for (int i = 0; i < sentence.length(); i++) {
@@ -71,7 +69,7 @@ public class RelationExtractionDataSet {
             }
             ids[i] = id.intValue();
         }
-        for (int i = sentence.length(); i < this.maxLength; i++)  // padding
+        for (int i = sentence.length(); i < this.MAX_LEN; i++)  // padding
             ids[i] = this.char2id.get("<PAD>");
 
         return ids;
@@ -83,7 +81,7 @@ public class RelationExtractionDataSet {
      * @return 句子中各个汉子相对于实体的位置
      */
     public int[] getRelativePositions(String sentence, String position) {
-        int[] relativePositions = new int[this.maxLength];
+        int[] relativePositions = new int[this.MAX_LEN];
         String[] positionPair = position.split(",");
         int startPos = Integer.parseInt(positionPair[0]);
         int endtPos = Integer.parseInt(positionPair[1]);
@@ -97,8 +95,8 @@ public class RelationExtractionDataSet {
                 relativePositions[i] = i - endtPos;
         }
 
-        for (int i = sentence.length(); i < this.maxLength; i++)
-            relativePositions[i] = this.maxLength - 1;
+        for (int i = sentence.length(); i < this.MAX_LEN; i++)
+            relativePositions[i] = this.MAX_LEN - 1;
 
         return relativePositions;
     }
@@ -110,7 +108,7 @@ public class RelationExtractionDataSet {
      * @return
      */
     public int[][] getExample(String sentence, Lemma entity1, Lemma entity2) {
-        int[][] example = new int[3][this.maxLength];
+        int[][] example = new int[3][this.MAX_LEN];
 
         example[0] = this.sentence2ids(sentence);
         example[1] = this.getRelativePositions(sentence, entity1.getPosition());

+ 184 - 0
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionEnsembleModel.java

@@ -0,0 +1,184 @@
+package org.algorithm.core.cnn.model;
+
+import org.algorithm.core.cnn.AlgorithmCNNExecutor;
+import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
+import org.algorithm.core.cnn.entity.Triad;
+import org.diagbot.pub.utils.PropertiesUtil;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+
+import java.io.File;
+import java.nio.FloatBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.*;
+
+/**
+ * @Author: bijl
+ * @Date: 2019/1/22 10:21
+ * @Description: 集成模型
+ */
+public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
+    private final String X_PLACEHOLDER = "X";
+    private final String PREDICTION = "prediction/prediction";
+    private final int NUM_LABEL = 1;
+    private SavedModelBundle bundle; // 模型捆绑
+    private Session session;  // 会话
+    private RelationExtractionDataSet dataSet;
+    private RelationExtractionSubModel[] subModels = new RelationExtractionSubModel[3];
+    private ExecutorService executorService = Executors.newCachedThreadPool();
+
+    public RelationExtractionEnsembleModel() {
+        PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
+
+        String modelsPath = prop.getProperty("basicPath");  // 模型基本路径
+        String dataSetPath = modelsPath.substring(0, modelsPath.indexOf("model_version_replacement"));
+        dataSetPath = dataSetPath + File.separator + "char2id.json";
+        String exportDir = modelsPath.replace("model_version_replacement", "ensemble_model_2");
+
+        this.dataSet = new RelationExtractionDataSet(dataSetPath);
+        this.init(exportDir);
+        subModels[0] = new RelationExtractionSubModel("cnn_1d_low");
+        subModels[1] = new RelationExtractionSubModel("cnn_1d_lstm_low");
+        subModels[2] = new RelationExtractionSubModel("lstm_low_api");
+    }
+
+    /**
+     * 初始化:加载模型,获取会话。
+     *
+     * @param exportDir 模型地址
+     */
+    public void init(String exportDir) {
+        /* load the model Bundle */
+        try {
+            this.bundle = SavedModelBundle.load(exportDir, "serve");
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+
+        // create the session from the Bundle
+        this.session = bundle.session();
+    }
+
+    /**
+     * 转化数据为张量
+     *
+     * @param content 句子
+     * @param triads  三元组list
+     * @return int[3][] 表示charId,pos1,pos2
+     */
+    private int[][] convertData(String content, List<Triad> triads) {
+
+        int[][] inputValues = new int[3][triads.size() * this.dataSet.MAX_LEN];
+        for (int i = 0; i < triads.size(); i++) {
+            Triad triad = triads.get(i);
+            int[][] aInput = this.dataSet.getExample(content, triad.getL_1(), triad.getL_2());
+            for (int j = 0; j < aInput.length; j++)
+                for (int k = 0; k < this.dataSet.MAX_LEN; k++)
+                    inputValues[j][i * this.dataSet.MAX_LEN] = aInput[j][k];
+        }
+
+        return inputValues;
+    }
+
+    @Override
+    public List<Triad> execute(String content, List<Triad> triads) {
+        // 句子长度不超过MAX_LEN,有三元组
+        if (content.length() > this.dataSet.MAX_LEN || triads.size() < 1) {
+            return new ArrayList<>();
+        }
+        int[][] inputValues = this.convertData(content, triads);  // shape = [3, batchSize * this.subModels.length]
+        int batchSize = triads.size();
+
+        float[] sigmoidS = new float[batchSize * this.subModels.length];  // 集成模型的输入
+
+        List<Future<float[][]>> futureList = new ArrayList<>();
+
+//         // 非并行运行子模型
+//        for (int i = 0; i < this.subModels.length; i++) {
+//            float[][] sigmoid = subModels[i].sigmoid(inputValues, batchSize);  // 子模型预测
+//            for (int j = 0; j < batchSize; j++)
+//                sigmoidS[i * batchSize + j] = sigmoid[j][0];
+//        }
+
+//         多线程运行子模型
+        for (int i = 0; i < this.subModels.length; i++) {
+            int index = i;
+            Future<float[][]> future = this.executorService.submit(new Callable<float[][]>() {
+                @Override
+                public float[][] call() throws Exception {
+                    return subModels[index].sigmoid(inputValues, batchSize);
+                }
+            });
+            futureList.add(future);
+        }
+
+        // 从future中获取数据,并填入sigmoidS中
+        for (int i = 0; i < this.subModels.length; i++) {
+            try {
+                float[][] sigmoid = futureList.get(i).get();
+                for (int j = 0; j < batchSize; j++)
+                    sigmoidS[i * batchSize + j] = sigmoid[j][0];
+
+            } catch (InterruptedException e) {
+                e.printStackTrace();
+                System.err.println("获取数据不成功");
+            } catch (ExecutionException e) {
+                e.printStackTrace();
+                System.err.println("获取数据不成功");
+            }
+        }
+//        this.executorService.shutdown();
+
+        float[][] prediction = this.run(sigmoidS, batchSize);
+
+        //设置三元组关系
+        for (int j = 0; j < prediction.length; j++) {
+            if (prediction[j][0] == 1.0)
+                triads.get(j).setRelation("有");
+            else
+                triads.get(j).setRelation("无");
+        }
+
+        //删除无关系三元组
+        List<Triad> deleteTriads = new ArrayList<>();
+        for (Triad triad : triads)
+            if ("无".equals(triad.getRelation()))  // 有关系着留下
+                deleteTriads.add(triad);
+        for (Triad triad : deleteTriads)
+            triads.remove(triad);
+
+        return triads;
+    }
+
+
+    /**
+     * @param inputValues 字符id,相对于实体1位置,相对于实体2位置
+     * @param batchSize   批量大小
+     * @return float[][] shape = [batchSize, 1]
+     */
+    private float[][] run(float[] inputValues, int batchSize) {
+        long[] shape = {batchSize, this.subModels.length};  // 老模型
+        Tensor<Float> sigmoidS = Tensor.create(
+                shape,
+                FloatBuffer.wrap(inputValues)
+        );
+
+
+        return this.session.runner()
+                .feed(this.X_PLACEHOLDER, sigmoidS)
+                .fetch(this.PREDICTION).run().get(0)
+                .copyTo(new float[batchSize][this.NUM_LABEL]);
+    }
+
+    /**
+     * 关闭会话,释放资源
+     */
+    public void close() {
+        this.session.close();
+        this.bundle.close();
+        for (RelationExtractionSubModel subModel : this.subModels)
+            subModel.close();
+    }
+}

+ 1 - 2
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionModel.java

@@ -6,7 +6,6 @@ import com.alibaba.fastjson.JSONObject;
 import com.alibaba.fastjson.TypeReference;
 import org.algorithm.core.cnn.AlgorithmCNNExecutor;
 import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
-import org.algorithm.core.cnn.entity.Lemma;
 import org.algorithm.core.cnn.entity.LemmaInfo;
 import org.algorithm.core.cnn.entity.Triad;
 import org.tensorflow.SavedModelBundle;
@@ -112,7 +111,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
      * @return
      */
     private float[][] run(int[][] inputValues, int batchSize){
-        long[] shape = {1, dataSet.maxLength};  // 老模型
+        long[] shape = {1, dataSet.MAX_LEN};  // 老模型
         Tensor<Integer> charId = Tensor.create(
                 shape,
                 IntBuffer.wrap(inputValues[0])

+ 95 - 0
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionSubModel.java

@@ -0,0 +1,95 @@
+package org.algorithm.core.cnn.model;
+
+import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
+import org.diagbot.pub.utils.PropertiesUtil;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+
+import java.io.File;
+import java.nio.IntBuffer;
+
+/**
+ * @Author: bijl
+ * @Date 2019/6/12 14:02:31
+ * @Decription:
+ */
+public class RelationExtractionSubModel {
+    private final String X_PLACEHOLDER = "X";
+    private final String pos1_PLACEHOLDER = "pos1";
+    private final String pos2_PLACEHOLDER = "pos2";
+    private String PREDICTION = null;
+    private final int NUM_LABEL = 1;
+    private SavedModelBundle bundle; // 模型捆绑
+    private Session session;  // 会话
+    protected RelationExtractionDataSet dataSet;
+
+    public RelationExtractionSubModel(String modelName) {
+
+        PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
+
+        String modelsPath = prop.getProperty("basicPath");  // 模型基本路径
+        String re_path = prop.getProperty("relationExtraction");  // 模型基本路径
+        this.PREDICTION = modelName + "/prediction/Sigmoid";
+        String exportDir = modelsPath.replace("model_version_replacement", modelName);
+        String dataSetPath = modelsPath.substring(0, modelsPath.indexOf("model_version_replacement"));
+        dataSetPath = dataSetPath + File.separator + "char2id.json";
+
+        this.dataSet = new RelationExtractionDataSet(dataSetPath);
+        this.init(exportDir);
+    }
+
+    /**
+     * 初始化:加载模型,获取会话。
+     *
+     * @param exportDir 模型dir
+     */
+    public void init(String exportDir) {
+        /* load the model Bundle */
+        try {
+            this.bundle = SavedModelBundle.load(exportDir, "serve");
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+
+        // create the session from the Bundle
+        this.session = bundle.session();
+    }
+
+    /**
+     * @param inputValues 字符id,相对于实体1位置,相对于实体2位置
+     * @param batchSize   批量大小
+     * @return float[batchSize][NUM_LABEL]
+     */
+    public float[][] sigmoid(int[][] inputValues, int batchSize) {
+        long[] shape = {batchSize, dataSet.MAX_LEN};  // 老模型
+        Tensor<Integer> charId = Tensor.create(
+                shape,
+                IntBuffer.wrap(inputValues[0])
+        );
+        Tensor<Integer> pos1 = Tensor.create(
+                shape,
+                IntBuffer.wrap(inputValues[1])
+        );
+        Tensor<Integer> pos2 = Tensor.create(
+                shape,
+                IntBuffer.wrap(inputValues[2])
+        );
+
+        return this.session.runner()
+                .feed(this.X_PLACEHOLDER, charId)  // 输入,字符id
+                .feed(this.pos1_PLACEHOLDER, pos1)  // 输入,相对位置1
+                .feed(this.pos2_PLACEHOLDER, pos2)  //  输入,相对位置2
+                .feed("keep_prob", Tensor.create(1.0f, Float.class))  // 输入,dropout保留率
+                .fetch(this.PREDICTION).run().get(0)  // 输出,tensor
+                .copyTo(new float[batchSize][this.NUM_LABEL]);  // tensor转float[]对象
+    }
+
+    /**
+     * 关闭会话,释放资源
+     */
+    public void close() {
+        this.session.close();
+        this.bundle.close();
+    }
+}

+ 75 - 0
algorithm/src/main/java/org/algorithm/test/MultiThreadsTest.java

@@ -0,0 +1,75 @@
+package org.algorithm.test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Timer;
+import java.util.concurrent.*;
+
+/**
+ * @Author: bijl
+ * @Date: 2019/6/13 10:32
+ * @Decription: 多线程测试
+ */
+public class MultiThreadsTest {
+    public static void main(String[] args) {
+        ExecutorService exe = Executors.newCachedThreadPool();
+//        for (int i = 0; i < 10; i++)
+//            exe.execute(new MyThread());
+        List<Future<Float>> futures = new ArrayList<>();
+        Future<Float> future = null;
+        for(int j=0; j<2;j ++){
+            for (int i = 0; i < 10; i++) {
+                future = exe.submit(new MyCallable("" + i));
+                System.out.println("............_" + i);
+                futures.add(future);
+
+//                try {
+//                    System.out.println(future.get());
+//                } catch (InterruptedException e) {
+//                    e.printStackTrace();
+//                } catch (ExecutionException e) {
+//                    e.printStackTrace();
+//                }
+
+            }
+            for (Future<Float> future1:futures){
+                try {
+                    System.out.println(future1.get());
+                } catch (InterruptedException e) {
+                    e.printStackTrace();
+                } catch (ExecutionException e) {
+                    e.printStackTrace();
+                }
+            }
+//            exe.shutdown();
+        }
+        System.out.println("All done.");
+    }
+}
+
+class MyThread implements Runnable {
+
+
+    public void run() {
+        System.out.println(Thread.currentThread());
+        System.out.println("............");
+    }
+}
+
+class MyCallable implements Callable<Float> {
+    private String values = null;
+
+    public MyCallable(String values) {
+        this.values = values;
+    }
+
+    @Override
+    public Float call() {
+        try {
+            Thread.sleep(5000);
+        } catch (InterruptedException e) {
+            e.printStackTrace();
+        }
+        return Float.parseFloat(this.values);
+    }
+}

+ 44 - 0
algorithm/src/main/java/org/algorithm/test/ReEnsembleModelTest.java

@@ -0,0 +1,44 @@
+package org.algorithm.test;
+
+import org.algorithm.core.cnn.entity.Lemma;
+import org.algorithm.core.cnn.entity.Triad;
+import org.algorithm.core.cnn.model.RelationExtractionEnsembleModel;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * 测试集成模型
+ *
+ * @Author: bijl
+ * @Date: 2019/6/12 17:53
+ * @Decription:
+ */
+public class ReEnsembleModelTest {
+
+    public static void main(String[] args) {
+        RelationExtractionEnsembleModel ensembleModel = new RelationExtractionEnsembleModel();
+
+        List<Triad> triads = new ArrayList<>();
+        Triad triad_1 = new Triad();
+        Lemma l_1 = new Lemma();
+        l_1.setPosition("3,4");
+        l_1.setText("剧烈");
+
+        Lemma l_2 = new Lemma();
+        l_2.setPosition("5,6");
+        l_2.setText("胸痛");
+
+        triad_1.setL_1(l_1);
+        triad_1.setL_2(l_2);
+        for (int i = 0; i < 500; i++)  // 500个样本
+            triads.add(triad_1);
+
+        long start = System.nanoTime();
+        for (int i=0; i<200; i++)  // 重复100次
+            triads = ensembleModel.execute("患者剧烈胸痛头痛失眠不安", triads);
+        long elapsedTime = System.nanoTime() - start;
+        System.out.println(triads.size());
+        System.out.println(elapsedTime);
+    }
+}

+ 27 - 0
algorithm/src/main/java/org/algorithm/test/ReSubModelTest.java

@@ -0,0 +1,27 @@
+package org.algorithm.test;
+
+import org.algorithm.core.cnn.model.RelationExtractionSubModel;
+
+/**测试子模型加载
+ * @Author: bijl
+ * @Date: 2019/6/12 16:57
+ * @Decription:
+ */
+public class ReSubModelTest {
+
+    public static void main(String[] args) {
+//        RelationExtractionSubModel subModel = new RelationExtractionSubModel("cnn_1d_low");
+//        RelationExtractionSubModel subModel = new RelationExtractionSubModel("cnn_1d_lstm_low");
+        RelationExtractionSubModel subModel = new RelationExtractionSubModel("lstm_low_api");
+
+
+        int[][] inputValues = new int[3][512];
+        for (int i=0; i< inputValues.length; i++)
+            for (int j=0; j<inputValues[0].length; j++)
+                inputValues[i][j] = 0;
+        float[][] result = subModel.sigmoid(inputValues, 1);
+        System.out.println(result[0][0]);
+        System.out.println(result.length);
+        System.out.println(result[0].length);
+    }
+}

+ 5 - 1
algorithm/src/main/java/org/algorithm/test/RelationExtractionDataSetTest.java

@@ -11,7 +11,7 @@ public class RelationExtractionDataSetTest {
 
     public static void main(String[] args) {
 
-        String filePath = "E:\\relation_extraction\\shao_yi_fu_data\\char2id.json";
+        String filePath = "E:\\char2id.json";
         String sentence = "有双手麻木感,活动后好转,颈部及肩部活动度无殊";
         RelationExtractionDataSet dataSet = new RelationExtractionDataSet(filePath);
 //        for (float id:dataSet.sentence2ids(sentence)) {
@@ -20,6 +20,10 @@ public class RelationExtractionDataSetTest {
 //        for (float id:dataSet.getRelativePositions(sentence, "1,2")) {
 //            System.out.println(id); //pass
 //        }
+        for (float id:dataSet.sentence2ids(sentence)) {
+            System.out.println(id);
+        }
+        System.out.println("...............");
 
         for (float id:dataSet.getRelativePositions(sentence, "1,2")) {
             System.out.println(id);

+ 29 - 0
algorithm/src/main/java/org/algorithm/test/TensorMethodTest.java

@@ -0,0 +1,29 @@
+package org.algorithm.test;
+
+import org.tensorflow.Tensor;
+
+import java.nio.FloatBuffer;
+
+/**
+ * 测试tensorflow Tensor方法
+ *
+ * @Author: bijl
+ * @Date: 2019/6/12 16:44
+ * @Decription:
+ */
+public class TensorMethodTest {
+
+    public static void main(String[] args) {
+        float[] inputValues = {1, 2, 3, 4, 5, 6};
+        long[] shape = {2, 3};  // 老模型
+        Tensor<Float> tensor = Tensor.create(
+                shape,
+                FloatBuffer.wrap(inputValues)
+        );
+        float[][] xx = tensor.copyTo(new float[2][3]);
+        for (int i = 0; i < xx.length; i++)
+            for (int j = 0; j < xx[0].length; j++)
+                System.out.println(i + "," + j + " --> " + xx[i][j]);
+
+    }
+}

+ 2 - 2
algorithm/src/main/resources/algorithm.properties

@@ -1,7 +1,7 @@
 ################################ model basic url ###################################
 
 #basicPath=E:/git/push/algorithm/src/main/models/model_version_replacement/model
-basicPath=/opt/models/dev/models/model_version_replacement/model
+#basicPath=/opt/models/dev/models/model_version_replacement/model
 #basicPath=E:/xxx/model_version_replacement/model
 
 ############################### current model version ################################
@@ -18,4 +18,4 @@ diagnosisToVital.version=diagnosis_to_vital_1
 
 ############################ relation extraction ######################################
 relationExtraction=relation_extraction
-relationExtractionUrl=http://192.168.2.234:54321/api/re/predict
+relationExtractionUrl=http://192.168.3.40:54321/api/re/predict

+ 9 - 0
common-service/src/main/java/org/diagbot/common/work/ResponseData.java

@@ -26,6 +26,15 @@ public class ResponseData {
     private List<MedicalIndication> medicalIndications;//量表和指标推送
 //    private Map<String,JSONObject> managementEvaluation; //管理评估
     private Map managementEvaluation;
+    private List<String> diffDiag;//鉴别诊断
+
+    public List<String> getDiffDiag() {
+        return diffDiag;
+    }
+
+    public void setDiffDiag(List<String> diffDiag) {
+        this.diffDiag = diffDiag;
+    }
 
     public Map getManagementEvaluation() {
         return managementEvaluation;

+ 15 - 6
graph-web/src/main/java/org/diagbot/graphWeb/work/GraphCalculate.java

@@ -1,5 +1,6 @@
 package org.diagbot.graphWeb.work;
 
+import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSONArray;
 import com.alibaba.fastjson.JSONObject;
 import org.diagbot.common.javabean.MangementEvaluation;
@@ -43,24 +44,25 @@ public class GraphCalculate {
         List<String> featureTypeList = Arrays.asList(featureTypes);
         logger.info("featureTypeList : " + featureTypeList);
         inputList.addAll(ss);
-        //        Driver driver = DriverManager.newDrive("192.168.2.232", "neo4j", "root");
         Neo4jAPI neo4jAPI = new Neo4jAPI(DriverManager.newDrive());
         logger.info("图谱开始推送诊断!!!!!!!!!!!");
+        String webDiag = searchData.getDiag();
+        logger.info("页面诊断为 :"+webDiag);
         //计算诊断
-        Map<String, String> condition = neo4jAPI.getCondition((String[]) inputList.toArray(new String[inputList.size()]));
+        Map<String, Map<String,String>> condition = neo4jAPI.getCondition((String[]) inputList.toArray(new String[inputList.size()]),webDiag );
         List<FeatureRate> featureRates = new ArrayList<>();
-        for (Map.Entry<String, String> d : condition.entrySet()) {
+        for (Map.Entry<String, Map<String,String>> d : condition.entrySet()) {
             if (!"低血糖反应".equals(d.getKey()) && !"胃肠道不良反应".equals(d.getKey())) {
                 FeatureRate featureRate = new FeatureRate();
                 featureRate.setFeatureName(d.getKey());
                 featureRate.setExtraProperty("");
-                featureRate.setDesc(d.getValue());
+                Map<String, String> value = d.getValue();
+                String s = JSON.toJSONString(value);
+                featureRate.setDesc(s);
                 featureRate.setRate("neo4j");
                 featureRates.add(featureRate);
             }
         }
-        String webDiag = searchData.getDiag();
-        logger.info("页面诊断为 :"+webDiag);
         Set<String> diseaseSet = condition.keySet();
         logger.info("diseaseSet :" + diseaseSet);
         Integer diseaseType = searchData.getDisType();
@@ -77,6 +79,13 @@ public class GraphCalculate {
             }
         }
         logger.info("页面导入的所有化验项为 :" +lisSet);
+        //鉴别诊断
+        /*if(webDiag != null && webDiag.trim() != ""){
+            String[] webDiagSplits = webDiag.split(",");
+            String mainDiag = webDiagSplits[0];
+            List<String> differentialDiagnose = neo4jAPI.getDifferentialDiagnose(mainDiag);
+            responseData.setDiffDiag(differentialDiagnose);
+        }*/
         //走治疗
         if (webDiag.trim() != null && webDiag.trim() != "" && featureTypeList.contains("8")) {
             // 查找页面诊断里是否有不良反应

+ 59 - 13
graph/src/main/java/org/diagbot/graph/jdbc/Neo4jAPI.java

@@ -646,8 +646,8 @@ public class Neo4jAPI {
      * @param keys
      * @return
      */
-    public Map<String, String> getCondition(String[] keys) {
-        Map<String, String> diseaseCondition = new LinkedHashMap<>();
+    public Map<String, Map<String,String>> getCondition(String[] keys,String webDiag) {
+        Map<String, Map<String,String>> diseaseCondition = new LinkedHashMap<>();
         List<String> newList = new ArrayList<>();
         ArrayList<String> fildList = new ArrayList<>();
         //输出确诊集合
@@ -690,16 +690,6 @@ public class Neo4jAPI {
                         fildList.add(js);
                     }
                 }
-                //                else {
-                //                    fildList.add(fild);
-                //                    query = "match(l)-[r:近义词]->(h) where l.name="+fild+" return h.name as js";
-                //                    StatementResult run1 = session.run(query);
-                //                    while (run1.hasNext()){
-                //                        Record next = run1.next();
-                //                        String js = next.get("js").toString();
-                //                        fildList.add(js);
-                //                    }
-                //                }
             }
             newList.addAll(fildList);
             int i = 0;
@@ -735,9 +725,34 @@ public class Neo4jAPI {
             for (String qu : quezhen) {
                 Map<String, String> dis_res = new HashMap<>();
                 dis_res.put("确诊", "");
-                diseaseCondition.put(qu, JSON.toJSONString(dis_res));
+//                diseaseCondition.put(qu, JSON.toJSONString(dis_res));
+                diseaseCondition.put(qu,dis_res);
                 logger.info("图谱推出的诊断为: " + qu);
             }
+            Set<String> queSets = diseaseCondition.keySet();
+            if(webDiag != null && webDiag.trim() != ""){
+                String[] webDiagSplits = webDiag.split(",");
+                String mainDiag = webDiagSplits[0];
+                query = propertiesUtil.getProperty("searchDifferentialDiagnose").replace("mainDis", mainDiag);
+                result = session.run(query);
+                while (result.hasNext()) {
+                    Record record = result.next();
+                    List<Object> coll = record.get("coll").asList();
+                    if(coll != null && coll.size()>0){
+                        for (Object o:coll) {
+                            if(queSets.contains(o.toString().replace("\"",""))){
+                                Map<String, String> stringStringMap = diseaseCondition.get(o.toString().replace("\"", ""));
+                                stringStringMap.put("鉴别诊断","");
+                                diseaseCondition.put(o.toString().replace("\"",""),stringStringMap);
+                            }else {
+                                Map<String, String> diffMap = new HashMap<>();
+                                diffMap.put("鉴别诊断","");
+                                diseaseCondition.put(o.toString().replace("\"",""),diffMap);
+                            }
+                        }
+                    }
+                }
+            }
             //查找指标推送
             Set<String> indSet = new HashSet<>();
             query = propertiesUtil.getProperty("searchIndication").replace("fildList", fildList.toString());
@@ -755,6 +770,37 @@ public class Neo4jAPI {
         }
     }
 
+    /**
+     * 鉴别诊断
+     * @param mainDiag
+     * @return
+     */
+    public List<String> getDifferentialDiagnose(String mainDiag){
+        List<String> differentialDiagnoseList = new LinkedList<>();
+        Session session = null;
+        StatementResult result = null;
+        String query = "";
+        try {
+            session = driver.session(AccessMode.WRITE);
+            logger.info("session 为: " + session);
+            query = propertiesUtil.getProperty("searchDifferentialDiagnose").replace("mainDis", mainDiag);
+            result = session.run(query);
+            while (result.hasNext()) {
+                Record record = result.next();
+                List<Object> coll = record.get("coll").asList();
+                if(coll != null && coll.size()>0){
+                    for (Object o:coll) {
+                        differentialDiagnoseList.add(o.toString().replace("\"",""));
+                    }
+                }
+            }
+        }catch (Exception e){
+            e.printStackTrace();
+        }finally {
+            CloseSession(session);
+            return differentialDiagnoseList;
+        }
+    }
     /**
      * 不良反应推送
      */

+ 3 - 0
graph/src/main/resources/bolt.properties

@@ -42,6 +42,9 @@ where n.name=row\n \
 with distinct m,r\n \
 return m.name as name, labels(m)[0] as label,type(r) as relationType
 
+#\u67E5\u627E\u9274\u522B\u8BCA\u65AD\u7684\u8BED\u53E5
+searchDifferentialDiagnose=match(d:Disease)-[r:\u9274\u522B\u8BCA\u65AD]->(h) where d.name='mainDis' return collect(h.name) as coll
+
 #\u67E5\u627E\u6307\u6807\u7684\u8BED\u53E5
 searchIndication=with fildList  as data unwind data as row\n \
 match (n)-[r:\u786E\u8BCA|:\u62DF\u8BCA]->(m:Indicators)\n \

+ 15 - 1
graph/src/main/resources/query.properties

@@ -6,4 +6,18 @@ diseaseDrugsMedication=\
 # \u7B2C\u4E8C\u6B65,\u75BE\u75C5\u548C\u836F\u7684\u5173\u7CFB
 diseaseMedication=match(d:Disease{name:"diseaseName"})-[r:\u63A8\u8350]->(l:Medicine) return d.name,r.rate as p,l.name as collectName
 # \u7B2C\u4E09\u6B65,\u5927\u7C7B\u548C\u5B50\u7C7B\u7684\u5173\u7CFB
-bigDrugAndSubDrugs=match(d:Disease)-[r:\u63A8\u8350]->(s:Drugs)-[r1:\u5305\u542B]->(j:Drugs)where d.name="diseaseName" return s.name as big,j.name as sub
+bigDrugAndSubDrugs=match(d:Disease)-[r:\u63A8\u8350]->(s:Drugs)-[r1:\u5305\u542B]->(j:Drugs)where d.name="diseaseName" return s.name as big,j.name as sub
+
+# \u5904\u7406232\u7684\u6307\u6807\u6570\u636E\u5B58\u5165112\u4E2D
+allStruct=match(d:Indicators{name:'idn'})-[r2:\u5185\u5BB9]->(n:Content) where r2.p=1 \
+return r2.p as sort1 ,n.name as conten,null as sort2,null as item,null as controlType,null as state,null as value,null as uint,null as details \
+union \
+match(d:Indicators{name:'idn'})-[r2:\u5185\u5BB9]->(n:Content)-[r3:\u9879\u76EE]->(i:Item) \
+return r2.p as sort1 ,n.name as conten,r3.p as sort2,i.name as item,i.controlType as controlType,i.state as state,i.value as value,i.uint as uint,i.details as details order by sort1,sort2
+
+# \u5904\u7406232\u7684\u7BA1\u7406\u8BC4\u4F30\u6570\u636E\u5B58\u5165112\u4E2D
+allMange=match(d:Disease)-[r1:\u7BA1\u7406\u8BC4\u4F30]-(m:Management)-[r2:\u7597\u6548\u8BC4\u4F30]->(e:Effect)-[r3:\u7ED3\u679C]-(u:Result) \
+where d.name='dis' \
+return m.name as ma,r2.p as p1,e.name as eff,r3.p as p2,u.name as res,u.state as st order by p1,p2
+xuezhi=match(e:Effect)-[r:\u6307\u6807]->(l:LIS)-[r1:\u8BA1\u7B97\u7ED3\u679C]->(k:Result) where e.name='\u8840\u8102\u63A7\u5236\u60C5\u51B5' \
+return l.name as name,r1.between as between ,k.name as res

+ 1 - 9
graphdb/src/main/java/org/diagbot/service/impl/KnowledgeServiceImpl.java

@@ -181,15 +181,7 @@ public class KnowledgeServiceImpl implements KnowledgeService {
 
     @Override
     public Map<String, Object> getHighRiskDisease(SearchData searchData) {
-//        List<FeatureRate> diags = searchData.getPushDiags();
-//        List<String> diseaseList = new ArrayList<>();
-//        if(diags != null && diags.size()>0){
-//            for (FeatureRate fe:diags) {
-//              String featureName = fe.getFeatureName();
-//              diseaseList.add(featureName);
-//            }
-//        }
-        String[] splitsDiag = searchData.getDiag().split(",");
+        String[] splitsDiag = searchData.getDiag().split(",|,|、");
         List<String> diseaseNameList = Arrays.asList(splitsDiag);
         Map<String, Object> map = new HashMap<>();
         List<Map<String, Object>> list = baseNodeRepository.getHighRisk(diseaseNameList);

+ 4 - 1
push-web/src/main/java/org/diagbot/push/controller/AlgorithmController.java

@@ -90,10 +90,13 @@ public class AlgorithmController extends BaseController {
         //推送出的诊断信息作为参数传入知识图谱
         List<FeatureRate> pushDiags = new ArrayList<>();
         for (FeatureRate fr : bigDataResponseData.getDis()) {
-            pushDiags.add(fr);
+//            if(!"鉴别诊断".equals(fr.getDesc())){
+                pushDiags.add(fr);
+//            }
             logger.info("合并知识图谱、大数据后推送的诊断信息....: " + fr.getFeatureName());
         }
         searchData.setPushDiags(pushDiags);
+//        bigDataResponseData.setDiffDiag(graphResponseData.getDiffDiag());
 
         bigDataResponseData.setTreat(graphResponseData.getTreat());
         //量表和指标推送