SGTY hai 2 semanas
pai
achega
bc11d5a260

+ 1 - 1
requirements.txt

@@ -11,7 +11,7 @@ psycopg2-binary==2.9.10
 python-dotenv==1.0.0
 python-dotenv==1.0.0
 hui-tools[all]==0.5.8
 hui-tools[all]==0.5.8
 cachetools==6.1.0
 cachetools==6.1.0
-
+jieba==0.42.1
 
 
 
 
 
 

+ 21 - 0
src/knowledge/model/trunks_model.py

@@ -0,0 +1,21 @@
+from sqlalchemy import Column, Integer, Text, String
+from sqlalchemy.dialects.postgresql import TSVECTOR
+from pgvector.sqlalchemy import Vector
+from ..db.base_class import Base
+
+class Trunks(Base):
+    __tablename__ = 'trunks'
+
+    id = Column(Integer, primary_key=True, index=True)
+    embedding = Column(Vector(1024))
+    content = Column(Text)
+    file_path = Column(String(255))
+    content_tsvector = Column(TSVECTOR)
+    type = Column(String(255))
+    title = Column(String(255))
+    referrence = Column(String(255))
+    meta_header = Column(String(255))
+    page_no = Column(Integer) 
+
+    def __repr__(self):
+        return f"<Trunks(id={self.id}, file_path={self.file_path})>"

+ 215 - 0
src/knowledge/router/text_search.py

