Selaa lähdekoodia

Merge remote-tracking branch 'origin/master'

louhr 6 vuotta sitten
vanhempi
commit
306fab3038

+ 112 - 0
algorithm/src/main/java/org/algorithm/core/cnn/dataset/RelationExtractionDataSet.java

@@ -0,0 +1,112 @@
+package org.algorithm.core.cnn.dataset;
+
+import java.io.BufferedReader;
+import java.io.FileNotFoundException;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.algorithm.util.TextFileReader;
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONObject;
+
+/**
+ * @Author: bijl
+ * @Date: 2019年1月21日-下午2:43:44
+ * @Description: 
+ */
+public class RelationExtractionDataSet {
+    
+    private Map<String, Integer> char2id = new HashMap<>();
+    private int maxLength = 200;
+    
+    
+    /**
+     * 切分句子
+     * @param document 原有文档
+     * @return 句子数组
+     */
+    public String[] splitSentence(String document) {
+        String[] sentences = null;
+        sentences = document.split("。|;|\n|\n\r"); 
+        return sentences;
+    }
+    
+    /**
+     * 加载字典
+     * @param dir
+     */
+    public void loadDictionary(String dir) {
+        List<String> lines= TextFileReader.readLines(dir);
+        BufferedReader br = null;
+        try {
+             br = new BufferedReader(new FileReader(dir));// 读取原始json文件
+            String s = null;
+            while ((s = br.readLine()) != null) {
+                JSONObject jsonObject = (JSONObject) JSON.parse(s);
+                Set<Map.Entry<String, Object>> entries = jsonObject.entrySet();
+                for (Map.Entry<String, Object> entry : entries) {
+                    this.char2id.put(entry.getKey(), (Integer) entry.getValue());
+                }
+            }
+        } catch (Exception e) {
+            e.printStackTrace();
+        }finally {
+            try {
+                br.close();
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+        }
+    }
+    
+    /**
+     * 句子转字符ids
+     * @param sentence 句子
+     * @return ids
+     */
+    public int[]  sentence2ids(String sentence) {
+        int[] ids = new int[this.maxLength];
+        char ch = '1';
+        Integer id = null;
+        for (int i = 0; i < sentence.length(); i++) {
+            ch = sentence.charAt(i);
+            id = this.char2id.get(ch);
+            if (id == null) {
+                id = this.char2id.get("<UKC>");
+            }
+            ids[i] = id.intValue();
+        }
+        for(int i=sentence.length(); i<this.maxLength; i++)  // padding
+            ids[i] = this.char2id.get("<PAD>");
+        
+        return ids;
+    }
+    
+    /**
+     * 
+     * @param sentence 句子
+     * @param position 一个实体的位置
+     * @return 句子中各个汉子相对于实体的位置
+     */
+    public int[] getRelativePositions(String sentence, String position) {
+        int[] relativePositions = new int[this.maxLength];
+        String[] positionPair = position.split(",");
+        int startPos = Integer.parseInt(positionPair[0]); 
+        int endtPos = Integer.parseInt(positionPair[1]);
+        
+        char ch = '1';
+        Integer id = null;
+        for (int i = 0; i < sentence.length(); i++) {
+        }
+        
+        
+        
+        
+        return relativePositions;
+    }
+    
+
+}