|
@@ -0,0 +1,184 @@
|
|
|
+package org.algorithm.core.cnn.model;
|
|
|
+
|
|
|
+import org.algorithm.core.cnn.AlgorithmCNNExecutor;
|
|
|
+import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
|
|
|
+import org.algorithm.core.cnn.entity.Triad;
|
|
|
+import org.diagbot.pub.utils.PropertiesUtil;
|
|
|
+import org.tensorflow.SavedModelBundle;
|
|
|
+import org.tensorflow.Session;
|
|
|
+import org.tensorflow.Tensor;
|
|
|
+
|
|
|
+import java.io.File;
|
|
|
+import java.nio.FloatBuffer;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.List;
|
|
|
+import java.util.concurrent.*;
|
|
|
+
|
|
|
+/**
|
|
|
+ * @Author: bijl
|
|
|
+ * @Date: 2019/1/22 10:21
|
|
|
+ * @Description: 集成模型
|
|
|
+ */
|
|
|
+public class RelationExtractionEnsembleModel extends AlgorithmCNNExecutor {
|
|
|
+ private final String X_PLACEHOLDER = "X";
|
|
|
+ private final String PREDICTION = "prediction/prediction";
|
|
|
+ private final int NUM_LABEL = 1;
|
|
|
+ private SavedModelBundle bundle; // 模型捆绑
|
|
|
+ private Session session; // 会话
|
|
|
+ private RelationExtractionDataSet dataSet;
|
|
|
+ private RelationExtractionSubModel[] subModels = new RelationExtractionSubModel[3];
|
|
|
+ private ExecutorService executorService = Executors.newCachedThreadPool();
|
|
|
+
|
|
|
+ public RelationExtractionEnsembleModel() {
|
|
|
+ PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
|
|
|
+
|
|
|
+ String modelsPath = prop.getProperty("basicPath"); // 模型基本路径
|
|
|
+ String dataSetPath = modelsPath.substring(0, modelsPath.indexOf("model_version_replacement"));
|
|
|
+ dataSetPath = dataSetPath + File.separator + "char2id.json";
|
|
|
+ String exportDir = modelsPath.replace("model_version_replacement", "ensemble_model_2");
|
|
|
+
|
|
|
+ this.dataSet = new RelationExtractionDataSet(dataSetPath);
|
|
|
+ this.init(exportDir);
|
|
|
+ subModels[0] = new RelationExtractionSubModel("cnn_1d_low");
|
|
|
+ subModels[1] = new RelationExtractionSubModel("cnn_1d_lstm_low");
|
|
|
+ subModels[2] = new RelationExtractionSubModel("lstm_low_api");
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 初始化:加载模型,获取会话。
|
|
|
+ *
|
|
|
+ * @param exportDir 模型地址
|
|
|
+ */
|
|
|
+ public void init(String exportDir) {
|
|
|
+ /* load the model Bundle */
|
|
|
+ try {
|
|
|
+ this.bundle = SavedModelBundle.load(exportDir, "serve");
|
|
|
+ } catch (Exception e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ }
|
|
|
+
|
|
|
+ // create the session from the Bundle
|
|
|
+ this.session = bundle.session();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 转化数据为张量
|
|
|
+ *
|
|
|
+ * @param content 句子
|
|
|
+ * @param triads 三元组list
|
|
|
+ * @return int[3][] 表示charId,pos1,pos2
|
|
|
+ */
|
|
|
+ private int[][] convertData(String content, List<Triad> triads) {
|
|
|
+
|
|
|
+ int[][] inputValues = new int[3][triads.size() * this.dataSet.MAX_LEN];
|
|
|
+ for (int i = 0; i < triads.size(); i++) {
|
|
|
+ Triad triad = triads.get(i);
|
|
|
+ int[][] aInput = this.dataSet.getExample(content, triad.getL_1(), triad.getL_2());
|
|
|
+ for (int j = 0; j < aInput.length; j++)
|
|
|
+ for (int k = 0; k < this.dataSet.MAX_LEN; k++)
|
|
|
+ inputValues[j][i * this.dataSet.MAX_LEN] = aInput[j][k];
|
|
|
+ }
|
|
|
+
|
|
|
+ return inputValues;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public List<Triad> execute(String content, List<Triad> triads) {
|
|
|
+ // 句子长度不超过MAX_LEN,有三元组
|
|
|
+ if (content.length() > this.dataSet.MAX_LEN || triads.size() < 1) {
|
|
|
+ return new ArrayList<>();
|
|
|
+ }
|
|
|
+ int[][] inputValues = this.convertData(content, triads); // shape = [3, batchSize * this.subModels.length]
|
|
|
+ int batchSize = triads.size();
|
|
|
+
|
|
|
+ float[] sigmoidS = new float[batchSize * this.subModels.length]; // 集成模型的输入
|
|
|
+
|
|
|
+ List<Future<float[][]>> futureList = new ArrayList<>();
|
|
|
+
|
|
|
+// // 非并行运行子模型
|
|
|
+// for (int i = 0; i < this.subModels.length; i++) {
|
|
|
+// float[][] sigmoid = subModels[i].sigmoid(inputValues, batchSize); // 子模型预测
|
|
|
+// for (int j = 0; j < batchSize; j++)
|
|
|
+// sigmoidS[i * batchSize + j] = sigmoid[j][0];
|
|
|
+// }
|
|
|
+
|
|
|
+// 多线程运行子模型
|
|
|
+ for (int i = 0; i < this.subModels.length; i++) {
|
|
|
+ int index = i;
|
|
|
+ Future<float[][]> future = this.executorService.submit(new Callable<float[][]>() {
|
|
|
+ @Override
|
|
|
+ public float[][] call() throws Exception {
|
|
|
+ return subModels[index].sigmoid(inputValues, batchSize);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ futureList.add(future);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 从future中获取数据,并填入sigmoidS中
|
|
|
+ for (int i = 0; i < this.subModels.length; i++) {
|
|
|
+ try {
|
|
|
+ float[][] sigmoid = futureList.get(i).get();
|
|
|
+ for (int j = 0; j < batchSize; j++)
|
|
|
+ sigmoidS[i * batchSize + j] = sigmoid[j][0];
|
|
|
+
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ System.err.println("获取数据不成功");
|
|
|
+ } catch (ExecutionException e) {
|
|
|
+ e.printStackTrace();
|
|
|
+ System.err.println("获取数据不成功");
|
|
|
+ }
|
|
|
+ }
|
|
|
+// this.executorService.shutdown();
|
|
|
+
|
|
|
+ float[][] prediction = this.run(sigmoidS, batchSize);
|
|
|
+
|
|
|
+ //设置三元组关系
|
|
|
+ for (int j = 0; j < prediction.length; j++) {
|
|
|
+ if (prediction[j][0] == 1.0)
|
|
|
+ triads.get(j).setRelation("有");
|
|
|
+ else
|
|
|
+ triads.get(j).setRelation("无");
|
|
|
+ }
|
|
|
+
|
|
|
+ //删除无关系三元组
|
|
|
+ List<Triad> deleteTriads = new ArrayList<>();
|
|
|
+ for (Triad triad : triads)
|
|
|
+ if ("无".equals(triad.getRelation())) // 有关系着留下
|
|
|
+ deleteTriads.add(triad);
|
|
|
+ for (Triad triad : deleteTriads)
|
|
|
+ triads.remove(triad);
|
|
|
+
|
|
|
+ return triads;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ /**
|
|
|
+ * @param inputValues 字符id,相对于实体1位置,相对于实体2位置
|
|
|
+ * @param batchSize 批量大小
|
|
|
+ * @return float[][] shape = [batchSize, 1]
|
|
|
+ */
|
|
|
+ private float[][] run(float[] inputValues, int batchSize) {
|
|
|
+ long[] shape = {batchSize, this.subModels.length}; // 老模型
|
|
|
+ Tensor<Float> sigmoidS = Tensor.create(
|
|
|
+ shape,
|
|
|
+ FloatBuffer.wrap(inputValues)
|
|
|
+ );
|
|
|
+
|
|
|
+
|
|
|
+ return this.session.runner()
|
|
|
+ .feed(this.X_PLACEHOLDER, sigmoidS)
|
|
|
+ .fetch(this.PREDICTION).run().get(0)
|
|
|
+ .copyTo(new float[batchSize][this.NUM_LABEL]);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 关闭会话,释放资源
|
|
|
+ */
|
|
|
+ public void close() {
|
|
|
+ this.session.close();
|
|
|
+ this.bundle.close();
|
|
|
+ for (RelationExtractionSubModel subModel : this.subModels)
|
|
|
+ subModel.close();
|
|
|
+ }
|
|
|
+}
|