@@ -0,0 +1,215 @@
+from fastapi import APIRouter, HTTPException, Depends
+from pydantic import BaseModel, Field, validator
+from typing import List, Optional
+from ..service.trunks_service import TrunksService
+from ..utils.sentence_util import SentenceUtil
+
+from ..model.response import StandardResponse
+
+import os
+
+import logging
+
+from ..db.session import get_db
+from sqlalchemy.orm import Session
+
+
+from cachetools import TTLCache
+
+# 使用TextSimilarityFinder进行文本相似度匹配
+from ..utils.text_similarity import TextSimilarityFinder
+
+logger = logging.getLogger(__name__)
+router = APIRouter(tags=["Text Search"])
+DISTANCE_THRESHOLD = 0.73
+# 创建全局缓存实例
+#cache = TTLCache(maxsize=1000, ttl=3600)
+
+class FindSimilarTexts(BaseModel):
+    keywords:Optional[List[str]] = None
+    search_text: str
+
+@router.post("/knowledge/text/find_similar_texts", response_model=StandardResponse)
+async def find_similar_texts(request: FindSimilarTexts, db: Session = Depends(get_db)):
+    trunks_service = TrunksService()
+    search_text = request.search_text
+    if request.keywords:
+        search_text = f"{request.keywords}:{search_text}"
+    # 使用向量搜索获取相似内容
+    search_results = trunks_service.search_by_vector(
+        text=search_text,
+        limit=500,
+        type='trunk',
+        distance=0.7
+    )
+
+    # 准备语料库数据
+    trunk_texts = []
+    trunk_ids = []
+
+    # 创建一个字典来存储trunk的详细信息
+    trunk_details = {}
+    if len(search_results) == 0:
+        return StandardResponse(success=True)
+    for trunk in search_results:
+        trunk_texts.append(trunk.get('content'))
+        trunk_ids.append(trunk.get('id'))
+        # 缓存trunk的详细信息
+        trunk_details[trunk.get('id')] = {
+            'id': trunk.get('id'),
+            'content': trunk.get('content'),
+            'file_path': trunk.get('file_path'),
+            'title': trunk.get('title'),
+            'referrence': trunk.get('referrence'),
+            'page_no': trunk.get('page_no')
+        }
+
+    # 初始化TextSimilarityFinder并加载语料库
+    similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
+    similarity_finder.load_corpus(trunk_texts, trunk_ids)
+
+    similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
+    prop_result = {}
+    # 处理搜索结果
+    if similar_results and similar_results[0]['similarity'] >= 0.3:  # 设置相似度阈值
+        # 获取最相似的文本对应的trunk_id
+        trunk_id = similar_results[0]['path']
+
+        # 从缓存中获取trunk详细信息
+        trunk_info = trunk_details.get(trunk_id)
+
+        if trunk_info:
+            search_result = {
+                **trunk_info,
+                'distance': similar_results[0]['similarity']  # 转换相似度为距离
+            }
+
+            reference, _ = _process_search_result(search_result, 1)
+            prop_result["references"] = [reference]
+            prop_result["answer"] = [{
+                 "sentence": request.search_text,
+                 "index": "1"
+            }]
+    else:
+        # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
+        sentences = SentenceUtil.split_text(request.search_text, 10)
+        result_sentences, references = _process_sentence_search_keywords(
+            sentences, trunks_service,keywords=request.keywords
+        )
+        if references:
+            prop_result["references"] = references
+        if result_sentences:
+            prop_result["answer"] = result_sentences
+    return StandardResponse(success=True,data=prop_result)
+
+def _process_search_result(search_result: dict, reference_index: int) -> tuple[dict, str]:
+    """处理搜索结果,返回引用信息和文件名"""
+    file_name = ""
+    referrence = search_result.get("referrence", "")
+    if referrence and "/books/" in referrence:
+        file_name = referrence.split("/books/")[-1]
+        file_name = os.path.splitext(file_name)[0]
+
+    reference = {
+        "index": str(reference_index),
+        "id": search_result["id"],
+        "content": search_result["content"],
+        "file_path": search_result.get("file_path", ""),
+        "title": search_result.get("title", ""),
+        "distance": search_result.get("distance", DISTANCE_THRESHOLD),
+        "page_no": search_result.get("page_no", ""),
+        "file_name": file_name,
+        "referrence": referrence
+    }
+    return reference, file_name
+
+
+def _process_sentence_search(node_name: str, prop_title: str, sentences: list, trunks_service: TrunksService) -> tuple[
+    list, list]:
+    keywords = [node_name, prop_title] if node_name and prop_title else None
+    return _process_sentence_search_keywords(sentences, trunks_service, keywords=keywords)
+
+
+def _process_sentence_search_keywords(sentences: list, trunks_service: TrunksService,
+                                      keywords: Optional[List[str]] = None) -> tuple[list, list]:
+    """处理句子搜索,返回结果句子和引用列表"""
+    result_sentences = []
+    all_references = []
+    reference_index = 1
+    i = 0
+
+    while i < len(sentences):
+        sentence = sentences[i]
+        search_text = sentence
+        if keywords:
+            search_text = f"{keywords}:{sentence}"
+
+        i += 1
+
+        # 使用向量搜索获取相似内容
+        search_results = trunks_service.search_by_vector(
+            text=search_text,
+            limit=500,
+            type='trunk',
+            distance=0.7
+        )
+
+        # 准备语料库数据
+        trunk_texts = []
+        trunk_ids = []
+        # 创建一个字典来存储trunk的详细信息
+        trunk_details = {}
+
+        for trunk in search_results:
+            trunk_texts.append(trunk.get('content'))
+            trunk_ids.append(trunk.get('id'))
+            # 缓存trunk的详细信息
+            trunk_details[trunk.get('id')] = {
+                'id': trunk.get('id'),
+                'content': trunk.get('content'),
+                'file_path': trunk.get('file_path'),
+                'title': trunk.get('title'),
+                'referrence': trunk.get('referrence'),
+                'page_no': trunk.get('page_no')
+            }
+        if len(trunk_texts) == 0:
+            continue
+        # 初始化TextSimilarityFinder并加载语料库
+        similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
+        similarity_finder.load_corpus(trunk_texts, trunk_ids)
+
+        # 使用TextSimilarityFinder进行相似度匹配
+        similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
+
+        if not similar_results:  # 设置相似度阈值
+            result_sentences.append({"sentence": sentence, "index": ""})
+            continue
+
+        # 获取最相似的文本对应的trunk_id
+        trunk_id = similar_results[0]['path']
+
+        # 从缓存中获取trunk详细信息
+        trunk_info = trunk_details.get(trunk_id)
+
+        if trunk_info:
+            search_result = {
+                **trunk_info,
+                'distance': similar_results[0]['similarity']  # 转换相似度为距离
+            }
+            # 检查相似度是否达到阈值
+            if search_result['distance'] >= DISTANCE_THRESHOLD:
+                result_sentences.append({"sentence": sentence, "index": ""})
+                continue
+                # 检查是否已存在相同引用
+            existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
+            current_index = int(existing_ref["index"]) if existing_ref else reference_index
+
+            if not existing_ref:
+                reference, _ = _process_search_result(search_result, reference_index)
+                all_references.append(reference)
+                reference_index += 1
+
+            result_sentences.append({"sentence": sentence, "index": str(current_index)})
+
+    return result_sentences, all_references
+text_search_router = router

