فهرست منبع

1- 添加测试类。

bijl 6 سال پیش
والد
کامیت
080e3614f2

+ 79 - 31
algorithm/src/main/java/org/algorithm/core/cnn/dataset/RelationExtractionDataSet.java

@@ -4,10 +4,9 @@ import java.io.BufferedReader;
 import java.io.FileNotFoundException;
 import java.io.FileReader;
 import java.io.IOException;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
+import java.util.*;
+
+import org.algorithm.core.cnn.model.LemmaInfo;
 import org.algorithm.util.TextFileReader;
 import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSONObject;
@@ -15,45 +14,46 @@ import com.alibaba.fastjson.JSONObject;
 /**
  * @Author: bijl
  * @Date: 2019年1月21日-下午2:43:44
- * @Description: 
+ * @Description:
  */
 public class RelationExtractionDataSet {
-    
+
     private Map<String, Integer> char2id = new HashMap<>();
-    private int maxLength = 200;
-    
-    
+    private Map<Integer, Map<String, String>> entities_info = new HashMap<>();
+    public int maxLength = 200;
+
+
     /**
      * 切分句子
+     *
      * @param document 原有文档
      * @return 句子数组
      */
     public String[] splitSentence(String document) {
         String[] sentences = null;
-        sentences = document.split("。|;|\n|\n\r"); 
+        sentences = document.split("。|;|\n|\n\r");
         return sentences;
     }
-    
+
     /**
      * 加载字典
+     *
      * @param dir
      */
     public void loadDictionary(String dir) {
-        List<String> lines= TextFileReader.readLines(dir);
         BufferedReader br = null;
         try {
-             br = new BufferedReader(new FileReader(dir));// 读取原始json文件
+            br = new BufferedReader(new FileReader(dir));// 读取原始json文件
             String s = null;
             while ((s = br.readLine()) != null) {
                 JSONObject jsonObject = (JSONObject) JSON.parse(s);
                 Set<Map.Entry<String, Object>> entries = jsonObject.entrySet();
-                for (Map.Entry<String, Object> entry : entries) {
+                for (Map.Entry<String, Object> entry : entries)
                     this.char2id.put(entry.getKey(), (Integer) entry.getValue());
-                }
             }
         } catch (Exception e) {
             e.printStackTrace();
-        }finally {
+        } finally {
             try {
                 br.close();
             } catch (IOException e) {
@@ -61,51 +61,49 @@ public class RelationExtractionDataSet {
             }
         }
     }
-    
+
     /**
      * 句子转字符ids
+     *
      * @param sentence 句子
      * @return ids
      */
-    public int[]  sentence2ids(String sentence) {
-        int[] ids = new int[this.maxLength];
+    public float[] sentence2ids(String sentence) {
+        float[] ids = new float[this.maxLength];
         char ch = '1';
         Integer id = null;
         for (int i = 0; i < sentence.length(); i++) {
             ch = sentence.charAt(i);
-            id = this.char2id.get(ch);
+            id = this.char2id.get(String.valueOf(ch));
             if (id == null) {
                 id = this.char2id.get("<UKC>");
             }
             ids[i] = id.intValue();
         }
-        for(int i=sentence.length(); i<this.maxLength; i++)  // padding
+        for (int i = sentence.length(); i < this.maxLength; i++)  // padding
             ids[i] = this.char2id.get("<PAD>");
-        
+
         return ids;
     }
-    
+
     /**
-     * 
      * @param sentence 句子
      * @param position 一个实体的位置
      * @return 句子中各个汉子相对于实体的位置
      */
-    public int[] getRelativePositions(String sentence, String position) {
-        int[] relativePositions = new int[this.maxLength];
+    public float[] getRelativePositions(String sentence, String position) {
+        float[] relativePositions = new float[this.maxLength];
         String[] positionPair = position.split(",");
-        int startPos = Integer.parseInt(positionPair[0]); 
+        int startPos = Integer.parseInt(positionPair[0]);
         int endtPos = Integer.parseInt(positionPair[1]);
-        
-        char ch = '1';
-        Integer id = null;
+
         for (int i = 0; i < sentence.length(); i++) {
             if (i < startPos)
                 relativePositions[i] = startPos - i;
             else if (i >= startPos && i <= endtPos)
                 relativePositions[i] = 0;
             else
-                relativePositions[i] = endtPos - i;
+                relativePositions[i] = i - endtPos;
         }
 
         for (int i = sentence.length(); i < this.maxLength; i++)
@@ -114,6 +112,56 @@ public class RelationExtractionDataSet {
         return relativePositions;
     }
 
+    /**
+     * 获取实体对的组合
+     *
+     * @return
+     */
+    public List<String> getPositionCombinations(String json_content) {
+        List<String> combinations = new ArrayList<>();
+
+        return combinations;
+    }
+
+
+    /**
+     * @param sentence     输入句子
+     * @param json_content 句子content中的实体信息
+     * @return
+     */
+    public List<float[][]> get_examples(String sentence, String json_content) {
+        List<float[][]> examples = new ArrayList<>();
+        List<String> combinations = this.getPositionCombinations(json_content);
+        float[] charId = this.sentence2ids(sentence);
+        for (String combination : combinations) {
+            float[][] example = new float[3][this.maxLength];
+            example[0] = charId;
+            example[1] = this.getRelativePositions(sentence, combination);
+            example[1] = this.getRelativePositions(sentence, combination);
+            examples.add(example);
+        }
+        return examples;
+    }
+
+    /**
+     * @param sentence     输入句子
+     * @param entity1 实体1信息
+     * @param entity2 实体2信息
+     * @return
+     */
+    public float[][] getExample(String sentence, LemmaInfo entity1, LemmaInfo entity2) {
+        float[][] example = new float[3][this.maxLength];
+        int startPos = entity1.getOffset().intValue();
+        int endPos = entity1.getOffset().intValue() + entity1.getLength().intValue() - 1;
+
+        example[0] = this.sentence2ids(sentence);
+        example[1] = this.getRelativePositions(sentence, startPos + "," + endPos);
+        startPos = entity2.getOffset().intValue();
+        endPos = entity2.getOffset().intValue() + entity2.getLength().intValue() - 1;
+        example[2] = this.getRelativePositions(sentence, startPos + "," + endPos);
+
+        return example;
+    }
 
 
 }

+ 59 - 0
algorithm/src/main/java/org/algorithm/core/cnn/model/LemmaInfo.java

@@ -0,0 +1,59 @@
+package org.algorithm.core.cnn.model;
+
+/**
+ * @Author: bijl
+ * @Date: 2019/1/22 12:23
+ * @Decription:
+ */
+public class LemmaInfo {
+    private Integer length;
+    private Integer offset;
+    private String property;
+    private String text;
+
+    public Integer getLength() {
+        return length;
+    }
+
+    public void setLength(Integer length) {
+        this.length = length;
+    }
+
+    public Integer getOffset() {
+        return offset;
+    }
+
+    public void setOffset(Integer offset) {
+        this.offset = offset;
+    }
+
+    public String getProperty() {
+        return property;
+    }
+
+    public void setProperty(String property) {
+        this.property = property;
+    }
+
+    public String getText() {
+        return text;
+    }
+
+    public void setText(String text) {
+        this.text = text;
+    }
+
+    /**
+     * 转Lemma
+     * @return
+     */
+    public Lemma toLemma(){
+        Lemma lemma = new Lemma();
+        lemma.setLen(this.length);
+        lemma.setProperty(this.property);
+        lemma.setText(this.text);
+        lemma.setPosition(this.offset + "," + (this.offset + this.length - 1));
+        return lemma;
+    }
+
+}

+ 114 - 4
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionModel.java

@@ -1,9 +1,17 @@
 package org.algorithm.core.cnn.model;
 
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONArray;
+import com.alibaba.fastjson.JSONObject;
+import com.alibaba.fastjson.TypeReference;
 import org.algorithm.core.cnn.AlgorithmCNNExecutor;
+import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
 import org.tensorflow.SavedModelBundle;
 import org.tensorflow.Session;
+import org.tensorflow.Tensor;
 
+import java.nio.FloatBuffer;
+import java.util.ArrayList;
 import java.util.List;
 
 /**
@@ -12,9 +20,30 @@ import java.util.List;
  * @Decription:
  */
 public class RelationExtractionModel extends AlgorithmCNNExecutor {
-    private final String xx = null;
+//    self.X = tf.placeholder(tf.int32, shape=[None, self.max_length], name='X')
+//    self.pos1 = tf.placeholder(tf.int32, shape=[None, self.max_length], name='pos1')
+//    self.pos2 = tf.placeholder(tf.int32, shape=[None, self.max_length], name='pos2')
+//
+//    self.y = tf.placeholder(tf.float32, shape=[None, self.n_classes], name='y')
+//    self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
+    private final String X_PLACEHOLDER = "X";
+    private final String pos1_PLACEHOLDER = "pos1";
+    private final String pos2_PLACEHOLDER = "pos2";
+    private final String y_PLACEHOLDER = "y";
+    private final int NUM_LABEL = 2;
     private SavedModelBundle bundle; // 模型捆绑
     private Session session;  // 会话
+    private RelationExtractionDataSet dataSet;
+
+
+    /**
+     *
+     * @param exportDir 模型保存地址
+     */
+    public RelationExtractionModel(String exportDir, RelationExtractionDataSet dataSet){
+        this.dataSet = dataSet;
+        this.init(exportDir);
+    }
 
     /**
      * 初始化:加载模型,获取会话。
@@ -32,11 +61,92 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
         this.session = bundle.session();
     }
 
+    @Override
+    public List<Triad> execute(String content, String json_content) {
+        List<LemmaInfo[]> combinations = new ArrayList<>();
+        List<LemmaInfo> lemmaInfos = this.stringToObjects(json_content);  // json数组字符串装对象list
+        // 组合
+        for(int i=0; i < lemmaInfos.size() - 1; i++){  // 两两组合成实体对
+            for (int j = i + 1; j< lemmaInfos.size(); j++){
+                LemmaInfo[] pair = new LemmaInfo[2];
+                pair[0] = lemmaInfos.get(i);
+                pair[1] = lemmaInfos.get(j);
+                combinations.add(pair);
+            }
+        }
 
+        List<Triad> triads = new ArrayList<>();
 
+        // 遍历组合
+        for (LemmaInfo[] lemmaInfoPair: combinations) {
+            float[][] example = dataSet.getExample(content, lemmaInfoPair[0], lemmaInfoPair[1]);
+            // 调用模型
+            float[][] relation = this.run(example, 1);
+            Triad triad = new Triad();
 
-    @Override
-    public List<Triad> execute(String content, String json_content) {
-        return null;
+            // TODO:修改triad
+            // 生成Triad(三元组)
+            triad.setL_1(lemmaInfoPair[0].toLemma());
+            triad.setL_2(lemmaInfoPair[1].toLemma());
+            triad.setRelation(relation[0][0] > relation[0][1] ? "无":"有");
+            triads.add(triad);
+        }
+        return triads;
+    }
+
+    /**
+     * json数组字符串装对象list
+     * @param json_content
+     * @return
+     */
+    public List<LemmaInfo> stringToObjects(String json_content) {
+        List<LemmaInfo> lemmaInfos = new ArrayList<>();
+        JSONArray jsonArray = JSONArray.parseArray(json_content);
+        for (int i = 0; i < jsonArray.size(); i++) {
+            JSONObject job = jsonArray.getJSONObject(i);
+            LemmaInfo info = JSON.parseObject(job.toJSONString(), new TypeReference<LemmaInfo>() {
+            });
+            lemmaInfos.add(info);
+        }
+        return lemmaInfos;
+    }
+
+
+    /**
+     *
+     * @param inputValues 字符id,相对于实体1位置,相对于实体2位置
+     * @param batchSize 批量大小
+     * @return
+     */
+    private float[][] run(float[][] inputValues, int batchSize){
+        long[] shape = {1, dataSet.maxLength};  // 老模型
+        Tensor<Float> charId = Tensor.create(
+                shape,
+                FloatBuffer.wrap(inputValues[0])
+        );
+        Tensor<Float> pos1 = Tensor.create(
+                shape,
+                FloatBuffer.wrap(inputValues[1])
+        );
+        Tensor<Float> pos2 = Tensor.create(
+                shape,
+                FloatBuffer.wrap(inputValues[2])
+        );
+
+        return this.session.runner()
+                .feed(this.X_PLACEHOLDER, charId)
+                .feed(this.pos1_PLACEHOLDER, pos1)
+                .feed(this.pos2_PLACEHOLDER, pos2)
+                .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+                .fetch(this.y_PLACEHOLDER).run().get(0)
+                .copyTo(new float[1][this.NUM_LABEL]);
+    }
+
+    /**
+     * 关闭会话,释放资源
+     */
+    public void close() {
+        this.session.close();
+        this.bundle.close();
     }
 }

+ 28 - 0
algorithm/src/main/java/org/algorithm/test/RelationExtractionDataSetTest.java

@@ -0,0 +1,28 @@
+package org.algorithm.test;
+
+import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
+
+/**
+ * @Author: bijl
+ * @Date: 2019/1/22 13:46
+ * @Decription:
+ */
+public class RelationExtractionDataSetTest {
+
+    public static void main(String[] args) {
+        RelationExtractionDataSet dataSet = new RelationExtractionDataSet();
+        String filePath = "E:\\relation_extraction\\shao_yi_fu_data\\char2id.json";
+        String sentence = "有双手麻木感,活动后好转,颈部及肩部活动度无殊";
+        dataSet.loadDictionary(filePath);
+//        for (float id:dataSet.sentence2ids(sentence)) {
+//            System.out.println(id); // pass
+//        }
+//        for (float id:dataSet.getRelativePositions(sentence, "1,2")) {
+//            System.out.println(id); //pass
+//        }
+
+        for (float id:dataSet.getRelativePositions(sentence, "1,2")) {
+            System.out.println(id);
+        }
+    }
+}

+ 37 - 7
algorithm/src/main/java/org/algorithm/test/Test.java

@@ -1,4 +1,8 @@
 package org.algorithm.test;
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.*;
+import com.alibaba.fastjson.TypeReference;
+import org.algorithm.core.cnn.model.LemmaInfo;
 
 
 public class Test {
@@ -11,13 +15,39 @@ public class Test {
 //        for(int i=1; i< 955; i++) {
 //            xx = (float)(Math.round(1.0f * i / bb*100000))/100000;
 //            System.out.println(i+":"+xx);
-//        }
-        String filePath = "/opt/models/model_version_replacement/model";
-        int index = filePath.indexOf("model_version_replacement");
-        
-        System.out.println(filePath.substring(0, index));
-        
-                
+////        }
+//        String filePath = "/opt/models/model_version_replacement/model";
+//        int index = filePath.indexOf("model_version_replacement");
+//
+//        System.out.println(filePath.substring(0, index));
+//            public static void testJSONStrToJavaBeanObj(){
+//
+//        Student student = JSON.parseObject(JSON_OBJ_STR, new TypeReference<Student>() {});
+//        //Student student1 = JSONObject.parseObject(JSON_OBJ_STR, new TypeReference<Student>() {});//因为JSONObject继承了JSON,所以这样也是可以的
+//
+//        System.out.println(student.getStudentName()+":"+student.getStudentAge());
+//
+        String JSON_ARRAY_STR = "[{\"length\":4,\"offset\":0,\"property\":\"1\",\"text\":\"剑突下痛\",\"threshold\":0.0},{\"length\":2,\"offset\":4,\"property\":\"1\",\"text\":\"胀痛\",\"threshold\":0.0},{\"length\":2,\"offset\":6,\"property\":\"2\",\"text\":\"1天\",\"threshold\":0.0},{\"length\":1,\"offset\":8,\"text\":\",\",\"threshold\":0.0}]\n";
+//        JSONArray jsonArray = JSONArray.parseArray(JSON_ARRAY_STR);
+////        String jsonString = "{\"length\":4,\"offset\":0,\"property\":\"1\",\"text\":\"剑突下痛\",\"threshold\":0.0}";
+//
+//       for (int i = 0; i < jsonArray.size(); i++){
+//           JSONObject job = jsonArray.getJSONObject(i);
+//           LemmaInfo info = JSON.parseObject(job.toJSONString(), new TypeReference<LemmaInfo>() {});
+//           //Student student1 = JSONObject.parseObject(JSON_OBJ_STR, new TypeReference<Student>() {});//因为JSONObject继承了JSON,所以这样也是可以的
+//
+//           System.out.println(info.getLength()+":"+info.getText());
+//       }
+
+        int index = 0;
+        for (int i=0; i<5; i++)
+            for (int j = i+1; j< 6; j++){
+                System.out.println(i + "," + j);
+                index ++;
+            }
+
+        System.out.println(index);
+
     }
 
 }