Procházet zdrojové kódy

相似病例相关接口

yuchengwei před 3 měsíci
rodič
revize
df5fcd575b
2 změnil soubory, kde provedl 158 přidání a 5 odebrání
  1. 146 2
      router/text_search.py
  2. 12 3
      utils/file_reader.py

+ 146 - 2
router/text_search.py

@@ -1,12 +1,12 @@
 from fastapi import APIRouter, HTTPException
-from pydantic import BaseModel
+from pydantic import BaseModel, Field, validator
 from typing import List, Optional
 from service.trunks_service import TrunksService
 from utils.text_splitter import TextSplitter
 from utils.vector_distance import VectorDistance
 from model.response import StandardResponse
 from utils.vectorizer import Vectorizer
-DISTANCE_THRESHOLD = 0.65
+DISTANCE_THRESHOLD = 0.8
 import logging
 
 logger = logging.getLogger(__name__)
@@ -21,6 +21,35 @@ class TextCompareRequest(BaseModel):
     sentence: str
     text: str
 
+class TextMatchRequest(BaseModel):
+    text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
+
+    @validator('text')
+    def validate_text(cls, v):
+        # 保留所有可打印字符、换行符和中文字符
+        v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
+        
+        # 转义JSON特殊字符
+        # 先处理反斜杠,避免后续转义时出现问题
+        v = v.replace('\\', '\\\\')
+        # 处理引号和其他特殊字符
+        v = v.replace('"', '\\"')
+        v = v.replace('/', '\\/')
+        # 处理控制字符
+        v = v.replace('\n', '\\n')
+        v = v.replace('\r', '\\r')
+        v = v.replace('\t', '\\t')
+        v = v.replace('\b', '\\b')
+        v = v.replace('\f', '\\f')
+        # 处理Unicode转义
+        # v = v.replace('\u', '\\u')
+        
+        return v
+
+class TextCompareMultiRequest(BaseModel):
+    origin: str
+    similar: str
+
 @router.post("/search", response_model=StandardResponse)
 async def search_text(request: TextSearchRequest):
     try:
@@ -145,4 +174,119 @@ async def match_text(request: TextCompareRequest):
         logger.error(f"Text comparison failed: {str(e)}")
         raise HTTPException(status_code=500, detail=str(e))
 
+@router.post("/mr_search", response_model=StandardResponse)
+async def mr_search_text_content(request: TextMatchRequest):
+    try:
+        # 初始化服务
+        trunks_service = TrunksService()
+        
+        # 获取文本向量并搜索相似内容
+        search_results = trunks_service.search_by_vector(
+            text=request.text,
+            limit=10,
+            type="mr"
+        )
+
+        # 处理搜索结果
+        records = []
+        for result in search_results:
+            distance = result.get("distance", DISTANCE_THRESHOLD)
+            if distance >= DISTANCE_THRESHOLD:
+                continue
+
+            # 添加到引用列表
+            record = {
+                "content": result["content"],
+                "file_path": result.get("file_path", ""),
+                "title": result.get("title", ""),
+                "distance": distance,
+            }
+            records.append(record)
+
+        # 组装返回数据
+        response_data = {
+            "records": records
+        }
+        
+        return StandardResponse(success=True, data=response_data)
+        
+    except Exception as e:
+        logger.error(f"Mr search failed: {str(e)}")
+        raise HTTPException(status_code=500, detail=str(e))
+
+@router.post("/mr_match", response_model=StandardResponse)
+async def compare_text(request: TextCompareMultiRequest):
+    try:
+        # 拆分两段文本
+        origin_sentences = TextSplitter.split_text(request.origin)
+        similar_sentences = TextSplitter.split_text(request.similar)
+        
+        # 初始化结果列表
+        origin_results = []
+        similar_results = []
+        
+        # 获取origin文本的向量
+        for origin_sent in origin_sentences:
+            if len(origin_sent) < 10:
+                origin_results.append({"sentence": origin_sent, "matched": False})
+                continue
+                
+            origin_vector = Vectorizer.get_embedding(origin_sent)
+            min_distance = float('inf')
+            matched_sent = ""
+            
+            # 与similar文本的每个句子计算相似度
+            for similar_sent in similar_sentences:
+                if len(similar_sent) < 10:
+                    continue
+                    
+                similar_vector = Vectorizer.get_embedding(similar_sent)
+                distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
+                
+                if distance < min_distance and distance < DISTANCE_THRESHOLD:
+                    min_distance = distance
+                    matched_sent = similar_sent
+            
+            origin_results.append({
+                "sentence": origin_sent,
+                "matched": bool(matched_sent)
+            })
+        
+        # 获取similar文本的向量
+        for similar_sent in similar_sentences:
+            if len(similar_sent) < 10:
+                similar_results.append({"sentence": similar_sent, "matched": False})
+                continue
+                
+            similar_vector = Vectorizer.get_embedding(similar_sent)
+            min_distance = float('inf')
+            matched_sent = ""
+            
+            # 与origin文本的每个句子计算相似度
+            for origin_sent in origin_sentences:
+                if len(origin_sent) < 10:
+                    continue
+                    
+                origin_vector = Vectorizer.get_embedding(origin_sent)
+                distance = VectorDistance.calculate_distance(similar_vector, origin_vector)
+                
+                if distance < min_distance and distance < DISTANCE_THRESHOLD:
+                    min_distance = distance
+                    matched_sent = origin_sent
+            
+            similar_results.append({
+                "sentence": similar_sent,
+                "matched": bool(matched_sent)
+            })
+        
+        response_data = {
+            "origin": origin_results,
+            "similar": similar_results
+        }
+        
+        return StandardResponse(success=True, data=response_data)
+    except Exception as e:
+        logger.error(f"Text comparison failed: {str(e)}")
+        raise HTTPException(status_code=500, detail=str(e))
+
 text_search_router = router

+ 12 - 3
utils/file_reader.py

@@ -1,5 +1,4 @@
 import os
-import sys
 from service.trunks_service import TrunksService
 
 class FileReader:
@@ -16,7 +15,17 @@ class FileReader:
                     meta_header = lines[0]
                     content = ''.join(lines[1:])
                     TrunksService().create_trunk({'file_path': relative_path, 'content': content,'type':'trunk','meta_header':meta_header})
+    @staticmethod
+    def process_txt_files(directory):
+        for root, dirs, files in os.walk(directory):
+            for file in files:
+                if file.endswith('.txt'):
+                    file_path = os.path.join(root, file)
+                    with open(file_path, 'r', encoding='utf-8') as f:
+                        content = f.read()
+                    title = os.path.splitext(file)[0]
+                    TrunksService().create_trunk({'file_path': file_path, 'content': content, 'type': 'mr', 'title': title})
 
 if __name__ == '__main__':
-    directory = 'e:\\project\\knowledge\\utils\\files'
-    FileReader.find_and_print_split_files(directory)
+    directory = '/Users/ycw/work/脑梗死病历模版'
+    FileReader.process_txt_files(directory)