knowledge_saas.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from fastapi import APIRouter, Depends, HTTPException
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.trunks_service import TrunksService
  9. from service.kg_edge_service import KGEdgeService
  10. from service.kg_prop_service import KGPropService
  11. import logging
  12. from utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
  13. router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
  14. logger = logging.getLogger(__name__)
  15. class PaginatedSearchRequest(BaseModel):
  16. keyword: Optional[str] = None
  17. pageNo: int = 1
  18. limit: int = 10
  19. knowledge_ids: Optional[List[str]] = None
  20. class NodePaginatedSearchRequest(BaseModel):
  21. name: str
  22. category: Optional[str] = None
  23. pageNo: int = 1
  24. limit: int = 10
  25. class NodeCreateRequest(BaseModel):
  26. name: str
  27. category: str
  28. layout: Optional[str] = None
  29. version: Optional[str] = None
  30. embedding: Optional[List[float]] = None
  31. class NodeUpdateRequest(BaseModel):
  32. layout: Optional[str] = None
  33. version: Optional[str] = None
  34. status: Optional[int] = None
  35. embedding: Optional[List[float]] = None
  36. class VectorSearchRequest(BaseModel):
  37. text: str
  38. limit: int = 10
  39. type: Optional[str] = None
  40. class NodeRelationshipRequest(BaseModel):
  41. src_id: int
  42. @router.post("/nodes/paginated_search", response_model=StandardResponse)
  43. async def paginated_search(
  44. payload: PaginatedSearchRequest,
  45. db: Session = Depends(get_db)
  46. ):
  47. try:
  48. service = KGNodeService(db)
  49. search_params = {
  50. 'keyword': payload.keyword,
  51. 'pageNo': payload.pageNo,
  52. 'limit': payload.limit,
  53. 'knowledge_ids': payload.knowledge_ids,
  54. 'load_props': True
  55. }
  56. result = service.paginated_search(search_params)
  57. # 定义prop_title的排序顺序
  58. prop_title_order = [
  59. '基础信息', '概述', '病因学', '流行病学', '发病机制', '病理学',
  60. '临床表现', '辅助检查', '诊断', '鉴别诊断', '并发症', '治疗', '护理', '预后', '预防'
  61. ]
  62. # 处理每个记录的props,过滤并排序
  63. for record in result['records']:
  64. if 'props' in record:
  65. # 只保留指定的prop_title
  66. filtered_props = [prop for prop in record['props'] if prop.get('prop_title') in prop_title_order]
  67. # 按照指定顺序排序
  68. sorted_props = sorted(
  69. filtered_props,
  70. key=lambda x: prop_title_order.index(x.get('prop_title')) if x.get('prop_title') in prop_title_order else len(prop_title_order)
  71. )
  72. # 更新记录中的props
  73. record['props'] = sorted_props
  74. return StandardResponse(
  75. success=True,
  76. data={
  77. 'records': result['records'],
  78. 'pagination': result['pagination']
  79. }
  80. )
  81. except Exception as e:
  82. logger.error(f"分页查询失败: {str(e)}")
  83. raise HTTPException(
  84. status_code=500,
  85. detail=StandardResponse(
  86. success=False,
  87. error_code=500,
  88. error_msg=str(e)
  89. )
  90. )
  91. @router.post("/nodes", response_model=StandardResponse)
  92. async def create_node(
  93. payload: NodeCreateRequest,
  94. db: Session = Depends(get_db)
  95. ):
  96. try:
  97. service = TrunksService()
  98. result = service.create_node(payload.dict())
  99. return StandardResponse(success=True, data=result)
  100. except Exception as e:
  101. logger.error(f"创建节点失败: {str(e)}")
  102. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  103. @router.get("/nodes/{node_id}", response_model=StandardResponse)
  104. async def get_node(
  105. node_id: int,
  106. db: Session = Depends(get_db)
  107. ):
  108. try:
  109. service = TrunksService()
  110. result = service.get_node(node_id)
  111. return StandardResponse(success=True, data=result)
  112. except Exception as e:
  113. logger.error(f"获取节点失败: {str(e)}")
  114. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  115. @router.put("/nodes/{node_id}", response_model=StandardResponse)
  116. async def update_node(
  117. node_id: int,
  118. payload: NodeUpdateRequest,
  119. db: Session = Depends(get_db)
  120. ):
  121. try:
  122. service = TrunksService()
  123. result = service.update_node(node_id, payload.dict(exclude_unset=True))
  124. return StandardResponse(success=True, data=result)
  125. except Exception as e:
  126. logger.error(f"更新节点失败: {str(e)}")
  127. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  128. @router.delete("/nodes/{node_id}", response_model=StandardResponse)
  129. async def delete_node(
  130. node_id: int,
  131. db: Session = Depends(get_db)
  132. ):
  133. try:
  134. service = TrunksService()
  135. service.delete_node(node_id)
  136. return StandardResponse(success=True)
  137. except Exception as e:
  138. logger.error(f"删除节点失败: {str(e)}")
  139. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  140. @router.post('/trunks/vector_search', response_model=StandardResponse)
  141. async def vector_search(
  142. payload: VectorSearchRequest,
  143. db: Session = Depends(get_db)
  144. ):
  145. try:
  146. service = TrunksService()
  147. result = service.search_by_vector(
  148. payload.text,
  149. payload.limit,
  150. {'type': payload.type} if payload.type else None
  151. )
  152. return StandardResponse(success=True, data=result)
  153. except Exception as e:
  154. logger.error(f"向量搜索失败: {str(e)}")
  155. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  156. @router.get('/trunks/{trunk_id}', response_model=StandardResponse)
  157. async def get_trunk(
  158. trunk_id: int,
  159. db: Session = Depends(get_db)
  160. ):
  161. try:
  162. service = TrunksService()
  163. result = service.get_trunk_by_id(trunk_id)
  164. return StandardResponse(success=True, data=result)
  165. except Exception as e:
  166. logger.error(f"获取trunk详情失败: {str(e)}")
  167. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  168. saas_kb_router = router