Pārlūkot izejas kodu

Merge branch 'master' of http://173.18.12.196:3000/python/knowledge

SGTY 2 mēneši atpakaļ
vecāks
revīzija
4506a7dfa5

+ 1 - 1
agent/cdss/libs/cdss_helper.py

@@ -100,7 +100,7 @@ class CDSSHelper(GraphHelper):
                 continue
             if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
                 print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
-                with open(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json", "r", encoding="utf-8") as f:
+                with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
                     data = {"src": [], "dest": [], "type": [], "weight": []}
                     relations = json.load(f)
                     for item in relations:

+ 16 - 4
main.py

@@ -6,13 +6,15 @@ from typing import Optional, Set
 # 导入FastAPI及相关模块
 import os
 import uvicorn
+from fastapi.staticfiles import StaticFiles
+from fastapi.middleware.cors import CORSMiddleware
 
-from agent.cdss.capbility import CDSSCapability
+# from agent.cdss.capbility import CDSSCapability
 from router.knowledge_dify import dify_kb_router
 from router.knowledge_saas import saas_kb_router
 from router.text_search import text_search_router
 from router.graph_router import graph_router
-from router.knowledge_nodes_api import knowledge_nodes_api_router
+# from router.knowledge_nodes_api import knowledge_nodes_api_router
 
 # 配置日志
 logging.basicConfig(
@@ -32,8 +34,18 @@ app.include_router(dify_kb_router)
 app.include_router(saas_kb_router)
 app.include_router(text_search_router)
 app.include_router(graph_router)
-app.include_router(knowledge_nodes_api_router)
-
+# app.include_router(knowledge_nodes_api_router)
+
+# 挂载静态文件目录,将/books路径映射到本地books文件夹
+app.mount("/books", StaticFiles(directory="books"), name="books")
+# 允许所有来源(仅用于测试,生产环境应限制)
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],  # 允许所有来源(或指定 ["http://localhost:3000"])
+    allow_credentials=True,  # 允许携带 Cookie
+    allow_methods=["*"],  # 允许所有方法(或指定 ["GET", "POST"])
+    allow_headers=["*"],  # 允许所有请求头
+)
 
 # 需要拦截的 URL 列表(支持通配符)
 INTERCEPT_URLS = {

+ 1 - 0
model/trunks_model.py

@@ -15,6 +15,7 @@ class Trunks(Base):
     title = Column(String(255))
     referrence = Column(String(255))
     meta_header = Column(String(255))
+    page_no = Column(Integer) 
 
     def __repr__(self):
         return f"<Trunks(id={self.id}, file_path={self.file_path})>"

+ 22 - 0
router/knowledge_saas.py

@@ -64,6 +64,28 @@ async def paginated_search(
             'load_props': True
         }
         result = service.paginated_search(search_params)
+        
+        # 定义prop_title的排序顺序
+        prop_title_order = [
+            '基础信息', '概述', '病因学', '流行病学', '发病机制', '病理学',
+            '临床表现', '辅助检查', '诊断', '鉴别诊断', '并发症', '治疗', '护理', '预后', '预防'
+        ]
+        
+        # 处理每个记录的props,过滤并排序
+        for record in result['records']:
+            if 'props' in record:
+                # 只保留指定的prop_title
+                filtered_props = [prop for prop in record['props'] if prop.get('prop_title') in prop_title_order]
+                
+                # 按照指定顺序排序
+                sorted_props = sorted(
+                    filtered_props,
+                    key=lambda x: prop_title_order.index(x.get('prop_title')) if x.get('prop_title') in prop_title_order else len(prop_title_order)
+                )
+                
+                # 更新记录中的props
+                record['props'] = sorted_props
+        
         return StandardResponse(
             success=True,
             data={

+ 208 - 9
router/text_search.py

@@ -1,4 +1,4 @@
-from fastapi import APIRouter, HTTPException
+from fastapi import APIRouter, HTTPException, Depends
 from pydantic import BaseModel, Field, validator
 from typing import List, Optional
 from service.trunks_service import TrunksService
@@ -6,9 +6,15 @@ from utils.text_splitter import TextSplitter
 from utils.vector_distance import VectorDistance
 from model.response import StandardResponse
 from utils.vectorizer import Vectorizer
+# from utils.find_text_in_pdf import find_text_in_pdf
+import os
 DISTANCE_THRESHOLD = 0.8
 import logging
 import time
+from db.session import get_db
+from sqlalchemy.orm import Session
+from service.kg_node_service import KGNodeService
+from service.kg_prop_service import KGPropService
 
 logger = logging.getLogger(__name__)
 router = APIRouter(prefix="/text", tags=["Text Search"])
@@ -51,6 +57,10 @@ class TextCompareMultiRequest(BaseModel):
     origin: str
     similar: str
 
+class NodePropsSearchRequest(BaseModel):
+    node_id: int
+    props_ids: List[int]
+
 @router.post("/search", response_model=StandardResponse)
 async def search_text(request: TextSearchRequest):
     try:
@@ -107,14 +117,14 @@ async def search_text(request: TextSearchRequest):
                 )
             
             # 处理搜索结果
-            for result in search_results:
-                distance = result.get("distance", DISTANCE_THRESHOLD)
+            for search_result in search_results:
+                distance = search_result.get("distance", DISTANCE_THRESHOLD)
                 if distance >= DISTANCE_THRESHOLD:
                     result_sentences.append(sentence)
                     continue
                 
                 # 检查是否已存在相同引用
-                existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
+                existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
                 current_index = reference_index
                 if existing_ref:
                     current_index = int(existing_ref["index"])
@@ -122,13 +132,15 @@ async def search_text(request: TextSearchRequest):
                     # 添加到引用列表
                     reference = {
                         "index": str(reference_index),
-                        "id": result["id"],
-                        "content": result["content"],
-                        "file_path": result.get("file_path", ""),
-                        "title": result.get("title", ""),
+                        "id": search_result["id"],
+                        "content": search_result["content"],
+                        "file_path": search_result.get("file_path", ""),
+                        "title": search_result.get("title", ""),
                         "distance": distance,
-                        "referrence": result.get("referrence", "")
+                        "referrence": search_result.get("referrence", "")
                     }
+                    
+                    
                     all_references.append(reference)
                     reference_index += 1
                 
@@ -302,4 +314,191 @@ async def compare_text(request: TextCompareMultiRequest):
         logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
         raise HTTPException(status_code=500, detail=str(e))
 
+@router.post("/eb_search", response_model=StandardResponse)
+async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
+    try:
+        start_time = time.time()
+        # 初始化服务
+        trunks_service = TrunksService()
+        node_service = KGNodeService(db)
+        prop_service = KGPropService(db)
+
+        # 根据node_id查询节点信息
+        node = node_service.get_node(request.node_id)
+        if not node:
+            raise ValueError(f"节点不存在: {request.node_id}")
+
+        node_name = node.get('name', '')
+
+        # 初始化结果
+        result = {
+            "id": request.node_id,
+            "name": node_name,
+            "category": node.get('category', ''),
+            "props": [],
+            "files": [],
+            "distance": 0
+        }
+
+        # 遍历props_ids查询属性信息
+        for prop_id in request.props_ids:
+            prop = prop_service.get_props_by_id(prop_id)
+
+            if not prop:
+                logger.warning(f"属性不存在: {prop_id}")
+                continue
+
+            prop_title = prop.get('prop_title', '')
+            prop_value = prop.get('prop_value', '')
+
+            # 拆分属性值为句子
+            sentences = TextSplitter.split_text(prop_value)
+            prop_result = {
+                "id": prop_id,
+                "category": prop.get('category', 0),
+                "prop_name": prop.get('prop_name', ''),
+                "prop_value": prop_value,
+                "prop_title": prop_title,
+                "type": prop.get('type', 1)
+            }
+
+            # 添加到结果中
+            result["props"].append(prop_result)
+
+            # 处理属性值中的句子
+            result_sentences = []
+            all_references = []
+            reference_index = 1
+
+            # 对每个句子进行向量搜索
+            i = 0
+            while i < len(sentences):
+                original_sentence = sentences[i]
+                sentence = original_sentence
+                
+                # 如果当前句子长度小于10且不是最后一句,则与下一句合并
+                if len(sentence) < 10 and i + 1 < len(sentences):
+                    next_sentence = sentences[i + 1]
+                    combined_sentence = sentence + " " + next_sentence
+                    # 添加原短句到结果,flag为空
+                    result_sentences.append({
+                        "sentence": sentence,
+                        "flag": ""
+                    })
+                    # 使用合并后的句子进行搜索
+                    search_text = f"{node_name}:{prop_title}:{combined_sentence}"
+                    i += 1  # 跳过下一句,因为已经合并使用
+                elif len(sentence) < 10:
+                    # 如果是最后一句且长度小于10,直接添加到结果,flag为空
+                    result_sentences.append({
+                        "sentence": sentence,
+                        "flag": ""
+                    })
+                    i += 1
+                    continue
+                else:
+                    # 句子长度足够,直接使用
+                    search_text = f"{node_name}:{prop_title}:{sentence}"
+                
+                i += 1
+
+                # 进行向量搜索
+                search_results = trunks_service.search_by_vector(
+                        text=search_text,
+                        limit=1,
+                        type='trunk'
+                )
+
+                # 处理搜索结果
+                if not search_results:
+                    # 没有搜索结果,添加原句子,flag为空
+                    result_sentences.append({
+                        "sentence": sentence,
+                        "flag": ""
+                    })
+                    continue
+                    
+                for search_result in search_results:
+                    distance = search_result.get("distance", DISTANCE_THRESHOLD)
+                    if distance >= DISTANCE_THRESHOLD:
+                        # 距离过大,添加原句子,flag为空
+                        result_sentences.append({
+                            "sentence": sentence,
+                            "flag": ""
+                        })
+                        continue
+
+                    # 检查是否已存在相同引用
+                    existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
+                    current_index = reference_index
+                    if existing_ref:
+                        current_index = int(existing_ref["index"])
+                    else:
+                        # 添加到引用列表
+                        reference = {
+                            "index": str(reference_index),
+                            "id": search_result["id"],
+                            "content": search_result["content"],
+                            "file_path": search_result.get("file_path", ""),
+                            "title": search_result.get("title", ""),
+                            "distance": distance,
+                            "page_no": search_result.get("page_no", ""),
+                            "referrence": search_result.get("referrence", "")
+                        }
+                        
+                        all_references.append(reference)
+                        reference_index += 1
+
+                    # 添加句子和引用标记(作为单独的flag字段)
+                    result_sentences.append({
+                        "sentence": sentence,
+                        "flag": str(current_index)
+                    })
+
+            # 更新属性值,添加引用信息
+            if all_references:
+                prop_result["references"] = all_references
+
+            # 将处理后的句子添加到结果中
+            if result_sentences:
+                prop_result["answer"] = result_sentences
+
+        # 处理所有引用中的文件信息
+        all_files = set()
+        for prop_result in result["props"]:
+            if "references" in prop_result:
+                for ref in prop_result["references"]:
+                    referrence = ref.get("referrence", "")
+                    if referrence and "/books/" in referrence:
+                        # 提取/books/后面的文件名
+                        file_name = referrence.split("/books/")[-1]
+                        if file_name:
+                            # 根据文件名后缀确定文件类型
+                            file_type = ""
+                            if file_name.lower().endswith(".pdf"):
+                                file_type = "pdf"
+                            elif file_name.lower().endswith(".doc") or file_name.lower().endswith(".docx"):
+                                file_type = "doc"
+                            elif file_name.lower().endswith(".xls") or file_name.lower().endswith(".xlsx"):
+                                file_type = "excel"
+                            elif file_name.lower().endswith(".ppt") or file_name.lower().endswith(".pptx"):
+                                file_type = "ppt"
+                            else:
+                                file_type = "other"
+                            
+                            all_files.add((file_name, file_type))
+        
+        # 将文件信息添加到结果中
+        result["files"] = [{
+            "file_name": file_name,
+            "file_type": file_type
+        } for file_name, file_type in all_files]
+        
+        end_time = time.time()
+        logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
+        return StandardResponse(success=True, data=result)
+    except Exception as e:
+        logger.error(f"Node props search failed: {str(e)}")
+        raise HTTPException(status_code=500, detail=str(e))
+
 text_search_router = router

+ 20 - 0
service/kg_prop_service.py

@@ -29,6 +29,26 @@ class KGPropService:
             logger.error(f"根据ref_id查询属性失败: {str(e)}")
             raise ValueError("查询失败")
 
+    def get_props_by_id(self, id: int, prop_name: str = None) -> List[dict]:
+        try:
+            query = self.db.query(KGProp).filter(KGProp.id == id)
+            if prop_name:
+                query = query.filter(KGProp.prop_name == prop_name)
+            props = query.first()
+            if not props:
+                raise ValueError("props not found")
+            return {
+                'id': props.id,
+                'category': props.category,
+                'prop_name': props.prop_name,
+                'prop_value': props.prop_value,
+                'prop_title': props.prop_title,
+                'type': props.type
+            }
+        except Exception as e:
+            logger.error(f"根据id查询属性失败: {str(e)}")
+            raise ValueError("查询失败")
+
     def create_prop(self, prop_data: dict) -> KGProp:
         try:
             new_prop = KGProp(**prop_data)

+ 2 - 0
service/trunks_service.py

@@ -73,6 +73,7 @@ class TrunksService:
                 Trunks.embedding.l2_distance(embedding).label('distance'),
                 Trunks.title,
                 Trunks.embedding,
+                Trunks.page_no,
                 Trunks.referrence
             )
             if metadata_condition:
@@ -88,6 +89,7 @@ class TrunksService:
                 'distance': round(r.distance, 3),
                 'title': r.title,
                 'embedding': r.embedding.tolist(),
+                'page_no': r.page_no,
                 'referrence': r.referrence
             } for r in results]