+ 5 - 1
src/knowledge/server.py

@@ -12,7 +12,9 @@ from .config.site import SiteConfig
 from .db.session import get_db
 from .db.session import get_db
 from .middlewares.base import register_middlewares
 from .middlewares.base import register_middlewares
 from .model.response import StandardResponse
 from .model.response import StandardResponse
+from .router.graph_api import graph_router
 from .router.knowledge_nodes_api import knowledge_nodes_api_router, get_request_id, api_key_header
 from .router.knowledge_nodes_api import knowledge_nodes_api_router, get_request_id, api_key_header
+from .router.text_search import text_search_router
 from .service.kg_edge_service import KGEdgeService
 from .service.kg_edge_service import KGEdgeService
 from .service.kg_node_service import KGNodeService
 from .service.kg_node_service import KGNodeService
 from .utils import log_util
 from .utils import log_util
@@ -125,7 +127,9 @@ async def startup():
 
 
     # 加载路由
     # 加载路由
     app.include_router(knowledge_nodes_api_router)
     app.include_router(knowledge_nodes_api_router)
-
+    app.include_router(text_search_router)
+    app.include_router(graph_router)
+    
     logger.info("fastapi startup success")
     logger.info("fastapi startup success")
 
 
 
 

+ 319 - 0
src/knowledge/service/trunks_service.py

