소스 검색

新增循证查询接口

yuchengwei 3 달 전
부모
커밋
375b1fe446
4개의 변경된 파일189개의 추가작업 그리고 10개의 파일을 삭제
  1. 1 1
      agent/cdss/libs/cdss_helper.py
  2. 4 0
      main.py
  3. 164 9
      router/text_search.py
  4. 20 0
      service/kg_prop_service.py

+ 1 - 1
agent/cdss/libs/cdss_helper.py

@@ -97,7 +97,7 @@ class CDSSHelper(GraphHelper):
                 continue
             if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
                 print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
-                with open(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json", "r", encoding="utf-8") as f:
+                with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
                     data = {"src": [], "dest": [], "type": [], "weight": []}
                     relations = json.load(f)
                     for item in relations:

+ 4 - 0
main.py

@@ -6,6 +6,7 @@ from typing import Optional, Set
 # 导入FastAPI及相关模块
 import os
 import uvicorn
+from fastapi.staticfiles import StaticFiles
 
 from agent.cdss.capbility import CDSSCapability
 from router.knowledge_dify import dify_kb_router
@@ -34,6 +35,9 @@ app.include_router(text_search_router)
 app.include_router(graph_router)
 app.include_router(knowledge_nodes_api_router)
 
+# 挂载静态文件目录,将/books路径映射到本地books文件夹
+app.mount("/books", StaticFiles(directory="books"), name="books")
+
 
 # 需要拦截的 URL 列表(支持通配符)
 INTERCEPT_URLS = {

+ 164 - 9
router/text_search.py

@@ -1,4 +1,4 @@
-from fastapi import APIRouter, HTTPException
+from fastapi import APIRouter, HTTPException, Depends
 from pydantic import BaseModel, Field, validator
 from typing import List, Optional
 from service.trunks_service import TrunksService
@@ -9,6 +9,10 @@ from utils.vectorizer import Vectorizer
 DISTANCE_THRESHOLD = 0.8
 import logging
 import time
+from db.session import get_db
+from sqlalchemy.orm import Session
+from service.kg_node_service import KGNodeService
+from service.kg_prop_service import KGPropService
 
 logger = logging.getLogger(__name__)
 router = APIRouter(prefix="/text", tags=["Text Search"])
@@ -51,6 +55,11 @@ class TextCompareMultiRequest(BaseModel):
     origin: str
     similar: str
 
+class NodePropsSearchRequest(BaseModel):
+    node_id: int
+    props_ids: List[int]
+    conversation_id: Optional[str] = None
+
 @router.post("/search", response_model=StandardResponse)
 async def search_text(request: TextSearchRequest):
     try:
@@ -107,14 +116,14 @@ async def search_text(request: TextSearchRequest):
                 )
             
             # 处理搜索结果
-            for result in search_results:
-                distance = result.get("distance", DISTANCE_THRESHOLD)
+            for search_result in search_results:
+                distance = search_result.get("distance", DISTANCE_THRESHOLD)
                 if distance >= DISTANCE_THRESHOLD:
                     result_sentences.append(sentence)
                     continue
                 
                 # 检查是否已存在相同引用
-                existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
+                existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
                 current_index = reference_index
                 if existing_ref:
                     current_index = int(existing_ref["index"])
@@ -122,12 +131,12 @@ async def search_text(request: TextSearchRequest):
                     # 添加到引用列表
                     reference = {
                         "index": str(reference_index),
-                        "id": result["id"],
-                        "content": result["content"],
-                        "file_path": result.get("file_path", ""),
-                        "title": result.get("title", ""),
+                        "id": search_result["id"],
+                        "content": search_result["content"],
+                        "file_path": search_result.get("file_path", ""),
+                        "title": search_result.get("title", ""),
                         "distance": distance,
-                        "referrence": result.get("referrence", "")
+                        "referrence": search_result.get("referrence", "")
                     }
                     all_references.append(reference)
                     reference_index += 1
@@ -302,4 +311,150 @@ async def compare_text(request: TextCompareMultiRequest):
         logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
         raise HTTPException(status_code=500, detail=str(e))
 
+@router.post("/eb_search", response_model=StandardResponse)
+async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
+    try:
+        start_time = time.time()
+        # 初始化服务
+        trunks_service = TrunksService()
+        node_service = KGNodeService(db)
+        prop_service = KGPropService(db)
+
+        # 根据node_id查询节点信息
+        node = node_service.get_node(request.node_id)
+        if not node:
+            raise ValueError(f"节点不存在: {request.node_id}")
+
+        node_name = node.get('name', '')
+
+        # 初始化结果
+        result = {
+            "id": request.node_id,
+            "name": node_name,
+            "category": node.get('category', ''),
+            "props": [],
+            "distance": 0
+        }
+
+        # 缓存结果
+        cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
+
+        # 遍历props_ids查询属性信息
+        for prop_id in request.props_ids:
+            prop = prop_service.get_props_by_id(prop_id)
+
+            if not prop:
+                logger.warning(f"属性不存在: {prop_id}")
+                continue
+
+            prop_title = prop.get('prop_title', '')
+            prop_value = prop.get('prop_value', '')
+
+            # 拆分属性值为句子
+            sentences = TextSplitter.split_text(prop_value)
+            prop_result = {
+                "id": prop_id,
+                "category": prop.get('category', 0),
+                "prop_name": prop.get('prop_name', ''),
+                "prop_value": prop_value,
+                "prop_title": prop_title,
+                "type": prop.get('type', 1)
+            }
+
+            # 添加到结果中
+            result["props"].append(prop_result)
+
+            # 处理属性值中的句子
+            result_sentences = []
+            all_references = []
+            reference_index = 1
+
+            # 对每个句子进行向量搜索
+            for sentence in sentences:
+                original_sentence = sentence
+                sentence = sentence.replace("\n", "<br>")
+                if len(sentence) < 10:
+                    result_sentences.append(sentence)
+                    continue
+
+                # 构建搜索文本
+                search_text = f"{node_name}:{prop_title}:{sentence}"
+
+                # 检查缓存
+                if cached_results:
+                    min_distance = float('inf')
+                    best_result = None
+                    sentence_vector = Vectorizer.get_embedding(search_text)
+
+                    for cached_result in cached_results:
+                        content_vector = cached_result['embedding']
+                        distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
+                        if distance < min_distance:
+                            min_distance = distance
+                            best_result = {**cached_result, 'distance': distance}
+
+                    if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
+                        search_results = [best_result]
+                    else:
+                        search_results = []
+                else:
+                    # 进行向量搜索
+                    search_results = trunks_service.search_by_vector(
+                        text=search_text,
+                        limit=1,
+                        type='trunk',
+                        conversation_id=request.conversation_id
+                    )
+
+                # 处理搜索结果
+                for search_result in search_results:
+                    distance = search_result.get("distance", DISTANCE_THRESHOLD)
+                    if distance >= DISTANCE_THRESHOLD:
+                        result_sentences.append(sentence)
+                        continue
+
+                    # 检查是否已存在相同引用
+                    existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
+                    current_index = reference_index
+                    if existing_ref:
+                        current_index = int(existing_ref["index"])
+                    else:
+                        # 添加到引用列表
+                        reference = {
+                            "index": str(reference_index),
+                            "id": search_result["id"],
+                            "content": search_result["content"],
+                            "file_path": search_result.get("file_path", ""),
+                            "title": search_result.get("title", ""),
+                            "distance": distance,
+                            "referrence": search_result.get("referrence", "")
+                        }
+                        all_references.append(reference)
+                        reference_index += 1
+
+                    # 添加引用标记
+                    if sentence.endswith('<br>'):
+                        # 如果有多个<br>,在所有<br>前添加^[current_index]^
+                        result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
+                    else:
+                        # 直接在句子末尾添加^[current_index]^
+                        result_sentence = f'{sentence}^[{current_index}]^'
+
+                    result_sentences.append(result_sentence)
+
+            # 更新属性值,添加引用信息
+            if all_references:
+                prop_result["references"] = all_references
+
+            # 将处理后的句子添加到结果中
+            if result_sentences:
+                prop_result["answer"] = result_sentences
+
+        end_time = time.time()
+        logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
+        return StandardResponse(success=True, data=result)
+    except Exception as e:
+        logger.error(f"Node props search failed: {str(e)}")
+        raise HTTPException(status_code=500, detail=str(e))
+
 text_search_router = router

+ 20 - 0
service/kg_prop_service.py

@@ -29,6 +29,26 @@ class KGPropService:
             logger.error(f"根据ref_id查询属性失败: {str(e)}")
             raise ValueError("查询失败")
 
+    def get_props_by_id(self, id: int, prop_name: str = None) -> List[dict]:
+        try:
+            query = self.db.query(KGProp).filter(KGProp.id == id)
+            if prop_name:
+                query = query.filter(KGProp.prop_name == prop_name)
+            props = query.first()
+            if not props:
+                raise ValueError("props not found")
+            return {
+                'id': props.id,
+                'category': props.category,
+                'prop_name': props.prop_name,
+                'prop_value': props.prop_value,
+                'prop_title': props.prop_title,
+                'type': props.type
+            }
+        except Exception as e:
+            logger.error(f"根据id查询属性失败: {str(e)}")
+            raise ValueError("查询失败")
+
     def create_prop(self, prop_data: dict) -> KGProp:
         try:
             new_prop = KGProp(**prop_data)