text_similarity.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import os
  2. import numpy as np
  3. import re
  4. from typing import List, Dict, Tuple, Optional, Union, Callable
  5. from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
  6. from sklearn.metrics.pairwise import cosine_similarity
  7. import jieba
  8. import argparse
  9. class TextSimilarityFinder:
  10. """文本相似度查找器,用于从多个文本中找到与目标文本最相似的文本"""
  11. def __init__(self, method: str = 'tfidf', use_jieba: bool = True, ngram_range: Tuple[int, int] = (1, 1)):
  12. """
  13. 初始化文本相似度查找器
  14. Args:
  15. method: 相似度计算方法,可选值为 'tfidf'、'count'、'jaccard'
  16. use_jieba: 是否使用结巴分词(中文文本建议开启)
  17. ngram_range: N-gram范围,例如(1,2)表示同时使用unigram和bigram特征
  18. """
  19. self.method = method
  20. self.use_jieba = use_jieba
  21. self.ngram_range = ngram_range
  22. self.vectorizer = None
  23. self.corpus_vectors = None
  24. self.corpus_texts = None
  25. self.corpus_paths = None
  26. def _tokenize_text(self, text: str) -> str:
  27. """对文本进行分词处理
  28. Args:
  29. text: 输入文本
  30. Returns:
  31. str: 分词后的文本
  32. """
  33. if not self.use_jieba:
  34. return text
  35. # 使用结巴分词,返回空格分隔的词语
  36. return " ".join(jieba.cut(text))
  37. def _preprocess_texts(self, texts: List[str]) -> List[str]:
  38. """预处理文本列表
  39. Args:
  40. texts: 文本列表
  41. Returns:
  42. List[str]: 预处理后的文本列表
  43. """
  44. processed_texts = []
  45. for text in texts:
  46. # 去除标点符号、特殊字符和多余空白字符
  47. # 匹配中文标点符号(包括圆点•)
  48. text = re.sub(r'[\u3000-\u303F]|[\uFF00-\uFFEF]|[•]', ' ', text)
  49. # 匹配英文标点符号和特殊字符
  50. text = re.sub(r'[!"#$%&\'()*+,-./:;<=>?@\[\\\]^_`{|}~]', ' ', text)
  51. # 匹配其他特殊字符(如制表符、换行符等)
  52. text = re.sub(r'[\t\n\r\f\v]', ' ', text)
  53. # 去除多余空白字符
  54. text = re.sub(r'\s+', ' ', text).strip()
  55. # 分词处理
  56. text = self._tokenize_text(text)
  57. processed_texts.append(text)
  58. return processed_texts
  59. def _initialize_vectorizer(self):
  60. """初始化文本向量化器"""
  61. if self.method == 'tfidf':
  62. self.vectorizer = TfidfVectorizer(ngram_range=self.ngram_range)
  63. elif self.method == 'count':
  64. self.vectorizer = CountVectorizer(ngram_range=self.ngram_range)
  65. elif self.method == 'jaccard':
  66. # Jaccard相似度不需要向量化器,使用集合操作
  67. self.vectorizer = None
  68. else:
  69. raise ValueError(f"不支持的相似度计算方法: {self.method}")
  70. def _calculate_jaccard_similarity(self, text1: str, text2: str) -> float:
  71. """计算两个文本的Jaccard相似度
  72. Args:
  73. text1: 第一个文本
  74. text2: 第二个文本
  75. Returns:
  76. float: Jaccard相似度
  77. """
  78. # 将文本转换为词集合
  79. set1 = set(text1.split())
  80. set2 = set(text2.split())
  81. # 计算Jaccard相似度: 交集大小 / 并集大小
  82. intersection = len(set1.intersection(set2))
  83. union = len(set1.union(set2))
  84. # 避免除以零
  85. if union == 0:
  86. return 0.0
  87. return intersection / union
  88. def load_corpus_from_directory(self, directory: str, file_pattern: str = '*.txt'):
  89. """从目录加载语料库
  90. Args:
  91. directory: 目录路径
  92. file_pattern: 文件匹配模式
  93. """
  94. import glob
  95. # 获取所有匹配的文件路径
  96. file_paths = glob.glob(os.path.join(directory, file_pattern))
  97. if not file_paths:
  98. raise ValueError(f"在目录 {directory} 中没有找到匹配 {file_pattern} 的文件")
  99. # 读取所有文件内容
  100. texts = []
  101. for file_path in file_paths:
  102. try:
  103. with open(file_path, 'r', encoding='utf-8') as f:
  104. texts.append(f.read())
  105. except Exception as e:
  106. print(f"读取文件 {file_path} 失败: {str(e)}")
  107. self.load_corpus(texts, file_paths)
  108. def load_corpus(self, texts: List[str], paths: Optional[List[str]] = None):
  109. """加载语料库
  110. Args:
  111. texts: 文本列表
  112. paths: 文本对应的路径或标识符列表
  113. """
  114. if not texts:
  115. raise ValueError("文本列表不能为空")
  116. # 预处理文本
  117. processed_texts = self._preprocess_texts(texts)
  118. # 初始化向量化器
  119. self._initialize_vectorizer()
  120. # 存储原始文本和路径
  121. self.corpus_texts = texts
  122. self.corpus_paths = paths if paths else [f"text_{i}" for i in range(len(texts))]
  123. # 向量化语料库
  124. if self.method in ['tfidf', 'count']:
  125. self.corpus_vectors = self.vectorizer.fit_transform(processed_texts)
  126. else:
  127. # 对于Jaccard相似度,直接存储处理后的文本
  128. self.corpus_vectors = processed_texts
  129. def find_most_similar(self, query_text: str, top_n: int = 1) -> List[Dict]:
  130. """查找与查询文本最相似的文本
  131. Args:
  132. query_text: 查询文本
  133. top_n: 返回最相似的前N个结果
  134. Returns:
  135. List[Dict]: 包含相似度信息的结果列表
  136. """
  137. if self.corpus_vectors is None:
  138. raise ValueError("请先加载语料库")
  139. # 预处理查询文本
  140. processed_query = self._preprocess_texts([query_text])[0]
  141. # 计算相似度
  142. similarities = []
  143. if self.method in ['tfidf', 'count']:
  144. # 向量化查询文本
  145. query_vector = self.vectorizer.transform([processed_query])
  146. # 计算余弦相似度
  147. similarity_matrix = cosine_similarity(query_vector, self.corpus_vectors)
  148. similarities = similarity_matrix[0]
  149. else: # Jaccard相似度
  150. for corpus_text in self.corpus_vectors:
  151. similarity = self._calculate_jaccard_similarity(processed_query, corpus_text)
  152. similarities.append(similarity)
  153. # 获取相似度排序的索引
  154. top_indices = np.argsort(similarities)[::-1][:top_n]
  155. # 构建结果
  156. results = []
  157. for idx in top_indices:
  158. results.append({
  159. 'text': self.corpus_texts[idx],
  160. 'path': self.corpus_paths[idx],
  161. 'similarity': similarities[idx]
  162. })
  163. return results
  164. def main():
  165. # 解析命令行参数
  166. parser = argparse.ArgumentParser(description='文本相似度查找工具')
  167. parser.add_argument('--query', type=str, required=True, help='查询文本或查询文本文件路径')
  168. parser.add_argument('--dir', type=str, required=True, help='语料库目录路径')
  169. parser.add_argument('--pattern', type=str, default='*.txt', help='文件匹配模式')
  170. parser.add_argument('--method', type=str, default='tfidf', choices=['tfidf', 'count', 'jaccard'], help='相似度计算方法')
  171. parser.add_argument('--top_n', type=int, default=3, help='返回最相似的前N个结果')
  172. parser.add_argument('--no_jieba', action='store_true', help='不使用结巴分词')
  173. parser.add_argument('--ngram_min', type=int, default=1, help='N-gram最小值')
  174. parser.add_argument('--ngram_max', type=int, default=1, help='N-gram最大值')
  175. args = parser.parse_args()
  176. # 初始化文本相似度查找器
  177. finder = TextSimilarityFinder(
  178. method=args.method,
  179. use_jieba=not args.no_jieba,
  180. ngram_range=(args.ngram_min, args.ngram_max)
  181. )
  182. # 加载语料库
  183. finder.load_corpus_from_directory(args.dir, args.pattern)
  184. # 获取查询文本
  185. query_text = args.query
  186. if os.path.isfile(query_text):
  187. try:
  188. with open(query_text, 'r', encoding='utf-8') as f:
  189. query_text = f.read()
  190. except Exception as e:
  191. print(f"读取查询文件失败: {str(e)}")
  192. return
  193. # 查找最相似的文本
  194. results = finder.find_most_similar(query_text, args.top_n)
  195. # 打印结果
  196. print(f"\n查询文本: {query_text[:100]}..." if len(query_text) > 100 else f"\n查询文本: {query_text}")
  197. print(f"\n找到 {len(results)} 个最相似的文本:")
  198. for i, result in enumerate(results):
  199. print(f"\n{i+1}. 相似度: {result['similarity']:.4f}")
  200. print(f" 路径: {result['path']}")
  201. text_preview = result['text'][:200] + '...' if len(result['text']) > 200 else result['text']
  202. print(f" 内容预览: {text_preview}")
  203. if __name__ == "__main__":
  204. main()