@@ -0,0 +1,319 @@
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+from ..db.session import get_db
+from typing import List, Optional
+from ..model.trunks_model import Trunks
+from ..db.session import SessionLocal
+import logging
+
+from ..utils.sentence_util import SentenceUtil
+from ..utils.vectorizer import Vectorizer
+
+logger = logging.getLogger(__name__)
+
+class TrunksService:
+    def __init__(self):
+        self.db = next(get_db())
+
+
+    def create_trunk(self, trunk_data: dict) -> Trunks:
+        # 自动生成向量和全文检索字段
+        content = trunk_data.get('content')
+        if 'embedding' in trunk_data and len(trunk_data['embedding']) != 1024:
+            raise ValueError("向量维度必须为1024")
+        trunk_data['embedding'] = Vectorizer.get_embedding(content)
+        if 'type' not in trunk_data:
+            trunk_data['type'] = 'default'
+        if 'title' not in trunk_data:
+            from pathlib import Path
+            trunk_data['title'] = Path(trunk_data['file_path']).stem
+        print("embedding length:", len(trunk_data['embedding']))
+        logger.debug(f"生成的embedding长度: {len(trunk_data['embedding'])}, 内容摘要: {content[:20]}")
+        # trunk_data['content_tsvector'] = func.to_tsvector('chinese', content)
+        
+        
+        db = SessionLocal()
+        try:
+            trunk = Trunks(**trunk_data)
+            db.add(trunk)
+            db.commit()
+            db.refresh(trunk)
+            return trunk
+        except Exception as e:
+            db.rollback()
+            logger.error(f"创建trunk失败: {str(e)}")
+            raise
+        finally:
+            db.close()
+
+    def get_trunk_by_id(self, trunk_id: int) -> Optional[dict]:
+        db = SessionLocal()
+        try:
+            trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
+            if trunk:
+                return {
+                    'id': trunk.id,
+                    'file_path': trunk.file_path,
+                    'content': trunk.content,
+                    'embedding': trunk.embedding.tolist(),
+                    'type': trunk.type,
+                    'title':trunk.title
+                }
+            return None
+        finally:
+            db.close()
+
+    def search_by_vector(self, text: str, limit: int = 10, file_path: Optional[str]=None, distance: Optional[float]=None, type: Optional[str] = None, conversation_id: Optional[str] = None) -> List[dict]:
+       
+        embedding = Vectorizer.get_instance().get_embedding(text)
+        db = SessionLocal()
+        try:
+            query = db.query(
+                Trunks.id,
+                Trunks.file_path,
+                Trunks.content,
+                Trunks.embedding.l2_distance(embedding).label('distance'),
+                Trunks.title,
+                Trunks.embedding,
+                Trunks.page_no,
+                Trunks.referrence,
+                Trunks.meta_header
+            )
+            if distance:
+                query = query.filter(Trunks.embedding.l2_distance(embedding) <= distance)
+            if type:
+                query = query.filter(Trunks.type == type)
+            if file_path:
+                query = query.filter(Trunks.file_path.like('%'+file_path+'%'))
+            results = query.order_by('distance').limit(limit).all()
+            result_list = [{
+                'id': r.id,
+                'file_path': r.file_path,
+                'content': r.content,
+                #保留小数点后三位   
+                'distance': round(r.distance, 3),
+                'title': r.title,
+                'embedding': r.embedding.tolist(),
+                'page_no': r.page_no,
+                'referrence': r.referrence,
+                'meta_header': r.meta_header
+            } for r in results]
+
+            if conversation_id:
+                self.set_cache(conversation_id, result_list)
+
+            return result_list
+        finally:
+            db.close()
+
+    def fulltext_search(self, query: str) -> List[Trunks]:
+        db = SessionLocal()
+        try:
+            return db.query(Trunks).filter(
+                Trunks.content_tsvector.match(query)
+            ).all()
+        finally:
+            db.close()
+
+    def update_trunk(self, trunk_id: int, update_data: dict) -> Optional[Trunks]:
+        if 'content' in update_data:
+            content = update_data['content']
+            update_data['embedding'] = Vectorizer.get_embedding(content)
+            if 'type' not in update_data:
+                update_data['type'] = 'default'
+            logger.debug(f"更新生成的embedding长度: {len(update_data['embedding'])}, 内容摘要: {content[:20]}")
+            # update_data['content_tsvector'] = func.to_tsvector('chinese', content)
+
+        db = SessionLocal()
+        try:
+            trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
+            if trunk:
+                for key, value in update_data.items():
+                    setattr(trunk, key, value)
+                db.commit()
+                db.refresh(trunk)
+            return trunk
+        except Exception as e:
+            db.rollback()
+            logger.error(f"更新trunk失败: {str(e)}")
+            raise
+        finally:
+            db.close()
+
+    def delete_trunk(self, trunk_id: int) -> bool:
+        db = SessionLocal()
+        try:
+            trunk = db.query(Trunks).filter(Trunks.id == trunk_id).first()
+            if trunk:
+                db.delete(trunk)
+                db.commit()
+                return True
+            return False
+        except Exception as e:
+            db.rollback()
+            logger.error(f"删除trunk失败: {str(e)}")
+            raise
+        finally:
+            db.close()
+
+    def highlight(self, trunk_id: int, targetSentences: List[str]) -> List[int]:
+        trunk = self.get_trunk_by_id(trunk_id)
+        if not trunk:
+            return []
+
+        content = trunk['content']
+        sentence_util = SentenceUtil()
+        cleanedContent = sentence_util.clean_text(content)
+
+        result = []
+        for i, targetSentence in enumerate(targetSentences):
+            cleanedTarget = sentence_util.clean_text(targetSentence)
+            #cleanedTarget长度小于5的不进行匹配
+            if len(cleanedTarget)<5:
+                continue
+            if cleanedTarget in cleanedContent:
+                result.append(i)
+
+        # 补齐连续下标
+        if result:
+            result.sort()
+            filled_result = []
+            prev = result[0]
+            filled_result.append(prev)
+            for current in result[1:]:
+                if current - prev <= 2:
+                    for i in range(prev + 1, current):
+                        filled_result.append(i)
+                filled_result.append(current)
+                prev = current
+            return filled_result
+        return result
+
+    _cache = {}
+
+    def get_cache(self, conversation_id: str) -> List[dict]:
+        """
+        根据conversation_id获取缓存结果
+        :param conversation_id: 会话ID
+        :return: 结果列表
+        """
+        return self._cache.get(conversation_id, [])
+
+    def set_cache(self, conversation_id: str, result: List[dict]) -> None:
+        """
+        设置缓存结果
+        :param conversation_id: 会话ID
+        :param result: 要缓存的结果
+        """
+        self._cache[conversation_id] = result
+
+    def get_cached_result(self, conversation_id: str) -> List[dict]:
+        """
+        根据conversation_id获取缓存结果
+        :param conversation_id: 会话ID
+        :return: 结果列表
+        """
+        return self.get_cache(conversation_id)
+
+    def paginated_search_by_type_and_filepath(self, search_params: dict) -> dict:
+        """
+        根据type和file_path进行分页查询
+        :param search_params: 包含pageNo, limit的字典
+        :return: 包含结果列表和分页信息的字典
+        """
+        page_no = search_params.get('pageNo', 1)
+        limit = search_params.get('limit', 10)
+        file_path = search_params.get('file_path', None)
+        type = search_params.get('type', None)
+        
+        if page_no < 1:
+            page_no = 1
+        if limit < 1:
+            limit = 10
+            
+        offset = (page_no - 1) * limit
+        
+        db = SessionLocal()
+        try:
+
+            query = db.query(
+                Trunks.id,
+                Trunks.file_path,
+                Trunks.content,
+                Trunks.type,
+                Trunks.title,
+                Trunks.meta_header
+            )
+            if type:
+                query = query.filter(Trunks.type == type)
+            if file_path:
+                query = query.filter(Trunks.file_path.like('%' + file_path + '%'))
+
+            query = query.filter(Trunks.page_no == None)
+            results = query.offset(offset).limit(limit).all()
+            
+            return {
+                'data': [{
+                    'id': r.id,
+                    'file_path': r.file_path,
+                    'content': r.content,
+                    'type': r.type,
+                    'title': r.title,
+                    'meta_header':r.meta_header
+                } for r in results]
+            }
+        finally:
+            db.close()
+        
+        
+
+    def paginated_search(self, search_params: dict) -> dict:
+        """
+        分页查询方法
+        :param search_params: 包含keyword, pageNo, limit的字典
+        :return: 包含结果列表和分页信息的字典
+        """
+        keyword = search_params.get('keyword', '')
+        page_no = search_params.get('pageNo', 1)
+        limit = search_params.get('limit', 10)
+        
+        if page_no < 1:
+            page_no = 1
+        if limit < 1:
+            limit = 10
+            
+        embedding = Vectorizer.get_embedding(keyword)
+        offset = (page_no - 1) * limit
+        
+        db = SessionLocal()
+        try:
+            # 获取总条数
+            total_count = db.query(func.count(Trunks.id)).filter(Trunks.type == search_params.get('type')).scalar()
+            
+            # 执行向量搜索
+            results = db.query(
+                Trunks.id,
+                Trunks.file_path,
+                Trunks.content,
+                Trunks.embedding.l2_distance(embedding).label('distance'),
+                Trunks.title
+            ).filter(Trunks.type == search_params.get('type')).order_by('distance').offset(offset).limit(limit).all()
+            
+            return {
+                'data': [{
+                    'id': r.id,
+                    'file_path': r.file_path,
+                    'content': r.content,
+                    #保留小数点后三位   
+                    'distance': round(r.distance, 3),
+                    'title': r.title
+                } for r in results],
+                'pagination': {
+                    'total': total_count,
+                    'pageNo': page_no,
+                    'limit': limit,
+                    'totalPages': (total_count + limit - 1) // limit
+                }
+            }
+        finally:
+            db.close()

