Explorar el Código

新增nodes表对应的代码

SGTY hace 2 meses
padre
commit
72c70a0deb
Se han modificado 6 ficheros con 4429 adiciones y 8 borrados
  1. 4370 0
      app.log
  2. 5 1
      main.py
  3. 1 6
      model/response.py
  4. 0 0
      router/knowledge_dify.py
  5. 52 0
      service/trunks_service.py
  6. 1 1
      tests/community/test_graph_helper.py

La diferencia del archivo ha sido suprimido porque es demasiado grande
+ 4370 - 0
app.log


+ 5 - 1
main.py

@@ -17,12 +17,16 @@ logger.propagate = True
 import os
 from fastapi import FastAPI
 import uvicorn
-from router.dify_kb import dify_kb_router
+from router.knowledge_dify import dify_kb_router
+from router.knowledge_saas import saas_kb_router
 
 # 创建FastAPI应用
 app = FastAPI(title="医疗百科问答系统")
 app.include_router(dify_kb_router)
+app.include_router(saas_kb_router)
 
 if __name__ == "__main__":
     logger.info('Starting uvicorn server...2222')
+    #uvicorn main:app --host 0.0.0.0 --port 8000 --reload
+    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 

+ 1 - 6
model/response.py

@@ -5,17 +5,12 @@ from starlette.responses import StreamingResponse, JSONResponse
 from typing import Any, Optional,List
 import json
 
-class ResponseModel(BaseModel):
-    code: int
-    message: str
-    data: Any = None  # 使用Optional[Any]也是可以的
-
-
 class StandardResponse(BaseModel):
     success: bool
     error_code: Optional[int] = None
     error_msg: Optional[str] = None
     records: Optional[Any] = None
+    data: Optional[Any] = None
 
 
 # class ResponseFormatterMiddleware(BaseHTTPMiddleware):

router/dify_kb.py → router/knowledge_dify.py


+ 52 - 0
service/trunks_service.py

@@ -9,6 +9,10 @@ from utils.vectorizer import Vectorizer
 logger = logging.getLogger(__name__)
 
 class TrunksService:
+    def __init__(self):
+        self.db = next(get_db())
+
+
     def create_trunk(self, trunk_data: dict) -> Trunks:
         # 自动生成向量和全文检索字段
         content = trunk_data.get('content')
@@ -108,5 +112,53 @@ class TrunksService:
             db.rollback()
             logger.error(f"删除trunk失败: {str(e)}")
             raise
+        finally:
+            db.close()
+
+    def paginated_search(self, search_params: dict) -> dict:
+        """
+        分页查询方法
+        :param search_params: 包含keyword, pageNo, limit的字典
+        :return: 包含结果列表和分页信息的字典
+        """
+        keyword = search_params.get('keyword', '')
+        page_no = search_params.get('pageNo', 1)
+        limit = search_params.get('limit', 10)
+        
+        if page_no < 1:
+            page_no = 1
+        if limit < 1:
+            limit = 10
+            
+        embedding = Vectorizer.get_embedding(keyword)
+        offset = (page_no - 1) * limit
+        
+        db = SessionLocal()
+        try:
+            # 获取总条数
+            total_count = db.query(func.count(Trunks.id)).scalar()
+            
+            # 执行向量搜索
+            results = db.query(
+                Trunks.id,
+                Trunks.file_path,
+                Trunks.content,
+                Trunks.embedding.l2_distance(embedding).label('distance')
+            ).order_by('distance').offset(offset).limit(limit).all()
+            
+            return {
+                'data': [{
+                    'id': r.id,
+                    'file_path': r.file_path,
+                    'content': r.content,
+                    'distance': r.distance
+                } for r in results],
+                'pagination': {
+                    'total': total_count,
+                    'pageNo': page_no,
+                    'limit': limit,
+                    'totalPages': (total_count + limit - 1) // limit
+                }
+            }
         finally:
             db.close()

+ 1 - 1
tests/community/test_graph_helper.py

@@ -13,7 +13,7 @@ class TestGraphHelper(unittest.TestCase):
     def setUpClass(cls):
         # 初始化图谱助手并构建图谱
         cls.helper = GraphHelper()
-        cls.test_node = "发热"  # 使用实际存在的测试节点
+        cls.test_node = "感染性发热"  # 使用实际存在的测试节点
         cls.test_community_node = "糖尿病"  # 用于社区检测的测试节点
 
     def test_graph_construction(self):