SGTY 3 mēneši atpakaļ
vecāks
revīzija
c163304e0f
5 mainītis faili ar 218 papildinājumiem un 90 dzēšanām
  1. 99 10
      main.py
  2. 1 0
      model/response.py
  3. 115 0
      router/knowledge_nodes_api.py
  4. 0 78
      router/knowledge_saas.py
  5. 3 2
      service/kg_node_service.py

+ 99 - 10
main.py

@@ -1,5 +1,16 @@
 import logging
+import uuid
 from logging.handlers import RotatingFileHandler
+from fastapi import FastAPI, Request, Response, status
+from typing import Optional, Set
+# 导入FastAPI及相关模块
+import os
+import uvicorn
+from router.knowledge_dify import dify_kb_router
+from router.knowledge_saas import saas_kb_router
+from router.text_search import text_search_router
+from router.graph_router import graph_router
+from router.knowledge_nodes_api import knowledge_nodes_api_router
 
 # 配置日志
 logging.basicConfig(
@@ -13,21 +24,99 @@ logging.basicConfig(
 logger = logging.getLogger(__name__)
 logger.propagate = True
 
-# 导入FastAPI及相关模块
-import os
-from fastapi import FastAPI
-import uvicorn
-from router.knowledge_dify import dify_kb_router
-from router.knowledge_saas import saas_kb_router
-from router.text_search import text_search_router
-from router.graph_router import graph_router
-
 # 创建FastAPI应用
-app = FastAPI(title="医疗百科问答系统")
+app = FastAPI(title="知识图谱")
 app.include_router(dify_kb_router)
 app.include_router(saas_kb_router)
 app.include_router(text_search_router)
 app.include_router(graph_router)
+app.include_router(knowledge_nodes_api_router)
+
+
+# 需要拦截的 URL 列表(支持通配符)
+INTERCEPT_URLS = {
+    "/v1/knowledge/*"
+}
+
+# 白名单 URL(不需要拦截的路径)
+WHITE_LIST = {
+    "/api/public",
+    "/admin/login"
+}
+
+
+async def verify_token(authorization: str) -> Optional[dict]:
+    """
+    验证 token 有效性
+    返回:验证成功返回用户信息字典,失败返回 None
+    """
+    if not authorization.startswith("Bearer "):
+        return None
+
+    token = authorization[7:]
+    # 这里添加实际的 token 验证逻辑
+    # 示例:简单验证 token 是否等于 secret-token
+    if token == "secret-token":
+        return {"id": 1, "username": "admin", "role": "admin"}
+    return None
+
+
+def should_intercept(path: str) -> bool:
+    """
+    判断是否需要拦截当前路径
+    """
+    if path in WHITE_LIST:
+        return False
+
+    for pattern in INTERCEPT_URLS:
+        # 处理通配符匹配
+        if pattern.endswith("/*"):
+            if path.startswith(pattern[:-1]):
+                return True
+        # 精确匹配
+        elif path == pattern:
+            return True
+    return False
+
+
+@app.middleware("http")
+async def interceptor_middleware(request: Request, call_next):
+    path = request.url.path
+
+    if not should_intercept(path):
+        return await call_next(request)
+
+    # 权限校验
+    auth_header = request.headers.get("Authorization")
+    if not auth_header:
+        return Response(
+            content="Missing Authorization header",
+            status_code=status.HTTP_401_UNAUTHORIZED
+        )
+
+    user_info = await verify_token(auth_header)
+    if not user_info:
+        return Response(
+            content="Invalid token",
+            status_code=status.HTTP_401_UNAUTHORIZED
+        )
+
+    # 初始化操作:将用户信息添加到请求状态中
+    request.state.user = user_info
+
+    # 添加请求上下文(示例)
+    request.state.context = {
+        "request_id": request.headers.get("request-id", str(uuid.uuid4())),
+        "client_ip": request.client.host
+    }
+
+    # 继续处理请求
+    response = await call_next(request)
+    # 可以在返回前添加统一响应处理(如添加头信息)
+    response.headers["request-id"]=request.state.context["request_id"]
+
+    return response
+
 
 if __name__ == "__main__":
     logger.info('Starting uvicorn server...2222')

+ 1 - 0
model/response.py

@@ -7,6 +7,7 @@ import json
 
 class StandardResponse(BaseModel):
     success: bool
+    requestId: Optional[str] = None
     errorCode: Optional[int] = None
     errorMsg: Optional[str] = None
     records: Optional[Any] = None

+ 115 - 0
router/knowledge_nodes_api.py

@@ -0,0 +1,115 @@
+from fastapi import APIRouter, Depends, HTTPException, Request
+from typing import Optional, List
+from pydantic import BaseModel
+from model.response import StandardResponse
+from db.session import get_db
+from sqlalchemy.orm import Session
+from service.kg_node_service import KGNodeService
+from service.kg_edge_service import KGEdgeService
+from service.kg_prop_service import KGPropService
+import logging
+from utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
+
+router = APIRouter(prefix="/v1/knowledge", tags=["SaaS Knowledge Base"])
+
+logger = logging.getLogger(__name__)
+
+class PaginatedSearchRequest(BaseModel):
+    name: str
+    distance: float = 1.45
+    category: Optional[str] = None
+    pageNo: int = 1
+    limit: int = 10
+
+async def get_request_id(request: Request):
+    return request.state.context["request_id"]
+
+@router.post("/nodes/paginated_search", response_model=StandardResponse)
+async def paginated_search(
+    payload: PaginatedSearchRequest,
+    db: Session = Depends(get_db),
+    request_id: str = Depends(get_request_id)
+):
+    try:
+        service = KGNodeService(db)
+        search_params = {
+            'keyword': payload.name,
+            'category': payload.category,
+            'pageNo': payload.pageNo,
+            'limit': payload.limit,
+            'load_props': True,
+            'distance': 2.0-payload.distance,
+        }
+        result = service.paginated_search(search_params)
+        #result[data]的distance属性去掉
+        for item in result['records']:
+            item.pop('distance', None)
+        #result[pagination]去掉
+        result.pop('pagination', None)
+        #result[records]改为result[nodes]
+        result['nodes'] = result['records']
+        result.pop('records', None)   
+        return StandardResponse(
+            success=True,
+            requestId=request_id,
+            data=ObjectToJsonArrayConverter.convert(result)
+        )
+    except Exception as e:
+        logger.error(f"分页查询失败: {str(e)}")
+        raise HTTPException(
+            status_code=500,
+            detail=StandardResponse(
+                success=False,
+                error_code=500,
+                error_msg=str(e)
+            )
+        )
+
+@router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
+async def get_node_relationships(
+    src_id: int,
+    db: Session = Depends(get_db),
+    request_id: str = Depends(get_request_id)
+):
+    try:
+        edge_service = KGEdgeService(db)
+        prop_service = KGPropService(db)
+        
+        edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
+        relationships = []
+               
+        #count = 0
+        for edge in edges:
+            #if count >= 2:
+                #break
+            dest_node = edge['dest_node']
+            dest_props = []
+            edge_props = []
+            #count += 1
+            #dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
+            #              for p in prop_service.get_props_by_ref_id(dest_node['id'])]
+            
+            #edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
+            #             for p in prop_service.get_props_by_ref_id(edge['id'])]
+
+            relationships.append({
+                "name": edge['name'],
+                "props": edge_props,
+                "destNode": {
+                    "category": dest_node['category'],
+                    "id": str(dest_node['id']),
+                    "name": dest_node['name'],
+                    "props": dest_props
+                }
+            })
+        
+        return StandardResponse(
+            success=True,
+            requestId=request_id,
+            data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
+        )
+    except Exception as e:
+        logger.error(f"获取节点关系失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+knowledge_nodes_api_router = router

+ 0 - 78
router/knowledge_saas.py

@@ -24,7 +24,6 @@ class PaginatedSearchRequest(BaseModel):
 
 class NodePaginatedSearchRequest(BaseModel):
     name: str
-    appid: str
     category: Optional[str] = None
     pageNo: int = 1
     limit: int = 10  
@@ -83,44 +82,6 @@ async def paginated_search(
             )
         )
 
-@router.post("/nodes/query", response_model=StandardResponse)
-async def paginated_search(
-    payload: NodePaginatedSearchRequest,
-    db: Session = Depends(get_db)
-):
-    try:
-        service = KGNodeService(db)
-        search_params = {
-            'keyword': payload.name,
-            'category': payload.category,
-            'pageNo': payload.pageNo,
-            'limit': payload.limit,
-            'load_props': True
-        }
-        result = service.paginated_search(search_params)
-        #result[data]的distance属性去掉
-        for item in result['records']:
-            item.pop('distance', None)
-        #result[pagination]去掉
-        result.pop('pagination', None)
-        #result[records]改为result[nodes]
-        result['nodes'] = result['records']
-        result.pop('records', None)   
-        return StandardResponse(
-            success=True,
-            data=ObjectToJsonArrayConverter.convert(result)
-        )
-    except Exception as e:
-        logger.error(f"分页查询失败: {str(e)}")
-        raise HTTPException(
-            status_code=500,
-            detail=StandardResponse(
-                success=False,
-                error_code=500,
-                error_msg=str(e)
-            )
-        )
-
 @router.post("/nodes", response_model=StandardResponse)
 async def create_node(
     payload: NodeCreateRequest,
@@ -204,43 +165,4 @@ async def get_trunk(
         logger.error(f"获取trunk详情失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
-@router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
-async def get_node_relationships(
-    src_id: int,
-    db: Session = Depends(get_db)
-):
-    try:
-        edge_service = KGEdgeService(db)
-        prop_service = KGPropService(db)
-        
-        edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
-        relationships = []
-        
-        for edge in edges:
-            dest_node = edge['dest_node']
-            dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']} 
-                          for p in prop_service.get_props_by_ref_id(dest_node['id'])]
-            
-            edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
-                         for p in prop_service.get_props_by_ref_id(edge['id'])]
-            
-            relationships.append({
-                "name": edge['name'],
-                "properties": edge_props,
-                "targetNode": {
-                    "category": dest_node['category'],
-                    "id": str(dest_node['id']),
-                    "name": dest_node['name'],
-                    "properties": dest_props
-                }
-            })
-        
-        return StandardResponse(
-            success=True,
-            data={"relationships": relationships}
-        )
-    except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
-        raise HTTPException(500, detail=StandardResponse.error(str(e)))
-
 saas_kb_router = router

+ 3 - 2
service/kg_node_service.py

@@ -61,6 +61,7 @@ class KGNodeService:
         keyword = search_params.get('keyword', '')
         category = search_params.get('category', '')
         page_no = search_params.get('pageNo', 1)
+        distance = search_params.get('distance',DISTANCE_THRESHOLD)
         limit = search_params.get('limit', 10)
 
         if page_no < 1:
@@ -75,7 +76,7 @@ class KGNodeService:
             # 构建基础查询条件
             base_query = self.db.query(func.count(KGNode.id)).filter(
                 KGNode.status == 0,
-                KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD
+                KGNode.embedding.l2_distance(embedding) < distance
             )
             # 如果有category,则添加额外过滤条件
             if category:
@@ -100,7 +101,7 @@ class KGNodeService:
                 query = query.filter(KGNode.category == category)
             if search_params.get('knowledge_ids'):
                 query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
-            query = query.filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD)
+            query = query.filter(KGNode.embedding.l2_distance(embedding) < distance)
             results = query.order_by('distance').offset(offset).limit(limit).all()
             #将results相同distance的category=疾病的放在前面
             results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))