Selaa lähdekoodia

二元组模型性能优化

louhr 6 vuotta sitten
vanhempi
commit
49bdf3c98f

+ 2 - 2
algorithm/src/main/java/org/algorithm/core/cnn/AlgorithmCNNExecutor.java

@@ -16,8 +16,8 @@ public abstract class AlgorithmCNNExecutor {
     /**
      *
      * @param content 输入句子
-     * @param lemmas 实体列表
+     * @param triads 实体列表
      * @return
      */
-    public abstract List<Triad> execute(String content, List<Lemma> lemmas);
+    public abstract List<Triad> execute(String content, List<Triad> triads);
 }

+ 31 - 26
algorithm/src/main/java/org/algorithm/core/cnn/model/impl/RelationExtractionModelFromHttp.java

@@ -18,13 +18,13 @@ import java.util.List;
 public class RelationExtractionModelFromHttp extends AlgorithmCNNExecutor {
     private final int MAX_LEN = 512;
     @Override
-    public List<Triad> execute(String content, List<Lemma> lemmas) {
+    public List<Triad> execute(String content, List<Triad> triads) {
         String url = getUrl();
         // 句子长度不超过MAX_LEN,实体数超过两个
-        if (content.length() > this.MAX_LEN || lemmas.size() < 2) {
+        if (content.length() > this.MAX_LEN || triads.size() < 1) {
             return new ArrayList<>();
         }
-        String positions = makePositionsParam(lemmas);
+        String positions = makePositionsParam(triads);
         String indexPairAndRelations = HttpGetAndPost.sendPost(url,
                 "sentence="+content+"&positions="+positions);
 
@@ -34,18 +34,18 @@ public class RelationExtractionModelFromHttp extends AlgorithmCNNExecutor {
             return new ArrayList<>();
         }
         // 合并信息到三元组中
-        return mergePredictInfoToTriads(indexPairAndRelations, lemmas);
+        return mergePredictInfoToTriads(indexPairAndRelations, triads);
     }
 
     /**
      * 用lemma内部属性position组成字符串,按lemma在lemmas中的顺序
-     * @param lemmas 实体list
+     * @param triads 实体list
      * @return
      */
-    private String makePositionsParam(List<Lemma> lemmas){
+    private String makePositionsParam(List<Triad> triads){
         String results = "";
-        for(Lemma lm: lemmas)
-            results += lm.getPosition() + "|";  // 形式:1,2|33,45|
+        for(Triad triad: triads)
+            results += triad.getL_1().getPosition() +  "^" + triad.getL_1().getPosition() + "|";  // 形式:1,2|33,45|
         if (!StringUtils.isEmpty(results)) {
             results = results.substring(0, results.length() - 1);  // 形式:1,2|33,45
         }
@@ -64,29 +64,34 @@ public class RelationExtractionModelFromHttp extends AlgorithmCNNExecutor {
     /**
      * 合并预测信息到三元组中去
      * @param indexPairAndRelations 预测信息,来自网络 // 形式:1,2:有|3,5:无
-     * @param lemmas 实体
+     * @param triads 实体
      * @return 有关系的实体对
      */
-    private List<Triad> mergePredictInfoToTriads(String indexPairAndRelations, List<Lemma> lemmas){
-        List<Triad> triads = new ArrayList<>();
+    private List<Triad> mergePredictInfoToTriads(String indexPairAndRelations, List<Triad> triads){
+        List<Triad> result = new ArrayList<>();
         String[] posRelationArray = indexPairAndRelations.split("\\|");
-        String[] info;
-        String[] indexPair;
-        int index1;
-        int index2;
-        for(String posRelation:posRelationArray){
-            info = posRelation.split(":");  // 形式:1,2:有
-            if ("有".equals(info[1])){  // 仅返回有关系的
-                indexPair = info[0].split(","); // 形式:1,2
-                index1 = Integer.parseInt(indexPair[0]);
-                index2 = Integer.parseInt(indexPair[1]);
-                Triad triad = new Triad();
-                triad.setL_1(lemmas.get(index1));
-                triad.setL_2(lemmas.get(index2));
-                triad.setRelation(info[1]);
-                triads.add(triad);
+        for (int i = 0; i< triads.size(); i++) {
+            if ("1".equals(posRelationArray[i])) {
+                result.add(triads.get(i));
             }
         }
+//        String[] info;
+//        String[] indexPair;
+//        int index1;
+//        int index2;
+//        for(String posRelation:posRelationArray){
+//            info = posRelation.split(":");  // 形式:1,2:有
+//            if ("有".equals(info[1])){  // 仅返回有关系的
+//                indexPair = info[0].split(","); // 形式:1,2
+//                index1 = Integer.parseInt(indexPair[0]);
+//                index2 = Integer.parseInt(indexPair[1]);
+//                Triad triad = new Triad();
+//                triad.setL_1(lemmas.get(index1));
+//                triad.setL_2(lemmas.get(index2));
+//                triad.setRelation(info[1]);
+//                triads.add(triad);
+//            }
+//        }
 
         return triads;
     }

+ 3 - 3
nlp/src/main/java/org/diagbot/nlp/relation/analyze/RelationAnalyze.java

@@ -36,10 +36,10 @@ public class RelationAnalyze {
             List<Lemma> lemmaParticiple = lemmaUtil.lexemeToTriadLemma(lexemes);
             //调用CNN模型
             long start = System.currentTimeMillis();
-            AlgorithmCNNExecutor executor = new RelationExtractionModelFromHttp();
-            List<Triad> triads = executor.execute(part_content, lemmaParticiple);
             //删除不作为训练样本集
-            triads = lemmaUtil.findPairTraids(triads);
+            List<Triad> triads = lemmaUtil.findPairTraids(lemmaParticiple);
+            AlgorithmCNNExecutor executor = new RelationExtractionModelFromHttp();
+            triads = executor.execute(part_content, triads);
             //模型返回的三元组转树形结构
             List<Lemma> lemmaTree = lemmaUtil.traidToTree(triads, featureType);
             long end = System.currentTimeMillis();

+ 11 - 1
nlp/src/main/java/org/diagbot/nlp/relation/util/LemmaUtil.java

@@ -130,7 +130,17 @@ public class LemmaUtil {
         return l;
     }
 
-    public List<Triad> findPairTraids(List<Triad> triads) {
+    public List<Triad> findPairTraids(List<Lemma> lemmaParticiple) {
+        List<Triad> triads = new ArrayList<>();
+        for (int i = 0; i < lemmaParticiple.size(); i++) {
+            for (int j = 0; j < lemmaParticiple.size(); j++) {
+                Triad triad = new Triad();
+                triad.setL_1(lemmaParticiple.get(i));
+                triad.setL_2(lemmaParticiple.get(j));
+                triads.add(triad);
+            }
+        }
+
         Map<String, List<String>> extract_relation_property_pair_map = NlpCache.getExtract_relation_property_pair_map();
 
         String[] prop1;