|
@@ -4,10 +4,12 @@ import org.algorithm.core.cnn.dataset.RelationExtractionDataSet;
|
|
|
import org.diagbot.pub.utils.PropertiesUtil;
|
|
|
import org.tensorflow.SavedModelBundle;
|
|
|
import org.tensorflow.Session;
|
|
|
+import org.tensorflow.Session.Runner;
|
|
|
import org.tensorflow.Tensor;
|
|
|
|
|
|
import java.io.File;
|
|
|
import java.nio.IntBuffer;
|
|
|
+import java.util.Map;
|
|
|
|
|
|
/**
|
|
|
* @Author: bijl
|
|
@@ -18,18 +20,20 @@ public class RelationExtractionSubModel {
|
|
|
private final String X_PLACEHOLDER = "X";
|
|
|
private final String pos1_PLACEHOLDER = "pos1";
|
|
|
private final String pos2_PLACEHOLDER = "pos2";
|
|
|
+ private Map<String, Tensor<Float>> keep_probs = null;
|
|
|
private String PREDICTION = null;
|
|
|
private final int NUM_LABEL = 1;
|
|
|
private SavedModelBundle bundle; // 模型捆绑
|
|
|
private Session session; // 会话
|
|
|
protected RelationExtractionDataSet dataSet;
|
|
|
|
|
|
- public RelationExtractionSubModel(String modelName) {
|
|
|
+ public RelationExtractionSubModel(String modelName, Map<String, Tensor<Float>> keep_probs) {
|
|
|
+
|
|
|
+ this.keep_probs = keep_probs;
|
|
|
|
|
|
PropertiesUtil prop = new PropertiesUtil("/algorithm.properties");
|
|
|
|
|
|
String modelsPath = prop.getProperty("basicPath"); // 模型基本路径
|
|
|
- String re_path = prop.getProperty("relationExtraction"); // 模型基本路径
|
|
|
this.PREDICTION = modelName + "/prediction/Sigmoid";
|
|
|
String exportDir = modelsPath.replace("model_version_replacement", modelName);
|
|
|
String dataSetPath = modelsPath.substring(0, modelsPath.indexOf("model_version_replacement"));
|
|
@@ -76,12 +80,23 @@ public class RelationExtractionSubModel {
|
|
|
IntBuffer.wrap(inputValues[2])
|
|
|
);
|
|
|
|
|
|
- return this.session.runner()
|
|
|
+// return this.session.runner()
|
|
|
+// .feed(this.X_PLACEHOLDER, charId) // 输入,字符id
|
|
|
+// .feed(this.pos1_PLACEHOLDER, pos1) // 输入,相对位置1
|
|
|
+// .feed(this.pos2_PLACEHOLDER, pos2) // 输入,相对位置2
|
|
|
+// .feed("keep_prob", Tensor.create(1.0f, Float.class)) // 输入,dropout保留率
|
|
|
+// .fetch(this.PREDICTION).run().get(0) // 输出,tensor
|
|
|
+// .copyTo(new float[batchSize][this.NUM_LABEL]); // tensor转float[]对象
|
|
|
+
|
|
|
+ Runner runner = this.session.runner() // 不同子模型共用的输入
|
|
|
.feed(this.X_PLACEHOLDER, charId) // 输入,字符id
|
|
|
.feed(this.pos1_PLACEHOLDER, pos1) // 输入,相对位置1
|
|
|
- .feed(this.pos2_PLACEHOLDER, pos2) // 输入,相对位置2
|
|
|
- .feed("keep_prob", Tensor.create(1.0f, Float.class)) // 输入,dropout保留率
|
|
|
- .fetch(this.PREDICTION).run().get(0) // 输出,tensor
|
|
|
+ .feed(this.pos2_PLACEHOLDER, pos2); // 输入,相对位置2
|
|
|
+
|
|
|
+ for (Map.Entry<String, Tensor<Float>> entry : this.keep_probs.entrySet()) // 非共用的输入
|
|
|
+ runner = runner.feed(entry.getKey(), entry.getValue());
|
|
|
+
|
|
|
+ return runner.fetch(this.PREDICTION).run().get(0) // 输出,tensor
|
|
|
.copyTo(new float[batchSize][this.NUM_LABEL]); // tensor转float[]对象
|
|
|
}
|
|
|
|