+ 185 - 0
src/knowledge/utils/sentence_util.py

@@ -0,0 +1,185 @@
+import re
+from typing import List
+import logging
+import argparse
+import sys
+
+logger = logging.getLogger(__name__)
+
+class SentenceUtil:
+    """中文文本句子拆分工具类
+    
+    用于将中文文本按照标点符号拆分成句子列表
+    """
+    
+    def __init__(self):
+        # 定义结束符号,包括常见的中文和英文标点
+        self.end_symbols = ['。', '!', '?', '!', '?', '\n']
+        # 定义引号对
+        self.quote_pairs = [("'", "'"), ('"', '"'), ('「', '」'), ('『', '』'), ('(', ')'), ('(', ')')]
+        
+    @staticmethod
+    def split_text(text: str, length: int = None) -> List[str]:
+        """将文本拆分成句子列表
+        
+        Args:
+            text: 输入的文本字符串
+            length: 可选参数,指定拆分后句子的最大长度
+            
+        Returns:
+            拆分后的句子列表
+        """
+        sentences = SentenceUtil()._split(text)
+        if length is not None:
+            i = 0
+            while i < len(sentences):
+                if SentenceUtil().get_valid_length(sentences[i]) <= length and i + 1 < len(sentences):
+                    sentences[i] = sentences[i] + sentences[i+1]
+                    del sentences[i+1]
+                else:
+                    i += 1
+        return sentences
+        
+    def _split(self, text: str) -> List[str]:
+        """内部拆分方法
+        
+        Args:
+            text: 输入的文本字符串
+            length: 可选参数,指定拆分后句子的最大长度
+            
+        Returns:
+            拆分后的句子列表
+        """
+        if not text or not text.strip():
+            return []
+        
+        try:       
+            # 通用拆分逻辑
+            sentences = []
+            current_sentence = ""
+            
+            # 用于跟踪引号状态的栈
+            quote_stack = []
+            
+            i = 0
+            while i < len(text):
+                char = text[i]
+                current_sentence += char
+                
+                # 处理引号开始
+                for start, end in self.quote_pairs:
+                    if char == start:
+                        if not quote_stack or quote_stack[-1][0] != end:
+                            quote_stack.append((end, i))
+                            break
+                
+                # 处理引号闭合
+                if quote_stack and char == quote_stack[-1][0] and i > quote_stack[-1][1]:
+                    quote_stack.pop()
+                
+                # 处理结束符号,仅在非引号环境中
+                if not quote_stack and char in self.end_symbols:
+                    if current_sentence.strip():
+                        # 保留句子末尾的换行符
+                        if char == '\n':
+                            current_sentence = current_sentence.rstrip('\n')
+                            sentences.append(current_sentence)
+                            current_sentence = '\n'
+                        else:
+                            sentences.append(current_sentence)
+                            current_sentence = ""
+                    
+                    # 处理空格 - 保留空格在下一个句子的开头
+                    if i + 1 < len(text) and text[i + 1].isspace() and text[i + 1] != '\n':
+                        i += 1
+                        current_sentence = text[i]
+                
+                i += 1
+            
+            # 处理循环结束时的剩余内容
+            if current_sentence.strip():
+                sentences.append(current_sentence)
+            
+            # 如果没有找到任何句子,返回原文本作为一个句子
+            if not sentences:
+                return [text]
+            
+            return sentences
+            
+        except Exception as e:
+            logger.error(f"拆分文本时发生错误: {str(e)}")
+            return []
+    
+    @staticmethod
+    def clean_text(text: str) -> str:
+        """去除除中英文和数字以外的所有字符
+        
+        Args:
+            text: 输入的文本字符串
+            
+        Returns:
+            处理后的字符串
+        """
+        if not text:
+            return text
+        return re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9]', '', text)
+
+    @staticmethod
+    def get_valid_length(text: str) -> int:
+        """计算只包含中英文和数字的有效长度
+        
+        Args:
+            text: 输入的文本字符串
+            
+        Returns:
+            有效字符的长度
+        """
+        if not text:
+            return 0
+        return len(re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9]', '', text))
+
+    def split_by_regex(self, text: str) -> List[str]:
+        """使用正则表达式拆分文本
+        
+        这是一个备选方法,使用正则表达式进行拆分
+        
+        Args:
+            text: 输入的文本字符串
+            
+        Returns:
+            拆分后的句子列表
+        """
+        if not text or not text.strip():
+            return []
+            
+        try:
+            # 使用正则表达式拆分,保留分隔符
+            pattern = r'([。!?!?]|\n)'
+            parts = re.split(pattern, text)
+            
+            # 组合分隔符与前面的部分
+            sentences = []
+            for i in range(0, len(parts), 2):
+                if i + 1 < len(parts):
+                    sentences.append(parts[i] + parts[i+1])
+                else:
+                    # 处理最后一个部分(如果没有对应的分隔符)
+                    if parts[i].strip():
+                        sentences.append(parts[i])
+            
+            return sentences
+        except Exception as e:
+            logger.error(f"使用正则表达式拆分文本时发生错误: {str(e)}")
+            return [text] if text else []
+
+
+if __name__ == '__main__':
+    input_text = """急性期护理:
+- 每4h评估腹痛程度 3-1 PDF
+延续护理: 1-2 PDF
+患者教育: 3-3 PDF
+- 识别复发症状(发热/黄疸)"""
+    sentences = SentenceUtil.split_text(input_text,10)
+    for sentence in sentences:
+        print(sentence)
+        print('-----------')

