|
@@ -6,43 +6,55 @@ import org.tensorflow.Session;
|
|
import org.tensorflow.Tensor;
|
|
import org.tensorflow.Tensor;
|
|
|
|
|
|
import java.nio.FloatBuffer;
|
|
import java.nio.FloatBuffer;
|
|
|
|
+import java.nio.IntBuffer;
|
|
|
|
+import java.util.HashMap;
|
|
import java.util.Map;
|
|
import java.util.Map;
|
|
|
|
|
|
/**
|
|
/**
|
|
* tensorflow 模型类,要求单个样本是1维向量,而不是高维向量
|
|
* tensorflow 模型类,要求单个样本是1维向量,而不是高维向量
|
|
|
|
+ *
|
|
* @Author: bijl
|
|
* @Author: bijl
|
|
* @Date: 2018年7月19日-下午7:21:24
|
|
* @Date: 2018年7月19日-下午7:21:24
|
|
* @Description:
|
|
* @Description:
|
|
*/
|
|
*/
|
|
public class TensorflowModel {
|
|
public class TensorflowModel {
|
|
-
|
|
|
|
- private final String INPUT_OPERATION_NAME; // 输入op的名称
|
|
|
|
- private final String OUTPUT_OPERATION_NAME; // 输出op的名称
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ private final String X = "X"; // 输入op x的名字
|
|
|
|
+ private final String Char_ids = "Char_ids"; // 输入op Char_ids的名字
|
|
|
|
+ private final String Pos_ids = "Pos_ids"; // 输入op Pos_ids的名字
|
|
|
|
+ private final String SOFT_MAX = "softmax/softmax"; // 输出op的名称
|
|
|
|
+
|
|
private final int NUM_FEATURE; // 特征个数
|
|
private final int NUM_FEATURE; // 特征个数
|
|
private final int NUM_LABEL; // 标签(类别)个数
|
|
private final int NUM_LABEL; // 标签(类别)个数
|
|
private SavedModelBundle bundle; // 模型捆绑
|
|
private SavedModelBundle bundle; // 模型捆绑
|
|
private Session session; // 会话
|
|
private Session session; // 会话
|
|
private NNDataSet dataSet; // 数据集
|
|
private NNDataSet dataSet; // 数据集
|
|
-
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ private boolean withSequenceInputs = false; // 是否带有序列输入
|
|
|
|
+ private final int MAX_LEN; // 最大长度
|
|
|
|
+
|
|
|
|
+
|
|
/**
|
|
/**
|
|
- *
|
|
|
|
- * @param exportDir 模型保存地址
|
|
|
|
- * @param inputOpName 输入op的名称
|
|
|
|
- * @param outputOpName 输出op的名称
|
|
|
|
- * @param dataSet 模型使用的数据集
|
|
|
|
|
|
+ * @param exportDir 模型保存地址
|
|
|
|
+ * @param dataSet 模型使用的数据集
|
|
*/
|
|
*/
|
|
- public TensorflowModel(String exportDir, String inputOpName, String outputOpName, NNDataSet dataSet) {
|
|
|
|
- this.INPUT_OPERATION_NAME = inputOpName;
|
|
|
|
- this.OUTPUT_OPERATION_NAME = outputOpName;
|
|
|
|
|
|
+ public TensorflowModel(String exportDir, NNDataSet dataSet) {
|
|
|
|
+
|
|
|
|
+ this.init(exportDir);
|
|
this.dataSet = dataSet;
|
|
this.dataSet = dataSet;
|
|
this.NUM_FEATURE = this.dataSet.getNumFeature();
|
|
this.NUM_FEATURE = this.dataSet.getNumFeature();
|
|
this.NUM_LABEL = this.dataSet.getNumLabel();
|
|
this.NUM_LABEL = this.dataSet.getNumLabel();
|
|
- this.init(exportDir);
|
|
|
|
-
|
|
|
|
|
|
+
|
|
|
|
+ // 序列数据有段的属性
|
|
|
|
+ this.MAX_LEN = this.dataSet.getMAX_LEN();
|
|
|
|
+ this.withSequenceInputs = this.dataSet.isWithSequenceInputs();
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
/**
|
|
/**
|
|
* 初始化:加载模型,获取会话。
|
|
* 初始化:加载模型,获取会话。
|
|
|
|
+ *
|
|
* @param exportDir
|
|
* @param exportDir
|
|
*/
|
|
*/
|
|
public void init(String exportDir) {
|
|
public void init(String exportDir) {
|
|
@@ -54,29 +66,61 @@ public class TensorflowModel {
|
|
}
|
|
}
|
|
|
|
|
|
// create the session from the Bundle
|
|
// create the session from the Bundle
|
|
- this.session = bundle.session();
|
|
|
|
|
|
+ this.session = bundle.session();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ /**
|
|
|
|
+ * 包装序列化输入
|
|
|
|
+ *
|
|
|
|
+ * @param sequenceValuesMap 序列输入的map
|
|
|
|
+ * @param numExamples 样本数
|
|
|
|
+ * @return
|
|
|
|
+ */
|
|
|
|
+ private Map<String, Tensor<Integer>> wrapSequenceInputs(Map<String, int[]> sequenceValuesMap, int numExamples) {
|
|
|
|
+ long[] inputShape = {numExamples, this.MAX_LEN};
|
|
|
|
+ Map<String, Tensor<Integer>> sequenceTensorMap = new HashMap<>();
|
|
|
|
+ for (Map.Entry<String, int[]> entry : sequenceValuesMap.entrySet()) {
|
|
|
|
+ String mapKey = entry.getKey();
|
|
|
|
+ Tensor<Integer> inputTensor = Tensor.create(
|
|
|
|
+ inputShape,
|
|
|
|
+ IntBuffer.wrap(entry.getValue())
|
|
|
|
+ );
|
|
|
|
+ sequenceTensorMap.put(mapKey, inputTensor);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return sequenceTensorMap;
|
|
}
|
|
}
|
|
-
|
|
|
|
|
|
+
|
|
|
|
+
|
|
/**
|
|
/**
|
|
* 运行模型
|
|
* 运行模型
|
|
- * @param inputValues 输入值
|
|
|
|
- * @param numExamples 样本个数
|
|
|
|
|
|
+ *
|
|
|
|
+ * @param inputValues 输入值
|
|
|
|
+ * @param numExamples 样本个数
|
|
* @return 模型的输出
|
|
* @return 模型的输出
|
|
*/
|
|
*/
|
|
- private float[][] run(float[] inputValues, int numExamples){
|
|
|
|
-// long[] inputShape = {numExamples, this.NUM_FEATURE, 4, 1}; // 新模型
|
|
|
|
- long[] inputShape = {numExamples, this.NUM_FEATURE}; // 老模型
|
|
|
|
|
|
+ private float[][] run(float[] inputValues, Map<String, int[]> sequenceValues, int numExamples) {
|
|
|
|
+ long[] inputShape = {numExamples, this.NUM_FEATURE};
|
|
Tensor<Float> inputTensor = Tensor.create(
|
|
Tensor<Float> inputTensor = Tensor.create(
|
|
- inputShape,
|
|
|
|
- FloatBuffer.wrap(inputValues)
|
|
|
|
|
|
+ inputShape,
|
|
|
|
+ FloatBuffer.wrap(inputValues)
|
|
);
|
|
);
|
|
- return this.session.runner().feed(this.INPUT_OPERATION_NAME, inputTensor)
|
|
|
|
|
|
+
|
|
|
|
+ // 序列数据
|
|
|
|
+ if (this.withSequenceInputs){
|
|
|
|
+ Map<String, Tensor<Integer>> sequenceTensorMap = this.wrapSequenceInputs(sequenceValues, numExamples);
|
|
|
|
+ this.session.runner().feed(this.Char_ids, sequenceTensorMap.get(this.Char_ids))
|
|
|
|
+ .feed(this.Pos_ids, sequenceTensorMap.get(this.Pos_ids));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return this.session.runner().feed(this.X, inputTensor)
|
|
.feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
.feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
- .fetch(this.OUTPUT_OPERATION_NAME).run().get(0)
|
|
|
|
|
|
+ .fetch(this.SOFT_MAX).run().get(0)
|
|
.copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
.copyTo(new float[numExamples][this.NUM_LABEL]);
|
|
}
|
|
}
|
|
-
|
|
|
|
-
|
|
|
|
|
|
+
|
|
|
|
+
|
|
/**
|
|
/**
|
|
* 运行模型,并将结果打包成目标格式
|
|
* 运行模型,并将结果打包成目标格式
|
|
*/
|
|
*/
|
|
@@ -85,14 +129,22 @@ public class TensorflowModel {
|
|
float sum = 0;
|
|
float sum = 0;
|
|
for (float f : inputValues)
|
|
for (float f : inputValues)
|
|
sum += f;
|
|
sum += f;
|
|
- if(sum == 0) // 如果输入没有有效特征,则直接返回null
|
|
|
|
|
|
+ if (sum == 0) // 如果输入没有有效特征,则直接返回null
|
|
return null;
|
|
return null;
|
|
-
|
|
|
|
- float[][] predict = this.run(inputValues, 1); // 一次一个样本
|
|
|
|
- return this.dataSet.wrap(predict);
|
|
|
|
|
|
+
|
|
|
|
+ Map<String, int[]> sequenceValues = null;
|
|
|
|
+ if (this.withSequenceInputs){
|
|
|
|
+ sequenceValues = new HashMap<>();
|
|
|
|
+ sequenceValues.put(this.Char_ids, this.dataSet.toCharIds(inputs));
|
|
|
|
+ sequenceValues.put(this.Pos_ids, this.dataSet.toPosIds(inputs));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ float[][] predict = this.run(inputValues, sequenceValues, 1); // 一次一个样本
|
|
|
|
+ return this.dataSet.wrap(predict);
|
|
}
|
|
}
|
|
-
|
|
|
|
-
|
|
|
|
|
|
+
|
|
|
|
+
|
|
/**
|
|
/**
|
|
* 关闭会话,释放资源
|
|
* 关闭会话,释放资源
|
|
*/
|
|
*/
|
|
@@ -101,4 +153,8 @@ public class TensorflowModel {
|
|
this.bundle.close();
|
|
this.bundle.close();
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ public void setWithSequenceInputs(boolean withSequenceInputs) {
|
|
|
|
+ this.withSequenceInputs = withSequenceInputs;
|
|
|
|
+ }
|
|
|
|
+
|
|
}
|
|
}
|