Pārlūkot izejas kodu

1- 提交拼接模型。

bijl 5 gadi atpakaļ
vecāks
revīzija
00a71473a4

+ 14 - 18
algorithm/src/main/java/org/algorithm/core/neural/TensorFlowModelLoadFactory.java

@@ -6,6 +6,7 @@ import org.diagbot.pub.utils.PropertiesUtil;
 
 /**
  * Tensorlflow 模型加载工厂
+ *
  * @Author: bijl
  * @Date: 2018年7月19日-下午7:28:58
  * @Description:
@@ -14,32 +15,31 @@ public class TensorFlowModelLoadFactory {
 
     /**
      * 加载并创建模型类
-     * @param modelVersion  模型版本号
+     *
+     * @param modelVersion 模型版本号
      * @return 模型
      */
     public static TensorflowModel create(String modelVersion) {
-        
-        
+
+
         PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
-        
-        String inputOpName = "X";  // 统一输入op名称
-        String outputOpName = "softmax/softmax";  // 统一输出op名称
-        
+
+
 //        NNDataSet dataSet = new NNDataSetImplNonParallel(modelVersion);  // 新模型
         NNDataSet dataSet = new NNDataSetImpl(modelVersion);  // 老模型
 
-        String modelPath =prop.getProperty("basicPath");  // 模型基本路径
+        String modelPath = prop.getProperty("basicPath");  // 模型基本路径
         modelVersion = prop.getProperty(modelVersion);
         modelPath = modelPath.replace("model_version_replacement", modelVersion);  // 生成模型路径
-        
-        TensorflowModel tm = new TensorflowModel(modelPath, inputOpName, outputOpName,
-                dataSet);
+
+        TensorflowModel tm = new TensorflowModel(modelPath, dataSet);
         return tm;
     }
 
     /**
      * 加载并创建模型类
-     * @param modelVersion  模型版本号
+     *
+     * @param modelVersion 模型版本号
      * @return 模型
      */
     public static TensorflowModel createAndFilterDiagnosis(String modelVersion) {
@@ -47,20 +47,16 @@ public class TensorFlowModelLoadFactory {
 
         PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
 
-        String inputOpName = "X";  // 统一输入op名称
-        String outputOpName = "softmax/softmax";  // 统一输出op名称
-
         NNDataSet dataSet = new NNDataSetImpl(modelVersion);  // 老模型
 
         dataSet.setDoFilterDiagnosis(true);
         dataSet.readFilterDiagnosisDict();
 
-        String modelPath =prop.getProperty("basicPath");  // 模型基本路径
+        String modelPath = prop.getProperty("basicPath");  // 模型基本路径
         modelVersion = prop.getProperty(modelVersion);
         modelPath = modelPath.replace("model_version_replacement", modelVersion);  // 生成模型路径
 
-        TensorflowModel tm = new TensorflowModel(modelPath, inputOpName, outputOpName,
-                dataSet);
+        TensorflowModel tm = new TensorflowModel(modelPath, dataSet);
         return tm;
     }
 

+ 90 - 34
algorithm/src/main/java/org/algorithm/core/neural/TensorflowModel.java

@@ -6,43 +6,55 @@ import org.tensorflow.Session;
 import org.tensorflow.Tensor;
 
 import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.util.HashMap;
 import java.util.Map;
 
 /**
  * tensorflow 模型类,要求单个样本是1维向量,而不是高维向量
+ *
  * @Author: bijl
  * @Date: 2018年7月19日-下午7:21:24
  * @Description:
  */
 public class TensorflowModel {
-    
-    private final String INPUT_OPERATION_NAME;   // 输入op的名称
-    private final String OUTPUT_OPERATION_NAME;  // 输出op的名称
+
+
+    private final String X = "X";  // 输入op x的名字
+    private final String Char_ids = "Char_ids";  // 输入op Char_ids的名字
+    private final String Pos_ids = "Pos_ids";  // 输入op Pos_ids的名字
+    private final String SOFT_MAX = "softmax/softmax";  // 输出op的名称
+
     private final int NUM_FEATURE;  // 特征个数
     private final int NUM_LABEL;  //  标签(类别)个数
     private SavedModelBundle bundle; // 模型捆绑
     private Session session;  // 会话
     private NNDataSet dataSet;  // 数据集
-    
+
+
+    private boolean withSequenceInputs = false;  // 是否带有序列输入
+    private final int MAX_LEN; // 最大长度
+
+
     /**
-     * 
-     * @param exportDir  模型保存地址
-     * @param inputOpName  输入op的名称
-     * @param outputOpName  输出op的名称
-     * @param dataSet  模型使用的数据集
+     * @param exportDir 模型保存地址
+     * @param dataSet   模型使用的数据集
      */
-    public TensorflowModel(String exportDir, String inputOpName, String outputOpName, NNDataSet dataSet) {
-        this.INPUT_OPERATION_NAME = inputOpName;
-        this.OUTPUT_OPERATION_NAME = outputOpName;
+    public TensorflowModel(String exportDir, NNDataSet dataSet) {
+
+        this.init(exportDir);
         this.dataSet = dataSet;
         this.NUM_FEATURE = this.dataSet.getNumFeature();
         this.NUM_LABEL = this.dataSet.getNumLabel();
-        this.init(exportDir);
-                
+
+        // 序列数据有段的属性
+        this.MAX_LEN = this.dataSet.getMAX_LEN();
+        this.withSequenceInputs = this.dataSet.isWithSequenceInputs();
     }
-    
+
     /**
      * 初始化:加载模型,获取会话。
+     *
      * @param exportDir
      */
     public void init(String exportDir) {
@@ -54,29 +66,61 @@ public class TensorflowModel {
         }
 
         // create the session from the Bundle
-        this.session = bundle.session(); 
+        this.session = bundle.session();
+    }
+
+
+    /**
+     * 包装序列化输入
+     *
+     * @param sequenceValuesMap 序列输入的map
+     * @param numExamples       样本数
+     * @return
+     */
+    private Map<String, Tensor<Integer>> wrapSequenceInputs(Map<String, int[]> sequenceValuesMap, int numExamples) {
+        long[] inputShape = {numExamples, this.MAX_LEN};
+        Map<String, Tensor<Integer>> sequenceTensorMap = new HashMap<>();
+        for (Map.Entry<String, int[]> entry : sequenceValuesMap.entrySet()) {
+            String mapKey = entry.getKey();
+            Tensor<Integer> inputTensor = Tensor.create(
+                    inputShape,
+                    IntBuffer.wrap(entry.getValue())
+            );
+            sequenceTensorMap.put(mapKey, inputTensor);
+        }
+
+        return sequenceTensorMap;
     }
-    
+
+
     /**
      * 运行模型
-     * @param inputValues  输入值
-     * @param numExamples  样本个数
+     *
+     * @param inputValues 输入值
+     * @param numExamples 样本个数
      * @return 模型的输出
      */
-    private float[][] run(float[] inputValues, int numExamples){
-//        long[] inputShape = {numExamples, this.NUM_FEATURE, 4, 1};  // 新模型
-        long[] inputShape = {numExamples, this.NUM_FEATURE};  // 老模型
+    private float[][] run(float[] inputValues, Map<String, int[]> sequenceValues, int numExamples) {
+        long[] inputShape = {numExamples, this.NUM_FEATURE};
         Tensor<Float> inputTensor = Tensor.create(
-                inputShape,  
-                FloatBuffer.wrap(inputValues) 
+                inputShape,
+                FloatBuffer.wrap(inputValues)
         );
-        return this.session.runner().feed(this.INPUT_OPERATION_NAME, inputTensor)
+
+        // 序列数据
+        if (this.withSequenceInputs){
+            Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
+            this.session.runner().feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
+                    .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids));
+        }
+
+        return this.session.runner().feed(this.X, inputTensor)
                 .feed("keep_prob", Tensor.create(1.0f, Float.class))  // dropout保留率
-                .fetch(this.OUTPUT_OPERATION_NAME).run().get(0)
+                .fetch(this.SOFT_MAX).run().get(0)
                 .copyTo(new float[numExamples][this.NUM_LABEL]);
     }
-    
-    
+
+
     /**
      * 运行模型,并将结果打包成目标格式
      */
@@ -85,14 +129,22 @@ public class TensorflowModel {
         float sum = 0;
         for (float f : inputValues)
             sum += f;
-        if(sum == 0)  // 如果输入没有有效特征,则直接返回null
+        if (sum == 0)  // 如果输入没有有效特征,则直接返回null
             return null;
-        
-        float[][] predict = this.run(inputValues, 1);  // 一次一个样本
-        return this.dataSet.wrap(predict);  
+
+        Map<String, int[]> sequenceValues = null;
+        if (this.withSequenceInputs){
+            sequenceValues = new HashMap<>();
+            sequenceValues.put(this.Char_ids, this.dataSet.toCharIds(inputs));
+            sequenceValues.put(this.Pos_ids, this.dataSet.toPosIds(inputs));
+        }
+
+
+        float[][] predict = this.run(inputValues, sequenceValues, 1);  // 一次一个样本
+        return this.dataSet.wrap(predict);
     }
-    
-    
+
+
     /**
      * 关闭会话,释放资源
      */
@@ -101,4 +153,8 @@ public class TensorflowModel {
         this.bundle.close();
     }
 
+    public void setWithSequenceInputs(boolean withSequenceInputs) {
+        this.withSequenceInputs = withSequenceInputs;
+    }
+
 }

+ 45 - 3
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSet.java

@@ -36,6 +36,11 @@ public abstract class NNDataSet {
     private final int numToPush = 3;  // 推荐推送的个数
     private final float rapidFallTimes = 5;  // 骤降倍数
 
+    // 序列数据
+    private final int MAX_LEN = 257;
+    private boolean withSequenceInputs = false;  // 是否带有序列输入
+    protected final Map<String, Integer> CHAR2ID_DICT = new HashMap<>();
+
 
     public NNDataSet(String modelAndVersion) {
         this.readDict(modelAndVersion);
@@ -45,6 +50,10 @@ public abstract class NNDataSet {
         this.LABEL_DICT_ARRAY = new String[this.NUM_LABEL];
         this.makeDictArr();
         this.readReSplitWordDict();
+
+        // 读取序列数据
+        if (this.withSequenceInputs)
+            this.readChar2IdDict(modelAndVersion);
     }
 
     /**
@@ -55,11 +64,33 @@ public abstract class NNDataSet {
      */
     public abstract float[] toFeatureVector(Map<String, Map<String, String>> inputs);
 
+    /**
+     * 装外部输入转为字符ids
+     *
+     * @param inputs
+     * @return
+     */
+    public abstract int[] toCharIds(Map<String, Map<String, String>> inputs);
+
+    /**
+     * 装外部输入转为位置ids
+     *
+     * @param inputs
+     * @return
+     */
+    public abstract int[] toPosIds(Map<String, Map<String, String>> inputs);
+
     /**
      * 读取特征和类别字典
      */
     public abstract void readDict(String modelAndVersion);
 
+
+    /**
+     * 读取特征和类别字典
+     */
+    public abstract void readChar2IdDict(String modelAndVersion);
+
     /**
      * 读取再分词字典
      */
@@ -320,9 +351,6 @@ public abstract class NNDataSet {
         this.FEATURE_NAME_STORE.addAll(features.keySet());
     }
 
-    /**
-     * @return
-     */
     public int getNumLabel() {
         return this.NUM_LABEL;
     }
@@ -332,4 +360,18 @@ public abstract class NNDataSet {
         this.doFilterDiagnosis = doFilterDiagnosis;
     }
 
+
+    public int getMAX_LEN() {
+        return MAX_LEN;
+    }
+
+
+    public void setWithSequenceInputs(boolean withSequenceInputs) {
+        this.withSequenceInputs = withSequenceInputs;
+    }
+
+
+    public boolean isWithSequenceInputs() {
+        return withSequenceInputs;
+    }
 }

+ 75 - 4
algorithm/src/main/java/org/algorithm/core/neural/dataset/NNDataSetImpl.java

@@ -1,12 +1,14 @@
 package org.algorithm.core.neural.dataset;
 
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONObject;
 import org.algorithm.util.TextFileReader;
 import org.diagbot.pub.utils.PropertiesUtil;
 
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.*;
 import java.util.Map.Entry;
 
 /**
@@ -66,6 +68,42 @@ public class NNDataSetImpl extends NNDataSet {
         return featureVector;
     }
 
+    @Override
+    public int[] toCharIds(Map<String, Map<String, String>> inputs) {
+        String sentence = inputs.get("sentence").get("sentence");
+        int max_len = this.getMAX_LEN();
+        int[] ids = new int[max_len];
+        char ch = '1';
+        Integer id = null;
+        for (int i = 0; i < sentence.length(); i++) {
+            ch = sentence.charAt(i);
+            id = this.CHAR2ID_DICT.get(String.valueOf(ch));
+            if (id == null) {
+                id = this.CHAR2ID_DICT.get("<UKC>");
+            }
+            ids[i] = id.intValue();
+        }
+        for (int i = sentence.length(); i < max_len; i++)  // padding
+            ids[i] = this.CHAR2ID_DICT.get("<PAD>");
+
+        return ids;
+    }
+
+    @Override
+    public int[] toPosIds(Map<String, Map<String, String>> inputs) {
+        int max_len = this.getMAX_LEN();
+        String sentence = inputs.get("sentence").get("sentence");
+        int[] pos_ids = new int[max_len];
+        for (int j=0; j<max_len; j++)
+            pos_ids[j] = max_len - 1;  // 位置的padding
+
+        // 绝对位置编码
+        for (int i = 0 ; i < (sentence.length() < max_len ? sentence.length() : max_len); i++)
+            pos_ids[i] = i;
+
+        return pos_ids;
+    }
+
     @Override
     public void readDict(String modelAndVersion) {
 
@@ -109,6 +147,39 @@ public class NNDataSetImpl extends NNDataSet {
 
     }
 
+    @Override
+    public void readChar2IdDict(String modelAndVersion) {
+
+        // 获取文件目录
+        PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
+        String filePath = prop.getProperty("basicPath");  // 基本目录
+        filePath = filePath.substring(0, filePath.indexOf("model_version_replacement"));
+
+        filePath = filePath + "char2id.bin";  // 字典文件位置
+
+        // 读取以json字符串保存的数据
+        BufferedReader br = null;
+        try {
+            br = new BufferedReader(new FileReader(filePath));  // 读取原始json文件
+            String s = null;
+            while ((s = br.readLine()) != null) {
+                JSONObject jsonObject = (JSONObject) JSON.parse(s);
+                Set<Entry<String, Object>> entries = jsonObject.entrySet();
+                for (Map.Entry<String, Object> entry : entries)
+                    this.CHAR2ID_DICT.put(entry.getKey(), (Integer) entry.getValue());
+            }
+        } catch (Exception e) {
+            e.printStackTrace();
+        } finally {
+            try {
+                br.close();
+            } catch (IOException e) {
+                e.printStackTrace();
+            }
+        }
+
+    }
+
     /**
      * 再分词:
      * 基本操作: