Ver código fonte

1-debug,调整序列数据的时候和注释print函数。

bijl 5 anos atrás
pai
commit
113a7ba040

+ 13 - 11
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -110,14 +110,20 @@ public class TensorflowModel {
         // 序列数据
         if (this.withSequenceInputs){
             Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
-            this.session.runner().feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
-                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids));
+            this.session.runner();
+
+            return this.session.runner().feed(this.X, inputTensor)
+                    .feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
+                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids))
+                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+                    .fetch(this.SOFT_MAX).run().get(0)
+                    .copyTo(new float[numExamples][this.NUM_LABEL]);
+        }else{
+            return this.session.runner().feed(this.X, inputTensor)
+                    .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
+                    .fetch(this.SOFT_MAX).run().get(0)
+                    .copyTo(new float[numExamples][this.NUM_LABEL]);
         }
-
-        return this.session.runner().feed(this.X, inputTensor)
-                .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                .fetch(this.SOFT_MAX).run().get(0)
-                .copyTo(new float[numExamples][this.NUM_LABEL]);
     }
 
 
@@ -153,8 +159,4 @@ public class TensorflowModel {
         this.bundle.close();
     }
 
-    public void setWithSequenceInputs(boolean withSequenceInputs) {
-        this.withSequenceInputs = withSequenceInputs;
-    }
-
 }

+ 13 - 13
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSetImpl.java

@@ -1,7 +1,5 @@
 package org.algorithm.core.neural.dataset;
 
-import com.alibaba.fastjson.JSON;
-import com.alibaba.fastjson.JSONObject;
 import org.algorithm.util.TextFileReader;
 import org.diagbot.pub.utils.PropertiesUtil;
 
@@ -79,9 +77,9 @@ public class NNDataSetImpl extends NNDataSet {
             ch = sentence.charAt(i);
             id = this.CHAR2ID_DICT.get(String.valueOf(ch));
             if (id == null) {
-                id = this.CHAR2ID_DICT.get("<UKC>");
+                id = this.CHAR2ID_DICT.get("<UNC>");
             }
-            ids[i] = id.intValue();
+            ids[i] = id;
         }
         for (int i = sentence.length(); i < max_len; i++)  // padding
             ids[i] = this.CHAR2ID_DICT.get("<PAD>");
@@ -143,7 +141,7 @@ public class NNDataSetImpl extends NNDataSet {
 
         }
 
-        System.out.println("feature size:" + this.FEATURE_DICT.size());
+//        System.out.println("feature size:" + this.FEATURE_DICT.size());
 
     }
 
@@ -161,12 +159,14 @@ public class NNDataSetImpl extends NNDataSet {
         BufferedReader br = null;
         try {
             br = new BufferedReader(new FileReader(filePath));  // 读取原始json文件
-            String s = null;
-            while ((s = br.readLine()) != null) {
-                JSONObject jsonObject = (JSONObject) JSON.parse(s);
-                Set<Entry<String, Object>> entries = jsonObject.entrySet();
-                for (Map.Entry<String, Object> entry : entries)
-                    this.CHAR2ID_DICT.put(entry.getKey(), (Integer) entry.getValue());
+            String line = null;
+            String[] pair = null;
+            while ((line = br.readLine()) != null) {
+                line = line.trim();
+                if (line.indexOf("_|_") > -1){
+                    pair = line.split("_\\|_");
+                    this.CHAR2ID_DICT.put(pair[0], Integer.parseInt(pair[1]));
+                }
             }
         } catch (Exception e) {
             e.printStackTrace();
@@ -242,7 +242,7 @@ public class NNDataSetImpl extends NNDataSet {
 
         }
 
-        System.out.println("再分词,词条数:" + this.RE_SPLIT_WORD_DICT.size());
+//        System.out.println("再分词,词条数:" + this.RE_SPLIT_WORD_DICT.size());
 
     }
 
@@ -275,7 +275,7 @@ public class NNDataSetImpl extends NNDataSet {
             this.RELATED_DIAGNOSIS_DICT.put(temp[0], diagnosis_map);
         }
 
-        System.out.println("疾病过滤字典大小:" + this.RELATED_DIAGNOSIS_DICT.size());
+//        System.out.println("疾病过滤字典大小:" + this.RELATED_DIAGNOSIS_DICT.size());
     }