+ 252 - 0
src/knowledge/utils/text_similarity.py

@@ -0,0 +1,252 @@
+import os
+import numpy as np
+import re
+from typing import List, Dict, Tuple, Optional, Union, Callable
+from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
+from sklearn.metrics.pairwise import cosine_similarity
+import jieba
+import argparse
+
+class TextSimilarityFinder:
+    """文本相似度查找器,用于从多个文本中找到与目标文本最相似的文本"""
+    
+    def __init__(self, method: str = 'tfidf', use_jieba: bool = True, ngram_range: Tuple[int, int] = (1, 1)):
+        """
+        初始化文本相似度查找器
+        
+        Args:
+            method: 相似度计算方法,可选值为 'tfidf'、'count'、'jaccard'
+            use_jieba: 是否使用结巴分词(中文文本建议开启)
+            ngram_range: N-gram范围,例如(1,2)表示同时使用unigram和bigram特征
+        """
+        self.method = method
+        self.use_jieba = use_jieba
+        self.ngram_range = ngram_range
+        self.vectorizer = None
+        self.corpus_vectors = None
+        self.corpus_texts = None
+        self.corpus_paths = None
+    
+    def _tokenize_text(self, text: str) -> str:
+        """对文本进行分词处理
+        
+        Args:
+            text: 输入文本
+            
+        Returns:
+            str: 分词后的文本
+        """
+        if not self.use_jieba:
+            return text
+        
+        # 使用结巴分词,返回空格分隔的词语
+        return " ".join(jieba.cut(text))
+    
+    def _preprocess_texts(self, texts: List[str]) -> List[str]:
+        """预处理文本列表
+        
+        Args:
+            texts: 文本列表
+            
+        Returns:
+            List[str]: 预处理后的文本列表
+        """
+        processed_texts = []
+        for text in texts:
+            # 去除标点符号、特殊字符和多余空白字符
+            # 匹配中文标点符号(包括圆点•)
+            text = re.sub(r'[\u3000-\u303F]|[\uFF00-\uFFEF]|[•]', ' ', text)
+            # 匹配英文标点符号和特殊字符
+            text = re.sub(r'[!"#$%&\'()*+,-./:;<=>?@\[\\\]^_`{|}~]', ' ', text)
+            # 匹配其他特殊字符(如制表符、换行符等)
+            text = re.sub(r'[\t\n\r\f\v]', ' ', text)
+            # 去除多余空白字符
+            text = re.sub(r'\s+', ' ', text).strip()
+            # 分词处理
+            text = self._tokenize_text(text)
+            processed_texts.append(text)
+        return processed_texts
+    
+    def _initialize_vectorizer(self):
+        """初始化文本向量化器"""
+        if self.method == 'tfidf':
+            self.vectorizer = TfidfVectorizer(ngram_range=self.ngram_range)
+        elif self.method == 'count':
+            self.vectorizer = CountVectorizer(ngram_range=self.ngram_range)
+        elif self.method == 'jaccard':
+            # Jaccard相似度不需要向量化器,使用集合操作
+            self.vectorizer = None
+        else:
+            raise ValueError(f"不支持的相似度计算方法: {self.method}")
+    
+    def _calculate_jaccard_similarity(self, text1: str, text2: str) -> float:
+        """计算两个文本的Jaccard相似度
+        
+        Args:
+            text1: 第一个文本
+            text2: 第二个文本
+            
+        Returns:
+            float: Jaccard相似度
+        """
+        # 将文本转换为词集合
+        set1 = set(text1.split())
+        set2 = set(text2.split())
+        
+        # 计算Jaccard相似度: 交集大小 / 并集大小
+        intersection = len(set1.intersection(set2))
+        union = len(set1.union(set2))
+        
+        # 避免除以零
+        if union == 0:
+            return 0.0
+        
+        return intersection / union
+    
+    def load_corpus_from_directory(self, directory: str, file_pattern: str = '*.txt'):
+        """从目录加载语料库
+        
+        Args:
+            directory: 目录路径
+            file_pattern: 文件匹配模式
+        """
+        import glob
+        
+        # 获取所有匹配的文件路径
+        file_paths = glob.glob(os.path.join(directory, file_pattern))
+        
+        if not file_paths:
+            raise ValueError(f"在目录 {directory} 中没有找到匹配 {file_pattern} 的文件")
+        
+        # 读取所有文件内容
+        texts = []
+        for file_path in file_paths:
+            try:
+                with open(file_path, 'r', encoding='utf-8') as f:
+                    texts.append(f.read())
+            except Exception as e:
+                print(f"读取文件 {file_path} 失败: {str(e)}")
+        
+        self.load_corpus(texts, file_paths)
+    
+    def load_corpus(self, texts: List[str], paths: Optional[List[str]] = None):
+        """加载语料库
+        
+        Args:
+            texts: 文本列表
+            paths: 文本对应的路径或标识符列表
+        """
+        if not texts:
+            raise ValueError("文本列表不能为空")
+        
+        # 预处理文本
+        processed_texts = self._preprocess_texts(texts)
+        
+        # 初始化向量化器
+        self._initialize_vectorizer()
+        
+        # 存储原始文本和路径
+        self.corpus_texts = texts
+        self.corpus_paths = paths if paths else [f"text_{i}" for i in range(len(texts))]
+        
+        # 向量化语料库
+        if self.method in ['tfidf', 'count']:
+            self.corpus_vectors = self.vectorizer.fit_transform(processed_texts)
+        else:
+            # 对于Jaccard相似度,直接存储处理后的文本
+            self.corpus_vectors = processed_texts
+    
+    def find_most_similar(self, query_text: str, top_n: int = 1) -> List[Dict]:
+        """查找与查询文本最相似的文本
+        
+        Args:
+            query_text: 查询文本
+            top_n: 返回最相似的前N个结果
+            
+        Returns:
+            List[Dict]: 包含相似度信息的结果列表
+        """
+        if self.corpus_vectors is None:
+            raise ValueError("请先加载语料库")
+        
+        # 预处理查询文本
+        processed_query = self._preprocess_texts([query_text])[0]
+        
+        # 计算相似度
+        similarities = []
+        
+        if self.method in ['tfidf', 'count']:
+            # 向量化查询文本
+            query_vector = self.vectorizer.transform([processed_query])
+            
+            # 计算余弦相似度
+            similarity_matrix = cosine_similarity(query_vector, self.corpus_vectors)
+            similarities = similarity_matrix[0]
+        else:  # Jaccard相似度
+            for corpus_text in self.corpus_vectors:
+                similarity = self._calculate_jaccard_similarity(processed_query, corpus_text)
+                similarities.append(similarity)
+        
+        # 获取相似度排序的索引
+        top_indices = np.argsort(similarities)[::-1][:top_n]
+        
+        # 构建结果
+        results = []
+        for idx in top_indices:
+            results.append({
+                'text': self.corpus_texts[idx],
+                'path': self.corpus_paths[idx],
+                'similarity': similarities[idx]
+            })
+        
+        return results
+
+def main():
+    # 解析命令行参数
+    parser = argparse.ArgumentParser(description='文本相似度查找工具')
+    parser.add_argument('--query', type=str, required=True, help='查询文本或查询文本文件路径')
+    parser.add_argument('--dir', type=str, required=True, help='语料库目录路径')
+    parser.add_argument('--pattern', type=str, default='*.txt', help='文件匹配模式')
+    parser.add_argument('--method', type=str, default='tfidf', choices=['tfidf', 'count', 'jaccard'], help='相似度计算方法')
+    parser.add_argument('--top_n', type=int, default=3, help='返回最相似的前N个结果')
+    parser.add_argument('--no_jieba', action='store_true', help='不使用结巴分词')
+    parser.add_argument('--ngram_min', type=int, default=1, help='N-gram最小值')
+    parser.add_argument('--ngram_max', type=int, default=1, help='N-gram最大值')
+    
+    args = parser.parse_args()
+    
+    # 初始化文本相似度查找器
+    finder = TextSimilarityFinder(
+        method=args.method,
+        use_jieba=not args.no_jieba,
+        ngram_range=(args.ngram_min, args.ngram_max)
+    )
+    
+    # 加载语料库
+    finder.load_corpus_from_directory(args.dir, args.pattern)
+    
+    # 获取查询文本
+    query_text = args.query
+    if os.path.isfile(query_text):
+        try:
+            with open(query_text, 'r', encoding='utf-8') as f:
+                query_text = f.read()
+        except Exception as e:
+            print(f"读取查询文件失败: {str(e)}")
+            return
+    
+    # 查找最相似的文本
+    results = finder.find_most_similar(query_text, args.top_n)
+    
+    # 打印结果
+    print(f"\n查询文本: {query_text[:100]}..." if len(query_text) > 100 else f"\n查询文本: {query_text}")
+    print(f"\n找到 {len(results)} 个最相似的文本:")
+    
+    for i, result in enumerate(results):
+        print(f"\n{i+1}. 相似度: {result['similarity']:.4f}")
+        print(f"   路径: {result['path']}")
+        text_preview = result['text'][:200] + '...' if len(result['text']) > 200 else result['text']
+        print(f"   内容预览: {text_preview}")
+
+if __name__ == "__main__":
+    main()