Parcourir la source

相似病例相关接口11

yuchengwei il y a 3 mois
Parent
commit
aef2fb6236
3 fichiers modifiés avec 67 ajouts et 55 suppressions
  1. 13 8
      requirements.txt
  2. 53 46
      router/text_search.py
  3. 1 1
      utils/file_reader.py

+ 13 - 8
requirements.txt

@@ -1,9 +1,14 @@
-fastapi>=0.95.2
-uvicorn[standard]>=0.18.3
-sqlalchemy==1.4.23
-pydantic>=2.7.4,<3.0.0
-jinja2>=3.0.3
-python-multipart==0.0.5
-aiofiles>=24.1.0
+fastapi==0.115.12
+leidenalg==0.10.2
+networkx==3.4.2
+numpy==2.2.4
 pgvector==0.1.8
-sentence-transformers>=2.2.2
+pydantic==2.11.1
+pytest==8.3.5
+python_igraph==0.11.8
+Requests==2.32.3
+SQLAlchemy==2.0.38
+starlette==0.46.1
+tabulate==0.9.0
+urllib3==2.3.0
+uvicorn==0.34.0

+ 53 - 46
router/text_search.py

@@ -8,6 +8,7 @@ from model.response import StandardResponse
 from utils.vectorizer import Vectorizer
 DISTANCE_THRESHOLD = 0.8
 import logging
+import time
 
 logger = logging.getLogger(__name__)
 router = APIRouter(prefix="/text", tags=["Text Search"])
@@ -95,7 +96,8 @@ async def search_text(request: TextSearchRequest):
                 # 如果没有缓存结果,进行向量搜索
                 search_results = trunks_service.search_by_vector(
                     text=sentence,
-                    limit=1
+                    limit=1,
+                    type='trunk'
                 )
             
             # 处理搜索结果
@@ -216,77 +218,82 @@ async def mr_search_text_content(request: TextMatchRequest):
 
 @router.post("/mr_match", response_model=StandardResponse)
 async def compare_text(request: TextCompareMultiRequest):
+    start_time = time.time()
     try:
         # 拆分两段文本
         origin_sentences = TextSplitter.split_text(request.origin)
         similar_sentences = TextSplitter.split_text(request.similar)
+        end_time = time.time()
+        logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
         
         # 初始化结果列表
         origin_results = []
-        similar_results = []
         
-        # 获取origin文本的向量
-        for origin_sent in origin_sentences:
-            if len(origin_sent) < 10:
+        # 过滤短句并预计算向量
+        valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
+        valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
+        
+        # 初始化similar_results,所有matched设为False
+        similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
+        
+        # 批量获取向量
+        origin_vectors = {}
+        similar_vectors = {}
+        origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
+        similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
+        
+        if origin_batch:
+            origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
+            origin_vectors = dict(zip(origin_batch, origin_embeddings))
+        
+        if similar_batch:
+            similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
+            similar_vectors = dict(zip(similar_batch, similar_embeddings))
+
+        end_time = time.time()
+        logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms") 
+        # 处理origin文本
+        for origin_sent, is_valid in valid_origin_sentences:
+            if not is_valid:
                 origin_results.append({"sentence": origin_sent, "matched": False})
                 continue
-                
-            origin_vector = Vectorizer.get_embedding(origin_sent)
-            min_distance = float('inf')
-            matched_sent = ""
             
-            # 与similar文本的每个句子计算相似度
-            for similar_sent in similar_sentences:
+            origin_vector = origin_vectors[origin_sent]
+            matched = False
+            
+            # 优化的相似度计算
+            for i, similar_result in enumerate(similar_results):
+                if similar_result["matched"]:
+                    continue
+                    
+                similar_sent = similar_result["sentence"]
                 if len(similar_sent) < 10:
                     continue
                     
-                similar_vector = Vectorizer.get_embedding(similar_sent)
-                distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
-                
-                if distance < min_distance and distance < DISTANCE_THRESHOLD:
-                    min_distance = distance
-                    matched_sent = similar_sent
-            
-            origin_results.append({
-                "sentence": origin_sent,
-                "matched": bool(matched_sent)
-            })
-        
-        # 获取similar文本的向量
-        for similar_sent in similar_sentences:
-            if len(similar_sent) < 10:
-                similar_results.append({"sentence": similar_sent, "matched": False})
-                continue
-                
-            similar_vector = Vectorizer.get_embedding(similar_sent)
-            min_distance = float('inf')
-            matched_sent = ""
-            
-            # 与origin文本的每个句子计算相似度
-            for origin_sent in origin_sentences:
-                if len(origin_sent) < 10:
+                similar_vector = similar_vectors.get(similar_sent)
+                if not similar_vector:
                     continue
                     
-                origin_vector = Vectorizer.get_embedding(origin_sent)
-                distance = VectorDistance.calculate_distance(similar_vector, origin_vector)
-                
-                if distance < min_distance and distance < DISTANCE_THRESHOLD:
-                    min_distance = distance
-                    matched_sent = origin_sent
+                distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
+                if distance < DISTANCE_THRESHOLD:
+                    matched = True
+                    similar_results[i]["matched"] = True
+                    break
             
-            similar_results.append({
-                "sentence": similar_sent,
-                "matched": bool(matched_sent)
-            })
+            origin_results.append({"sentence": origin_sent, "matched": matched})
         
         response_data = {
             "origin": origin_results,
             "similar": similar_results
         }
         
+        end_time = time.time()
+        logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
         return StandardResponse(success=True, data=response_data)
     except Exception as e:
+        end_time = time.time()
         logger.error(f"Text comparison failed: {str(e)}")
+        logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
         raise HTTPException(status_code=500, detail=str(e))
 
 text_search_router = router

+ 1 - 1
utils/file_reader.py

@@ -27,5 +27,5 @@ class FileReader:
                     TrunksService().create_trunk({'file_path': file_path, 'content': content, 'type': 'mr', 'title': title})
 
 if __name__ == '__main__':
-    directory = '/Users/ycw/work/梗死病历模版'
+    directory = '/Users/ycw/work/心肌梗死病历模版'
     FileReader.process_txt_files(directory)