|
@@ -0,0 +1,252 @@
|
|
|
+import os
|
|
|
+import numpy as np
|
|
|
+import re
|
|
|
+from typing import List, Dict, Tuple, Optional, Union, Callable
|
|
|
+from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
|
|
|
+from sklearn.metrics.pairwise import cosine_similarity
|
|
|
+import jieba
|
|
|
+import argparse
|
|
|
+
|
|
|
+class TextSimilarityFinder:
|
|
|
+ """文本相似度查找器,用于从多个文本中找到与目标文本最相似的文本"""
|
|
|
+
|
|
|
+ def __init__(self, method: str = 'tfidf', use_jieba: bool = True, ngram_range: Tuple[int, int] = (1, 1)):
|
|
|
+ """
|
|
|
+ 初始化文本相似度查找器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ method: 相似度计算方法,可选值为 'tfidf'、'count'、'jaccard'
|
|
|
+ use_jieba: 是否使用结巴分词(中文文本建议开启)
|
|
|
+ ngram_range: N-gram范围,例如(1,2)表示同时使用unigram和bigram特征
|
|
|
+ """
|
|
|
+ self.method = method
|
|
|
+ self.use_jieba = use_jieba
|
|
|
+ self.ngram_range = ngram_range
|
|
|
+ self.vectorizer = None
|
|
|
+ self.corpus_vectors = None
|
|
|
+ self.corpus_texts = None
|
|
|
+ self.corpus_paths = None
|
|
|
+
|
|
|
+ def _tokenize_text(self, text: str) -> str:
|
|
|
+ """对文本进行分词处理
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text: 输入文本
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: 分词后的文本
|
|
|
+ """
|
|
|
+ if not self.use_jieba:
|
|
|
+ return text
|
|
|
+
|
|
|
+ # 使用结巴分词,返回空格分隔的词语
|
|
|
+ return " ".join(jieba.cut(text))
|
|
|
+
|
|
|
+ def _preprocess_texts(self, texts: List[str]) -> List[str]:
|
|
|
+ """预处理文本列表
|
|
|
+
|
|
|
+ Args:
|
|
|
+ texts: 文本列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[str]: 预处理后的文本列表
|
|
|
+ """
|
|
|
+ processed_texts = []
|
|
|
+ for text in texts:
|
|
|
+ # 去除标点符号、特殊字符和多余空白字符
|
|
|
+ # 匹配中文标点符号(包括圆点•)
|
|
|
+ text = re.sub(r'[\u3000-\u303F]|[\uFF00-\uFFEF]|[•]', ' ', text)
|
|
|
+ # 匹配英文标点符号和特殊字符
|
|
|
+ text = re.sub(r'[!"#$%&\'()*+,-./:;<=>?@\[\\\]^_`{|}~]', ' ', text)
|
|
|
+ # 匹配其他特殊字符(如制表符、换行符等)
|
|
|
+ text = re.sub(r'[\t\n\r\f\v]', ' ', text)
|
|
|
+ # 去除多余空白字符
|
|
|
+ text = re.sub(r'\s+', ' ', text).strip()
|
|
|
+ # 分词处理
|
|
|
+ text = self._tokenize_text(text)
|
|
|
+ processed_texts.append(text)
|
|
|
+ return processed_texts
|
|
|
+
|
|
|
+ def _initialize_vectorizer(self):
|
|
|
+ """初始化文本向量化器"""
|
|
|
+ if self.method == 'tfidf':
|
|
|
+ self.vectorizer = TfidfVectorizer(ngram_range=self.ngram_range)
|
|
|
+ elif self.method == 'count':
|
|
|
+ self.vectorizer = CountVectorizer(ngram_range=self.ngram_range)
|
|
|
+ elif self.method == 'jaccard':
|
|
|
+ # Jaccard相似度不需要向量化器,使用集合操作
|
|
|
+ self.vectorizer = None
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的相似度计算方法: {self.method}")
|
|
|
+
|
|
|
+ def _calculate_jaccard_similarity(self, text1: str, text2: str) -> float:
|
|
|
+ """计算两个文本的Jaccard相似度
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text1: 第一个文本
|
|
|
+ text2: 第二个文本
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ float: Jaccard相似度
|
|
|
+ """
|
|
|
+ # 将文本转换为词集合
|
|
|
+ set1 = set(text1.split())
|
|
|
+ set2 = set(text2.split())
|
|
|
+
|
|
|
+ # 计算Jaccard相似度: 交集大小 / 并集大小
|
|
|
+ intersection = len(set1.intersection(set2))
|
|
|
+ union = len(set1.union(set2))
|
|
|
+
|
|
|
+ # 避免除以零
|
|
|
+ if union == 0:
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ return intersection / union
|
|
|
+
|
|
|
+ def load_corpus_from_directory(self, directory: str, file_pattern: str = '*.txt'):
|
|
|
+ """从目录加载语料库
|
|
|
+
|
|
|
+ Args:
|
|
|
+ directory: 目录路径
|
|
|
+ file_pattern: 文件匹配模式
|
|
|
+ """
|
|
|
+ import glob
|
|
|
+
|
|
|
+ # 获取所有匹配的文件路径
|
|
|
+ file_paths = glob.glob(os.path.join(directory, file_pattern))
|
|
|
+
|
|
|
+ if not file_paths:
|
|
|
+ raise ValueError(f"在目录 {directory} 中没有找到匹配 {file_pattern} 的文件")
|
|
|
+
|
|
|
+ # 读取所有文件内容
|
|
|
+ texts = []
|
|
|
+ for file_path in file_paths:
|
|
|
+ try:
|
|
|
+ with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
+ texts.append(f.read())
|
|
|
+ except Exception as e:
|
|
|
+ print(f"读取文件 {file_path} 失败: {str(e)}")
|
|
|
+
|
|
|
+ self.load_corpus(texts, file_paths)
|
|
|
+
|
|
|
+ def load_corpus(self, texts: List[str], paths: Optional[List[str]] = None):
|
|
|
+ """加载语料库
|
|
|
+
|
|
|
+ Args:
|
|
|
+ texts: 文本列表
|
|
|
+ paths: 文本对应的路径或标识符列表
|
|
|
+ """
|
|
|
+ if not texts:
|
|
|
+ raise ValueError("文本列表不能为空")
|
|
|
+
|
|
|
+ # 预处理文本
|
|
|
+ processed_texts = self._preprocess_texts(texts)
|
|
|
+
|
|
|
+ # 初始化向量化器
|
|
|
+ self._initialize_vectorizer()
|
|
|
+
|
|
|
+ # 存储原始文本和路径
|
|
|
+ self.corpus_texts = texts
|
|
|
+ self.corpus_paths = paths if paths else [f"text_{i}" for i in range(len(texts))]
|
|
|
+
|
|
|
+ # 向量化语料库
|
|
|
+ if self.method in ['tfidf', 'count']:
|
|
|
+ self.corpus_vectors = self.vectorizer.fit_transform(processed_texts)
|
|
|
+ else:
|
|
|
+ # 对于Jaccard相似度,直接存储处理后的文本
|
|
|
+ self.corpus_vectors = processed_texts
|
|
|
+
|
|
|
+ def find_most_similar(self, query_text: str, top_n: int = 1) -> List[Dict]:
|
|
|
+ """查找与查询文本最相似的文本
|
|
|
+
|
|
|
+ Args:
|
|
|
+ query_text: 查询文本
|
|
|
+ top_n: 返回最相似的前N个结果
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 包含相似度信息的结果列表
|
|
|
+ """
|
|
|
+ if self.corpus_vectors is None:
|
|
|
+ raise ValueError("请先加载语料库")
|
|
|
+
|
|
|
+ # 预处理查询文本
|
|
|
+ processed_query = self._preprocess_texts([query_text])[0]
|
|
|
+
|
|
|
+ # 计算相似度
|
|
|
+ similarities = []
|
|
|
+
|
|
|
+ if self.method in ['tfidf', 'count']:
|
|
|
+ # 向量化查询文本
|
|
|
+ query_vector = self.vectorizer.transform([processed_query])
|
|
|
+
|
|
|
+ # 计算余弦相似度
|
|
|
+ similarity_matrix = cosine_similarity(query_vector, self.corpus_vectors)
|
|
|
+ similarities = similarity_matrix[0]
|
|
|
+ else: # Jaccard相似度
|
|
|
+ for corpus_text in self.corpus_vectors:
|
|
|
+ similarity = self._calculate_jaccard_similarity(processed_query, corpus_text)
|
|
|
+ similarities.append(similarity)
|
|
|
+
|
|
|
+ # 获取相似度排序的索引
|
|
|
+ top_indices = np.argsort(similarities)[::-1][:top_n]
|
|
|
+
|
|
|
+ # 构建结果
|
|
|
+ results = []
|
|
|
+ for idx in top_indices:
|
|
|
+ results.append({
|
|
|
+ 'text': self.corpus_texts[idx],
|
|
|
+ 'path': self.corpus_paths[idx],
|
|
|
+ 'similarity': similarities[idx]
|
|
|
+ })
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+def main():
|
|
|
+ # 解析命令行参数
|
|
|
+ parser = argparse.ArgumentParser(description='文本相似度查找工具')
|
|
|
+ parser.add_argument('--query', type=str, required=True, help='查询文本或查询文本文件路径')
|
|
|
+ parser.add_argument('--dir', type=str, required=True, help='语料库目录路径')
|
|
|
+ parser.add_argument('--pattern', type=str, default='*.txt', help='文件匹配模式')
|
|
|
+ parser.add_argument('--method', type=str, default='tfidf', choices=['tfidf', 'count', 'jaccard'], help='相似度计算方法')
|
|
|
+ parser.add_argument('--top_n', type=int, default=3, help='返回最相似的前N个结果')
|
|
|
+ parser.add_argument('--no_jieba', action='store_true', help='不使用结巴分词')
|
|
|
+ parser.add_argument('--ngram_min', type=int, default=1, help='N-gram最小值')
|
|
|
+ parser.add_argument('--ngram_max', type=int, default=1, help='N-gram最大值')
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ # 初始化文本相似度查找器
|
|
|
+ finder = TextSimilarityFinder(
|
|
|
+ method=args.method,
|
|
|
+ use_jieba=not args.no_jieba,
|
|
|
+ ngram_range=(args.ngram_min, args.ngram_max)
|
|
|
+ )
|
|
|
+
|
|
|
+ # 加载语料库
|
|
|
+ finder.load_corpus_from_directory(args.dir, args.pattern)
|
|
|
+
|
|
|
+ # 获取查询文本
|
|
|
+ query_text = args.query
|
|
|
+ if os.path.isfile(query_text):
|
|
|
+ try:
|
|
|
+ with open(query_text, 'r', encoding='utf-8') as f:
|
|
|
+ query_text = f.read()
|
|
|
+ except Exception as e:
|
|
|
+ print(f"读取查询文件失败: {str(e)}")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 查找最相似的文本
|
|
|
+ results = finder.find_most_similar(query_text, args.top_n)
|
|
|
+
|
|
|
+ # 打印结果
|
|
|
+ print(f"\n查询文本: {query_text[:100]}..." if len(query_text) > 100 else f"\n查询文本: {query_text}")
|
|
|
+ print(f"\n找到 {len(results)} 个最相似的文本:")
|
|
|
+
|
|
|
+ for i, result in enumerate(results):
|
|
|
+ print(f"\n{i+1}. 相似度: {result['similarity']:.4f}")
|
|
|
+ print(f" 路径: {result['path']}")
|
|
|
+ text_preview = result['text'][:200] + '...' if len(result['text']) > 200 else result['text']
|
|
|
+ print(f" 内容预览: {text_preview}")
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|