浏览代码

1- 测试通过。

bijl 6 年之前
父节点
当前提交
8dda9cbcc3

+ 6 - 37
algorithm/src/main/java/org/algorithm/core/cnn/dataset/RelationExtractionDataSet.java

@@ -68,8 +68,8 @@ public class RelationExtractionDataSet {
      * @param sentence 句子
      * @return ids
      */
-    public float[] sentence2ids(String sentence) {
-        float[] ids = new float[this.maxLength];
+    public int[] sentence2ids(String sentence) {
+        int[] ids = new int[this.maxLength];
         char ch = '1';
         Integer id = null;
         for (int i = 0; i < sentence.length(); i++) {
@@ -91,8 +91,8 @@ public class RelationExtractionDataSet {
      * @param position 一个实体的位置
      * @return 句子中各个汉子相对于实体的位置
      */
-    public float[] getRelativePositions(String sentence, String position) {
-        float[] relativePositions = new float[this.maxLength];
+    public int[] getRelativePositions(String sentence, String position) {
+        int[] relativePositions = new int[this.maxLength];
         String[] positionPair = position.split(",");
         int startPos = Integer.parseInt(positionPair[0]);
         int endtPos = Integer.parseInt(positionPair[1]);
@@ -112,45 +112,14 @@ public class RelationExtractionDataSet {
         return relativePositions;
     }
 
-    /**
-     * 获取实体对的组合
-     *
-     * @return
-     */
-    public List<String> getPositionCombinations(String json_content) {
-        List<String> combinations = new ArrayList<>();
-
-        return combinations;
-    }
-
-
-    /**
-     * @param sentence     输入句子
-     * @param json_content 句子content中的实体信息
-     * @return
-     */
-    public List<float[][]> get_examples(String sentence, String json_content) {
-        List<float[][]> examples = new ArrayList<>();
-        List<String> combinations = this.getPositionCombinations(json_content);
-        float[] charId = this.sentence2ids(sentence);
-        for (String combination : combinations) {
-            float[][] example = new float[3][this.maxLength];
-            example[0] = charId;
-            example[1] = this.getRelativePositions(sentence, combination);
-            example[1] = this.getRelativePositions(sentence, combination);
-            examples.add(example);
-        }
-        return examples;
-    }
-
     /**
      * @param sentence     输入句子
      * @param entity1 实体1信息
      * @param entity2 实体2信息
      * @return
      */
-    public float[][] getExample(String sentence, LemmaInfo entity1, LemmaInfo entity2) {
-        float[][] example = new float[3][this.maxLength];
+    public int[][] getExample(String sentence, LemmaInfo entity1, LemmaInfo entity2) {
+        int[][] example = new int[3][this.maxLength];
         int startPos = entity1.getOffset().intValue();
         int endPos = entity1.getOffset().intValue() + entity1.getLength().intValue() - 1;
 

+ 11 - 10
algorithm/src/main/java/org/algorithm/core/cnn/model/RelationExtractionModel.java

@@ -11,6 +11,7 @@ import org.tensorflow.Session;
 import org.tensorflow.Tensor;
 
 import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
 import java.util.ArrayList;
 import java.util.List;
 
@@ -29,7 +30,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
     private final String X_PLACEHOLDER = "X";
     private final String pos1_PLACEHOLDER = "pos1";
     private final String pos2_PLACEHOLDER = "pos2";
-    private final String y_PLACEHOLDER = "y";
+    private final String SOFT_MAX = "softmax/Softmax";
     private final int NUM_LABEL = 2;
     private SavedModelBundle bundle; // 模型捆绑
     private Session session;  // 会话
@@ -79,7 +80,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
 
         // 遍历组合
         for (LemmaInfo[] lemmaInfoPair: combinations) {
-            float[][] example = dataSet.getExample(content, lemmaInfoPair[0], lemmaInfoPair[1]);
+            int[][] example = dataSet.getExample(content, lemmaInfoPair[0], lemmaInfoPair[1]);
             // 调用模型
             float[][] relation = this.run(example, 1);
             Triad triad = new Triad();
@@ -118,19 +119,19 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
      * @param batchSize 批量大小
      * @return
      */
-    private float[][] run(float[][] inputValues, int batchSize){
+    private float[][] run(int[][] inputValues, int batchSize){
         long[] shape = {1, dataSet.maxLength};  // 老模型
-        Tensor<Float> charId = Tensor.create(
+        Tensor<Integer> charId = Tensor.create(
                 shape,
-                FloatBuffer.wrap(inputValues[0])
+                IntBuffer.wrap(inputValues[0])
         );
-        Tensor<Float> pos1 = Tensor.create(
+        Tensor<Integer> pos1 = Tensor.create(
                 shape,
-                FloatBuffer.wrap(inputValues[1])
+                IntBuffer.wrap(inputValues[1])
         );
-        Tensor<Float> pos2 = Tensor.create(
+        Tensor<Integer> pos2 = Tensor.create(
                 shape,
-                FloatBuffer.wrap(inputValues[2])
+                IntBuffer.wrap(inputValues[2])
         );
 
         return this.session.runner()
@@ -138,7 +139,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
                 .feed(this.pos1_PLACEHOLDER, pos1)
                 .feed(this.pos2_PLACEHOLDER, pos2)
                 .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                .fetch(this.y_PLACEHOLDER).run().get(0)
+                .fetch(this.SOFT_MAX).run().get(0)
                 .copyTo(new float[1][this.NUM_LABEL]);
     }
 

文件差异内容过多而无法显示
+ 29 - 0
algorithm/src/main/java/org/algorithm/test/RelationExtractionModelTest.java


+ 3 - 3
nlp/src/main/java/org/diagbot/nlp/relation/RelationAnalyze.java

@@ -20,9 +20,9 @@ import java.util.List;
 public class RelationAnalyze {
     public void analyze(String content, FeatureType featureType) throws Exception {
         LexemePath<Lexeme> lexemes = ParticipleUtil.participle(content, true);
-        AlgorithmCNNExecutor executor = new AlgorithmCNNExecutor();
+//        AlgorithmCNNExecutor executor = new AlgorithmCNNExecutor();
 
-        String json_content = JSON.toJSONString(lexemes);
-        List<Triad> triads = executor.execute(content, json_content);
+//        String json_content = JSON.toJSONString(lexemes);
+//        List<Triad> triads = executor.execute(content, json_content);
     }
 }

+ 1 - 1
nlp/src/test/java/org/diagbot/nlp/test/ParticipleTest.java

@@ -23,7 +23,7 @@ import java.util.List;
 public class ParticipleTest {
     public static void main(String[] args) {
         try {
-            String content = "剑突下痛胀痛1天,";
+            String content = "有双手麻木感,活动后好转,颈部及肩部活动度无殊,";
             ParticipleTest test = new ParticipleTest();
 //            InputStream is = test.getClass().getClassLoader().getResourceAsStream("present.txt");
 //            BufferedReader br = new BufferedReader(new InputStreamReader(is, "UTF-8"), 512);