123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- from fastapi import APIRouter, Depends, HTTPException
- from typing import Optional, List
- from pydantic import BaseModel
- from model.response import StandardResponse
- from db.session import get_db
- from sqlalchemy.orm import Session
- from service.kg_node_service import KGNodeService
- from service.trunks_service import TrunksService
- from service.kg_edge_service import KGEdgeService
- from service.kg_prop_service import KGPropService
- import logging
- from utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
- router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
- logger = logging.getLogger(__name__)
- class PaginatedSearchRequest(BaseModel):
- keyword: Optional[str] = None
- category: Optional[str] = None
- pageNo: int = 1
- limit: int = 10
- knowledge_ids: Optional[List[str]] = None
- class NodePaginatedSearchRequest(BaseModel):
- name: str
- category: Optional[str] = None
- pageNo: int = 1
- limit: int = 10
- class NodeCreateRequest(BaseModel):
- name: str
- category: str
- layout: Optional[str] = None
- version: Optional[str] = None
- embedding: Optional[List[float]] = None
- class NodeUpdateRequest(BaseModel):
- layout: Optional[str] = None
- version: Optional[str] = None
- status: Optional[int] = None
- embedding: Optional[List[float]] = None
- class VectorSearchRequest(BaseModel):
- text: str
- limit: int = 10
- type: Optional[str] = None
- class NodeRelationshipRequest(BaseModel):
- src_id: int
- @router.post("/nodes/paginated_search", response_model=StandardResponse)
- async def paginated_search(
- payload: PaginatedSearchRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = KGNodeService(db)
- search_params = {
- 'keyword': payload.keyword,
- 'category': payload.category,
- 'pageNo': payload.pageNo,
- 'limit': payload.limit,
- 'knowledge_ids': payload.knowledge_ids,
- 'load_props': True
- }
- result = service.paginated_search(search_params)
-
- # 定义prop_title的排序顺序
- prop_title_order = [
- '基础信息', '概述', '病因学', '流行病学', '发病机制', '病理学',
- '临床表现', '辅助检查', '诊断', '鉴别诊断', '并发症', '治疗', '护理', '预后', '预防'
- ]
-
- # 处理每个记录的props,过滤并排序
- for record in result['records']:
- if 'props' in record:
- # 只保留指定的prop_title
- filtered_props = [prop for prop in record['props'] if prop.get('prop_title') in prop_title_order]
-
- # 按照指定顺序排序
- sorted_props = sorted(
- filtered_props,
- key=lambda x: prop_title_order.index(x.get('prop_title')) if x.get('prop_title') in prop_title_order else len(prop_title_order)
- )
-
- # 更新记录中的props
- record['props'] = sorted_props
-
- return StandardResponse(
- success=True,
- data={
- 'records': result['records'],
- 'pagination': result['pagination']
- }
- )
- except Exception as e:
- logger.error(f"分页查询失败: {str(e)}")
- raise HTTPException(
- status_code=500,
- detail=StandardResponse(
- success=False,
- error_code=500,
- error_msg=str(e)
- )
- )
- @router.post("/nodes", response_model=StandardResponse)
- async def create_node(
- payload: NodeCreateRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.create_node(payload.dict())
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"创建节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.get("/nodes/{node_id}", response_model=StandardResponse)
- async def get_node(
- node_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.get_node(node_id)
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"获取节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.put("/nodes/{node_id}", response_model=StandardResponse)
- async def update_node(
- node_id: int,
- payload: NodeUpdateRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.update_node(node_id, payload.dict(exclude_unset=True))
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"更新节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.post('/trunks/vector_search', response_model=StandardResponse)
- async def vector_search(
- payload: VectorSearchRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.search_by_vector(
- payload.text,
- payload.limit,
- type=payload.type
- )
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"向量搜索失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.get('/trunks/{trunk_id}', response_model=StandardResponse)
- async def get_trunk(
- trunk_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.get_trunk_by_id(trunk_id)
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"获取trunk详情失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.post('/trunks/{trunk_id}/highlight', response_model=StandardResponse)
- async def highlight(
- trunk_id: int,
- targetSentences: List[str],
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.highlight(trunk_id, targetSentences)
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"获取trunk高亮信息失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- saas_kb_router = router
|