nlp_tools.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # coding: utf-8
  2. from utils.cws_constant import *
  3. class InputFeatures(object):
  4. def __init__(self, text, label, input_id, label_id, input_mask, length):
  5. self.text = text
  6. self.label = label
  7. self.input_id = input_id
  8. self.label_id = label_id
  9. self.input_mask = input_mask
  10. self.lenght = length
  11. def load_vocab(vocab_file):
  12. """Loads a vocabulary file into a dictionary."""
  13. vocab = {}
  14. index = 0
  15. with open(vocab_file, "r", encoding="utf-8") as reader:
  16. while True:
  17. token = reader.readline()
  18. if not token:
  19. break
  20. token = token.strip()
  21. vocab[token] = index
  22. index += 1
  23. return vocab
  24. def load_file(file_path):
  25. contents = open(file_path, encoding='utf-8').readlines()
  26. text = []
  27. label = []
  28. texts = []
  29. labels = []
  30. for line in contents:
  31. if line != '\n':
  32. line = line.strip().split('\t')
  33. text.append(line[0])
  34. label.append(line[-1])
  35. else:
  36. texts.append(text)
  37. labels.append(label)
  38. text = []
  39. label = []
  40. return texts, labels
  41. def load_data(file_path, max_length, label_dic, vocab):
  42. # 返回InputFeatures的list
  43. texts, labels = load_file(file_path)
  44. assert len(texts) == len(labels)
  45. result = []
  46. for i in range(len(texts)):
  47. assert len(texts[i]) == len(labels[i])
  48. token = texts[i]
  49. label = labels[i]
  50. if len(token) > max_length - 2:
  51. token = token[0:(max_length - 2)]
  52. label = label[0:(max_length - 2)]
  53. tokens_f = ['[CLS]'] + token + ['[SEP]']
  54. label_f = ["<start>"] + label + ['<eos>']
  55. input_ids = [int(vocab[i]) if i in vocab else int(vocab['[UNK]']) for i in tokens_f]
  56. label_ids = [label_dic[i] for i in label_f]
  57. input_mask = [1] * len(input_ids)
  58. length = [len(tokens_f)]
  59. while len(input_ids) < max_length:
  60. input_ids.append(0)
  61. input_mask.append(0)
  62. label_ids.append(label_dic['<pad>'])
  63. assert len(input_ids) == max_length
  64. assert len(input_mask) == max_length
  65. assert len(label_ids) == max_length
  66. feature = InputFeatures(text=tokens_f, label=label_f, input_id=input_ids, input_mask=input_mask,
  67. label_id=label_ids, length=length)
  68. result.append(feature)
  69. return result
  70. def recover_label(pred_var, gold_var, l2i_dic, i2l_dic):
  71. assert len(pred_var) == len(gold_var)
  72. pred_variable = []
  73. gold_variable = []
  74. for i in range(len(gold_var)):
  75. start_index = gold_var[i].index(l2i_dic['<start>'])
  76. end_index = gold_var[i].index(l2i_dic['<eos>'])
  77. pred_variable.append(pred_var[i][start_index:end_index])
  78. gold_variable.append(gold_var[i][start_index:end_index])
  79. pred_label = []
  80. gold_label = []
  81. for j in range(len(gold_variable)):
  82. pred_label.append([i2l_dic[t] for t in pred_variable[j]])
  83. gold_label.append([i2l_dic[t] for t in gold_variable[j]])
  84. return pred_label, gold_label
  85. class SegmenterEvaluation():
  86. def evaluate(self, original_labels, predict_labels):
  87. right, predict = self.get_order(original_labels, predict_labels)
  88. print('right, predict: ', right, predict)
  89. right_count = self.rightCount(right, predict)
  90. if right_count == 0:
  91. recall = 0
  92. precision = 0
  93. f1 = 0
  94. error = 1
  95. else:
  96. recall = right_count / len(right)
  97. precision = right_count / len(predict)
  98. f1 = (2 * recall * precision) / (precision + recall)
  99. error = (len(predict) - right_count) / len(right)
  100. return precision, recall, f1, error, right, predict
  101. def rightCount(self, rightList, predictList):
  102. count = set(rightList) & set(predictList)
  103. return len(count)
  104. def get_order(self, original_labels, predict_labels):
  105. assert len(original_labels) == len(predict_labels)
  106. start = 1
  107. end = len(original_labels) - 1 # 当 len(original_labels) -1 > 1的时候,只要有一个字就没问题
  108. # 按照origin的长度,且删去开头结尾符
  109. original_labels = original_labels[start:end]
  110. predict_labels = predict_labels[start:end]
  111. def merge(labelList):
  112. # 输入标签字符串,返回一个个词的(begin,end+1)元组
  113. new_label = []
  114. chars = ""
  115. for i, label in enumerate(labelList):
  116. if label not in ("B", "M", "E", "S"): # 可能是其他标签
  117. if len(chars) != 0:
  118. new_label.append(chars)
  119. new_label.append(label)
  120. chars = ""
  121. elif label == "B":
  122. if len(chars) != 0:
  123. new_label.append(chars)
  124. chars = "B"
  125. elif label == "M":
  126. chars += "M"
  127. elif label == "S":
  128. if len(chars) != 0:
  129. new_label.append(chars)
  130. new_label.append("S")
  131. chars = ""
  132. else:
  133. new_label.append(chars + "E")
  134. chars = ""
  135. if len(chars) != 0:
  136. new_label.append(chars)
  137. orderList = []
  138. start = 0
  139. end = 0
  140. for each in new_label:
  141. end = start + len(each)
  142. orderList.append((start, end))
  143. start = end
  144. return orderList
  145. right = merge(original_labels)
  146. predict = merge(predict_labels)
  147. return right, predict
  148. def get_f1(gold_label, pred_label):
  149. assert len(gold_label) == len(pred_label)
  150. sege = SegmenterEvaluation()
  151. total_right = 0
  152. total_pred = 0
  153. total_gold = 0
  154. for i in range(len(gold_label)):
  155. temp_gold, temp_predict = sege.get_order(gold_label[i], pred_label[i])
  156. temp_right = sege.rightCount(temp_gold, temp_predict)
  157. total_right += temp_right
  158. total_gold += len(temp_gold)
  159. total_pred += len(temp_predict)
  160. recall = total_right / total_gold
  161. precision = total_right / total_pred
  162. f1 = (2 * recall * precision) / (precision + recall)
  163. return precision, recall, f1
  164. def save_model(path, model, epoch):
  165. pass
  166. def load_model(path, model):
  167. return model
  168. if __name__ == "__main__":
  169. pass