Browse Source

根据分词相似度查询属性对应的切片内容

yuchengwei 3 weeks ago
parent
commit
ebfff8c62f
3 changed files with 362 additions and 23 deletions
  1. 109 22
      router/text_search.py
  2. 1 1
      service/cdss_service.py
  3. 252 0
      utils/text_similarity.py

+ 109 - 22
router/text_search.py

@@ -19,6 +19,9 @@ from service.kg_prop_service import KGPropService
 
 from cachetools import TTLCache
 
+# 使用TextSimilarityFinder进行文本相似度匹配
+from utils.text_similarity import TextSimilarityFinder
+
 logger = logging.getLogger(__name__)
 router = APIRouter(tags=["Text Search"])
 
@@ -409,7 +412,8 @@ def _process_sentence_search(node_name: str, prop_title: str, sentences: list, t
     
     while i < len(sentences):
         sentence = sentences[i]
-        
+
+        search_text = f"{node_name}:{prop_title}:{sentence}"
         if len(sentence) < 10 and i + 1 < len(sentences):
             next_sentence = sentences[i + 1]
             result_sentences.append({"sentence": sentence, "flag": ""})
@@ -419,30 +423,72 @@ def _process_sentence_search(node_name: str, prop_title: str, sentences: list, t
             result_sentences.append({"sentence": sentence, "flag": ""})
             i += 1
             continue
-        else:
-            search_text = f"{node_name}:{prop_title}:{sentence}"
         
         i += 1
+
+        # 使用向量搜索获取相似内容
+        search_results = trunks_service.search_by_vector(
+            text=search_text,
+            limit=500,
+            type='trunk'
+        )
+        # 查询1000条切片数据
+        # db_trunks = db.query(Trunks).filter(Trunks.type == 'trunk').limit(1000).all()
+
+        # 准备语料库数据
+        trunk_texts = []
+        trunk_ids = []
+        # 创建一个字典来存储trunk的详细信息
+        trunk_details = {}
+
+        for trunk in search_results:
+            trunk_texts.append(trunk.get('content'))
+            trunk_ids.append(trunk.get('id'))
+            # 缓存trunk的详细信息
+            trunk_details[trunk.get('id')] = {
+                'id': trunk.get('id'),
+                'content': trunk.get('content'),
+                'file_path': trunk.get('file_path'),
+                'title': trunk.get('title'),
+                'referrence': trunk.get('referrence'),
+                'page_no': trunk.get('page_no')
+            }
+
+        # 初始化TextSimilarityFinder并加载语料库
+        similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
+        similarity_finder.load_corpus(trunk_texts, trunk_ids)
+
+        # 使用TextSimilarityFinder进行相似度匹配
+        similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
         
-        search_results = trunks_service.search_by_vector(text=search_text, limit=1, type='trunk')
-        
-        if not search_results:
+        if not similar_results:  # 设置相似度阈值
             result_sentences.append({"sentence": sentence, "flag": ""})
             continue
-            
-        for search_result in search_results:
-            if search_result.get("distance", DISTANCE_THRESHOLD) >= DISTANCE_THRESHOLD:
+        
+        # 获取最相似的文本对应的trunk_id
+        trunk_id = similar_results[0]['path']
+
+        # 从缓存中获取trunk详细信息
+        trunk_info = trunk_details.get(trunk_id)
+
+        if trunk_info:
+            search_result = {
+                **trunk_info,
+                'distance': similar_results[0]['similarity']  # 转换相似度为距离
+            }
+            # 检查相似度是否达到阈值
+            if search_result['distance'] >= DISTANCE_THRESHOLD:
                 result_sentences.append({"sentence": sentence, "flag": ""})
                 continue
-                
+                # 检查是否已存在相同引用
             existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
             current_index = int(existing_ref["index"]) if existing_ref else reference_index
-            
+
             if not existing_ref:
                 reference, _ = _process_search_result(search_result, reference_index)
                 all_references.append(reference)
                 reference_index += 1
-                
+
             result_sentences.append({"sentence": sentence, "flag": str(current_index)})
     
     return result_sentences, all_references
@@ -498,21 +544,62 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
 
             # 先用完整的prop_value进行搜索
             search_text = f"{node_name}:{prop_title}:{prop_value}"
-            full_search_results = trunks_service.search_by_vector(
+            # 使用向量搜索获取相似内容
+            search_results = trunks_service.search_by_vector(
                 text=search_text,
-                limit=1,
+                limit=500,
                 type='trunk'
             )
+            
+            # 准备语料库数据
+            trunk_texts = []
+            trunk_ids = []
+            
+            # 创建一个字典来存储trunk的详细信息
+            trunk_details = {}
+            
+            for trunk in search_results:
+                trunk_texts.append(trunk.get('content'))
+                trunk_ids.append(trunk.get('id'))
+                # 缓存trunk的详细信息
+                trunk_details[trunk.get('id')] = {
+                    'id': trunk.get('id'),
+                    'content': trunk.get('content'),
+                    'file_path': trunk.get('file_path'),
+                    'title': trunk.get('title'),
+                    'referrence': trunk.get('referrence'),
+                    'page_no': trunk.get('page_no')
+                }
+            
+            # 初始化TextSimilarityFinder并加载语料库
+            similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
+            similarity_finder.load_corpus(trunk_texts, trunk_ids)
 
+            similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
+            
             # 处理搜索结果
-            if full_search_results and full_search_results[0].get("distance", DISTANCE_THRESHOLD) < DISTANCE_THRESHOLD:
-                search_result = full_search_results[0]
-                reference, _ = _process_search_result(search_result, 1)
-                prop_result["references"] = [reference]
-                prop_result["answer"] = [{
-                    "sentence": prop_value,
-                    "flag": "1"
-                }]
+            if similar_results:  # 设置相似度阈值
+                # 获取最相似的文本对应的trunk_id
+                trunk_id = similar_results[0]['path']
+                
+                # 从缓存中获取trunk详细信息
+                trunk_info = trunk_details.get(trunk_id)
+                
+                if trunk_info:
+                    search_result = {
+                        **trunk_info,
+                        'distance': similar_results[0]['similarity']  # 转换相似度为距离
+                    }
+                    
+                    reference, _ = _process_search_result(search_result, 1)
+                    prop_result["references"] = [reference]
+                    prop_result["answer"] = [{
+                        "sentence": prop_value,
+                        "flag": "1"
+                    }]
+                else:
+                    # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
+                    sentences = SentenceUtil.split_text(prop_value)
             else:
                 # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
                 sentences = SentenceUtil.split_text(prop_value)

+ 1 - 1
service/cdss_service.py

@@ -6,7 +6,7 @@ import logging
 from sqlalchemy.exc import IntegrityError
 
 from service.kg_node_service import KGNodeService
-from tests.service.test_kg_node_service import kg_node_service
+# from tests.service.test_kg_node_service import kg_node_service
 from utils import vectorizer
 from utils.vectorizer import Vectorizer
 from sqlalchemy import func

+ 252 - 0
utils/text_similarity.py

@@ -0,0 +1,252 @@
+import os
+import numpy as np
+import re
+from typing import List, Dict, Tuple, Optional, Union, Callable
+from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
+from sklearn.metrics.pairwise import cosine_similarity
+import jieba
+import argparse
+
+class TextSimilarityFinder:
+    """文本相似度查找器,用于从多个文本中找到与目标文本最相似的文本"""
+    
+    def __init__(self, method: str = 'tfidf', use_jieba: bool = True, ngram_range: Tuple[int, int] = (1, 1)):
+        """
+        初始化文本相似度查找器
+        
+        Args:
+            method: 相似度计算方法,可选值为 'tfidf'、'count'、'jaccard'
+            use_jieba: 是否使用结巴分词(中文文本建议开启)
+            ngram_range: N-gram范围,例如(1,2)表示同时使用unigram和bigram特征
+        """
+        self.method = method
+        self.use_jieba = use_jieba
+        self.ngram_range = ngram_range
+        self.vectorizer = None
+        self.corpus_vectors = None
+        self.corpus_texts = None
+        self.corpus_paths = None
+    
+    def _tokenize_text(self, text: str) -> str:
+        """对文本进行分词处理
+        
+        Args:
+            text: 输入文本
+            
+        Returns:
+            str: 分词后的文本
+        """
+        if not self.use_jieba:
+            return text
+        
+        # 使用结巴分词,返回空格分隔的词语
+        return " ".join(jieba.cut(text))
+    
+    def _preprocess_texts(self, texts: List[str]) -> List[str]:
+        """预处理文本列表
+        
+        Args:
+            texts: 文本列表
+            
+        Returns:
+            List[str]: 预处理后的文本列表
+        """
+        processed_texts = []
+        for text in texts:
+            # 去除标点符号、特殊字符和多余空白字符
+            # 匹配中文标点符号(包括圆点•)
+            text = re.sub(r'[\u3000-\u303F]|[\uFF00-\uFFEF]|[•]', ' ', text)
+            # 匹配英文标点符号和特殊字符
+            text = re.sub(r'[!"#$%&\'()*+,-./:;<=>?@\[\\\]^_`{|}~]', ' ', text)
+            # 匹配其他特殊字符(如制表符、换行符等)
+            text = re.sub(r'[\t\n\r\f\v]', ' ', text)
+            # 去除多余空白字符
+            text = re.sub(r'\s+', ' ', text).strip()
+            # 分词处理
+            text = self._tokenize_text(text)
+            processed_texts.append(text)
+        return processed_texts
+    
+    def _initialize_vectorizer(self):
+        """初始化文本向量化器"""
+        if self.method == 'tfidf':
+            self.vectorizer = TfidfVectorizer(ngram_range=self.ngram_range)
+        elif self.method == 'count':
+            self.vectorizer = CountVectorizer(ngram_range=self.ngram_range)
+        elif self.method == 'jaccard':
+            # Jaccard相似度不需要向量化器,使用集合操作
+            self.vectorizer = None
+        else:
+            raise ValueError(f"不支持的相似度计算方法: {self.method}")
+    
+    def _calculate_jaccard_similarity(self, text1: str, text2: str) -> float:
+        """计算两个文本的Jaccard相似度
+        
+        Args:
+            text1: 第一个文本
+            text2: 第二个文本
+            
+        Returns:
+            float: Jaccard相似度
+        """
+        # 将文本转换为词集合
+        set1 = set(text1.split())
+        set2 = set(text2.split())
+        
+        # 计算Jaccard相似度: 交集大小 / 并集大小
+        intersection = len(set1.intersection(set2))
+        union = len(set1.union(set2))
+        
+        # 避免除以零
+        if union == 0:
+            return 0.0
+        
+        return intersection / union
+    
+    def load_corpus_from_directory(self, directory: str, file_pattern: str = '*.txt'):
+        """从目录加载语料库
+        
+        Args:
+            directory: 目录路径
+            file_pattern: 文件匹配模式
+        """
+        import glob
+        
+        # 获取所有匹配的文件路径
+        file_paths = glob.glob(os.path.join(directory, file_pattern))
+        
+        if not file_paths:
+            raise ValueError(f"在目录 {directory} 中没有找到匹配 {file_pattern} 的文件")
+        
+        # 读取所有文件内容
+        texts = []
+        for file_path in file_paths:
+            try:
+                with open(file_path, 'r', encoding='utf-8') as f:
+                    texts.append(f.read())
+            except Exception as e:
+                print(f"读取文件 {file_path} 失败: {str(e)}")
+        
+        self.load_corpus(texts, file_paths)
+    
+    def load_corpus(self, texts: List[str], paths: Optional[List[str]] = None):
+        """加载语料库
+        
+        Args:
+            texts: 文本列表
+            paths: 文本对应的路径或标识符列表
+        """
+        if not texts:
+            raise ValueError("文本列表不能为空")
+        
+        # 预处理文本
+        processed_texts = self._preprocess_texts(texts)
+        
+        # 初始化向量化器
+        self._initialize_vectorizer()
+        
+        # 存储原始文本和路径
+        self.corpus_texts = texts
+        self.corpus_paths = paths if paths else [f"text_{i}" for i in range(len(texts))]
+        
+        # 向量化语料库
+        if self.method in ['tfidf', 'count']:
+            self.corpus_vectors = self.vectorizer.fit_transform(processed_texts)
+        else:
+            # 对于Jaccard相似度,直接存储处理后的文本
+            self.corpus_vectors = processed_texts
+    
+    def find_most_similar(self, query_text: str, top_n: int = 1) -> List[Dict]:
+        """查找与查询文本最相似的文本
+        
+        Args:
+            query_text: 查询文本
+            top_n: 返回最相似的前N个结果
+            
+        Returns:
+            List[Dict]: 包含相似度信息的结果列表
+        """
+        if self.corpus_vectors is None:
+            raise ValueError("请先加载语料库")
+        
+        # 预处理查询文本
+        processed_query = self._preprocess_texts([query_text])[0]
+        
+        # 计算相似度
+        similarities = []
+        
+        if self.method in ['tfidf', 'count']:
+            # 向量化查询文本
+            query_vector = self.vectorizer.transform([processed_query])
+            
+            # 计算余弦相似度
+            similarity_matrix = cosine_similarity(query_vector, self.corpus_vectors)
+            similarities = similarity_matrix[0]
+        else:  # Jaccard相似度
+            for corpus_text in self.corpus_vectors:
+                similarity = self._calculate_jaccard_similarity(processed_query, corpus_text)
+                similarities.append(similarity)
+        
+        # 获取相似度排序的索引
+        top_indices = np.argsort(similarities)[::-1][:top_n]
+        
+        # 构建结果
+        results = []
+        for idx in top_indices:
+            results.append({
+                'text': self.corpus_texts[idx],
+                'path': self.corpus_paths[idx],
+                'similarity': similarities[idx]
+            })
+        
+        return results
+
+def main():
+    # 解析命令行参数
+    parser = argparse.ArgumentParser(description='文本相似度查找工具')
+    parser.add_argument('--query', type=str, required=True, help='查询文本或查询文本文件路径')
+    parser.add_argument('--dir', type=str, required=True, help='语料库目录路径')
+    parser.add_argument('--pattern', type=str, default='*.txt', help='文件匹配模式')
+    parser.add_argument('--method', type=str, default='tfidf', choices=['tfidf', 'count', 'jaccard'], help='相似度计算方法')
+    parser.add_argument('--top_n', type=int, default=3, help='返回最相似的前N个结果')
+    parser.add_argument('--no_jieba', action='store_true', help='不使用结巴分词')
+    parser.add_argument('--ngram_min', type=int, default=1, help='N-gram最小值')
+    parser.add_argument('--ngram_max', type=int, default=1, help='N-gram最大值')
+    
+    args = parser.parse_args()
+    
+    # 初始化文本相似度查找器
+    finder = TextSimilarityFinder(
+        method=args.method,
+        use_jieba=not args.no_jieba,
+        ngram_range=(args.ngram_min, args.ngram_max)
+    )
+    
+    # 加载语料库
+    finder.load_corpus_from_directory(args.dir, args.pattern)
+    
+    # 获取查询文本
+    query_text = args.query
+    if os.path.isfile(query_text):
+        try:
+            with open(query_text, 'r', encoding='utf-8') as f:
+                query_text = f.read()
+        except Exception as e:
+            print(f"读取查询文件失败: {str(e)}")
+            return
+    
+    # 查找最相似的文本
+    results = finder.find_most_similar(query_text, args.top_n)
+    
+    # 打印结果
+    print(f"\n查询文本: {query_text[:100]}..." if len(query_text) > 100 else f"\n查询文本: {query_text}")
+    print(f"\n找到 {len(results)} 个最相似的文本:")
+    
+    for i, result in enumerate(results):
+        print(f"\n{i+1}. 相似度: {result['similarity']:.4f}")
+        print(f"   路径: {result['path']}")
+        text_preview = result['text'][:200] + '...' if len(result['text']) > 200 else result['text']
+        print(f"   内容预览: {text_preview}")
+
+if __name__ == "__main__":
+    main()