knowledge_saas.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from fastapi import APIRouter, Depends, HTTPException
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.trunks_service import TrunksService
  9. import logging
  10. router = APIRouter(tags=["SaaS Knowledge Base"])
  11. logger = logging.getLogger(__name__)
  12. class PaginatedSearchRequest(BaseModel):
  13. keyword: Optional[str] = None
  14. category: Optional[str] = None
  15. distance: Optional[float] = None
  16. pageNo: int = 1
  17. limit: int = 10
  18. knowledge_ids: Optional[List[str]] = None
  19. class NodePaginatedSearchRequest(BaseModel):
  20. name: str
  21. category: Optional[str] = None
  22. pageNo: int = 1
  23. limit: int = 10
  24. class NodeCreateRequest(BaseModel):
  25. name: str
  26. category: str
  27. layout: Optional[str] = None
  28. version: Optional[str] = None
  29. embedding: Optional[List[float]] = None
  30. class NodeUpdateRequest(BaseModel):
  31. layout: Optional[str] = None
  32. version: Optional[str] = None
  33. status: Optional[int] = None
  34. embedding: Optional[List[float]] = None
  35. class VectorSearchRequest(BaseModel):
  36. text: str
  37. limit: int = 10
  38. type: Optional[str] = None
  39. class NodeRelationshipRequest(BaseModel):
  40. src_id: int
  41. @router.post("/kgrt_api/saas/nodes/paginated_search", response_model=StandardResponse)
  42. @router.post("/knowledge/saas/nodes/paginated_search", response_model=StandardResponse)
  43. async def paginated_search(
  44. payload: PaginatedSearchRequest,
  45. db: Session = Depends(get_db)
  46. ):
  47. try:
  48. service = KGNodeService(db)
  49. search_params = {
  50. 'keyword': payload.keyword,
  51. 'category': payload.category,
  52. 'pageNo': payload.pageNo,
  53. 'limit': payload.limit,
  54. 'knowledge_ids': payload.knowledge_ids,
  55. 'distance': payload.distance,
  56. 'load_props': True
  57. }
  58. result = service.paginated_search(search_params)
  59. # 定义prop_title的排序顺序
  60. prop_title_order = [
  61. '基础信息', '概述', '病因学', '流行病学', '发病机制', '病理学',
  62. '临床表现', '辅助检查', '诊断', '鉴别诊断', '并发症', '治疗', '护理', '预后', '预防'
  63. ]
  64. # 处理每个记录的props,过滤并排序
  65. for record in result['records']:
  66. if 'props' in record:
  67. # 只保留指定的prop_title
  68. filtered_props = [prop for prop in record['props'] if prop.get('prop_title') in prop_title_order]
  69. # 按照指定顺序排序
  70. sorted_props = sorted(
  71. filtered_props,
  72. key=lambda x: prop_title_order.index(x.get('prop_title')) if x.get('prop_title') in prop_title_order else len(prop_title_order)
  73. )
  74. # 更新记录中的props
  75. record['props'] = sorted_props
  76. return StandardResponse(
  77. success=True,
  78. data={
  79. 'records': result['records'],
  80. 'pagination': result['pagination']
  81. }
  82. )
  83. except Exception as e:
  84. logger.error(f"分页查询失败: {str(e)}")
  85. raise HTTPException(
  86. status_code=500,
  87. detail=StandardResponse(
  88. success=False,
  89. error_code=500,
  90. error_msg=str(e)
  91. )
  92. )
  93. @router.post("/nodes", response_model=StandardResponse)
  94. async def create_node(
  95. payload: NodeCreateRequest,
  96. db: Session = Depends(get_db)
  97. ):
  98. try:
  99. service = TrunksService()
  100. result = service.create_node(payload.dict())
  101. return StandardResponse(success=True, data=result)
  102. except Exception as e:
  103. logger.error(f"创建节点失败: {str(e)}")
  104. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  105. @router.get("/nodes/{node_id}", response_model=StandardResponse)
  106. async def get_node(
  107. node_id: int,
  108. db: Session = Depends(get_db)
  109. ):
  110. try:
  111. service = TrunksService()
  112. result = service.get_node(node_id)
  113. return StandardResponse(success=True, data=result)
  114. except Exception as e:
  115. logger.error(f"获取节点失败: {str(e)}")
  116. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  117. @router.put("/nodes/{node_id}", response_model=StandardResponse)
  118. async def update_node(
  119. node_id: int,
  120. payload: NodeUpdateRequest,
  121. db: Session = Depends(get_db)
  122. ):
  123. try:
  124. service = TrunksService()
  125. result = service.update_node(node_id, payload.dict(exclude_unset=True))
  126. return StandardResponse(success=True, data=result)
  127. except Exception as e:
  128. logger.error(f"更新节点失败: {str(e)}")
  129. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  130. @router.post('/kgrt_api/saas/trunks/vector_search', response_model=StandardResponse)
  131. @router.post('/knowledge/saas/trunks/vector_search', response_model=StandardResponse)
  132. async def vector_search(
  133. payload: VectorSearchRequest,
  134. db: Session = Depends(get_db)
  135. ):
  136. try:
  137. service = TrunksService()
  138. result = service.search_by_vector(
  139. payload.text,
  140. payload.limit,
  141. type=payload.type
  142. )
  143. return StandardResponse(success=True, data=result)
  144. except Exception as e:
  145. logger.error(f"向量搜索失败: {str(e)}")
  146. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  147. @router.get('/kgrt_api/saas/trunks/{trunk_id}', response_model=StandardResponse)
  148. @router.get('/knowledge/saas/trunks/{trunk_id}', response_model=StandardResponse)
  149. async def get_trunk(
  150. trunk_id: int,
  151. db: Session = Depends(get_db)
  152. ):
  153. try:
  154. service = TrunksService()
  155. result = service.get_trunk_by_id(trunk_id)
  156. return StandardResponse(success=True, data=result)
  157. except Exception as e:
  158. logger.error(f"获取trunk详情失败: {str(e)}")
  159. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  160. @router.post('/kgrt_api/saas/trunks/{trunk_id}/highlight', response_model=StandardResponse)
  161. @router.post('/knowledge/saas/trunks/{trunk_id}/highlight', response_model=StandardResponse)
  162. async def highlight(
  163. trunk_id: int,
  164. targetSentences: List[str],
  165. db: Session = Depends(get_db)
  166. ):
  167. try:
  168. service = TrunksService()
  169. result = service.highlight(trunk_id, targetSentences)
  170. return StandardResponse(success=True, data=result)
  171. except Exception as e:
  172. logger.error(f"获取trunk高亮信息失败: {str(e)}")
  173. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  174. saas_kb_router = router