|
@@ -11,6 +11,7 @@ import org.tensorflow.Session;
|
|
|
import org.tensorflow.Tensor;
|
|
|
|
|
|
import java.nio.FloatBuffer;
|
|
|
+import java.nio.IntBuffer;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
|
|
|
@@ -29,7 +30,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
|
|
|
private final String X_PLACEHOLDER = "X";
|
|
|
private final String pos1_PLACEHOLDER = "pos1";
|
|
|
private final String pos2_PLACEHOLDER = "pos2";
|
|
|
- private final String y_PLACEHOLDER = "y";
|
|
|
+ private final String SOFT_MAX = "softmax/Softmax";
|
|
|
private final int NUM_LABEL = 2;
|
|
|
private SavedModelBundle bundle; // 模型捆绑
|
|
|
private Session session; // 会话
|
|
@@ -79,7 +80,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
|
|
|
|
|
|
// 遍历组合
|
|
|
for (LemmaInfo[] lemmaInfoPair: combinations) {
|
|
|
- float[][] example = dataSet.getExample(content, lemmaInfoPair[0], lemmaInfoPair[1]);
|
|
|
+ int[][] example = dataSet.getExample(content, lemmaInfoPair[0], lemmaInfoPair[1]);
|
|
|
// 调用模型
|
|
|
float[][] relation = this.run(example, 1);
|
|
|
Triad triad = new Triad();
|
|
@@ -118,19 +119,19 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
|
|
|
* @param batchSize 批量大小
|
|
|
* @return
|
|
|
*/
|
|
|
- private float[][] run(float[][] inputValues, int batchSize){
|
|
|
+ private float[][] run(int[][] inputValues, int batchSize){
|
|
|
long[] shape = {1, dataSet.maxLength}; // 老模型
|
|
|
- Tensor<Float> charId = Tensor.create(
|
|
|
+ Tensor<Integer> charId = Tensor.create(
|
|
|
shape,
|
|
|
- FloatBuffer.wrap(inputValues[0])
|
|
|
+ IntBuffer.wrap(inputValues[0])
|
|
|
);
|
|
|
- Tensor<Float> pos1 = Tensor.create(
|
|
|
+ Tensor<Integer> pos1 = Tensor.create(
|
|
|
shape,
|
|
|
- FloatBuffer.wrap(inputValues[1])
|
|
|
+ IntBuffer.wrap(inputValues[1])
|
|
|
);
|
|
|
- Tensor<Float> pos2 = Tensor.create(
|
|
|
+ Tensor<Integer> pos2 = Tensor.create(
|
|
|
shape,
|
|
|
- FloatBuffer.wrap(inputValues[2])
|
|
|
+ IntBuffer.wrap(inputValues[2])
|
|
|
);
|
|
|
|
|
|
return this.session.runner()
|
|
@@ -138,7 +139,7 @@ public class RelationExtractionModel extends AlgorithmCNNExecutor {
|
|
|
.feed(this.pos1_PLACEHOLDER, pos1)
|
|
|
.feed(this.pos2_PLACEHOLDER, pos2)
|
|
|
.feed("keep_prob", Tensor.create(1.0f, Float.class)) // dropout保留率
|
|
|
- .fetch(this.y_PLACEHOLDER).run().get(0)
|
|
|
+ .fetch(this.SOFT_MAX).run().get(0)
|
|
|
.copyTo(new float[1][this.NUM_LABEL]);
|
|
|
}
|
|
|
|