knowledge_saas.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from fastapi import APIRouter, Depends, HTTPException
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.trunks_service import TrunksService
  9. from service.kg_edge_service import KGEdgeService
  10. from service.kg_prop_service import KGPropService
  11. import logging
  12. from utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
  13. router = APIRouter(tags=["SaaS Knowledge Base"])
  14. logger = logging.getLogger(__name__)
  15. class PaginatedSearchRequest(BaseModel):
  16. keyword: Optional[str] = None
  17. category: Optional[str] = None
  18. pageNo: int = 1
  19. limit: int = 10
  20. knowledge_ids: Optional[List[str]] = None
  21. class NodePaginatedSearchRequest(BaseModel):
  22. name: str
  23. category: Optional[str] = None
  24. pageNo: int = 1
  25. limit: int = 10
  26. class NodeCreateRequest(BaseModel):
  27. name: str
  28. category: str
  29. layout: Optional[str] = None
  30. version: Optional[str] = None
  31. embedding: Optional[List[float]] = None
  32. class NodeUpdateRequest(BaseModel):
  33. layout: Optional[str] = None
  34. version: Optional[str] = None
  35. status: Optional[int] = None
  36. embedding: Optional[List[float]] = None
  37. class VectorSearchRequest(BaseModel):
  38. text: str
  39. limit: int = 10
  40. type: Optional[str] = None
  41. class NodeRelationshipRequest(BaseModel):
  42. src_id: int
  43. @router.post("/kgrt_api/saas/nodes/paginated_search", response_model=StandardResponse)
  44. @router.post("/knowledge/saas/nodes/paginated_search", response_model=StandardResponse)
  45. async def paginated_search(
  46. payload: PaginatedSearchRequest,
  47. db: Session = Depends(get_db)
  48. ):
  49. try:
  50. service = KGNodeService(db)
  51. search_params = {
  52. 'keyword': payload.keyword,
  53. 'category': payload.category,
  54. 'pageNo': payload.pageNo,
  55. 'limit': payload.limit,
  56. 'knowledge_ids': payload.knowledge_ids,
  57. 'load_props': True
  58. }
  59. result = service.paginated_search(search_params)
  60. # 定义prop_title的排序顺序
  61. prop_title_order = [
  62. '基础信息', '概述', '病因学', '流行病学', '发病机制', '病理学',
  63. '临床表现', '辅助检查', '诊断', '鉴别诊断', '并发症', '治疗', '护理', '预后', '预防'
  64. ]
  65. # 处理每个记录的props,过滤并排序
  66. for record in result['records']:
  67. if 'props' in record:
  68. # 只保留指定的prop_title
  69. filtered_props = [prop for prop in record['props'] if prop.get('prop_title') in prop_title_order]
  70. # 按照指定顺序排序
  71. sorted_props = sorted(
  72. filtered_props,
  73. key=lambda x: prop_title_order.index(x.get('prop_title')) if x.get('prop_title') in prop_title_order else len(prop_title_order)
  74. )
  75. # 更新记录中的props
  76. record['props'] = sorted_props
  77. return StandardResponse(
  78. success=True,
  79. data={
  80. 'records': result['records'],
  81. 'pagination': result['pagination']
  82. }
  83. )
  84. except Exception as e:
  85. logger.error(f"分页查询失败: {str(e)}")
  86. raise HTTPException(
  87. status_code=500,
  88. detail=StandardResponse(
  89. success=False,
  90. error_code=500,
  91. error_msg=str(e)
  92. )
  93. )
  94. @router.post("/nodes", response_model=StandardResponse)
  95. async def create_node(
  96. payload: NodeCreateRequest,
  97. db: Session = Depends(get_db)
  98. ):
  99. try:
  100. service = TrunksService()
  101. result = service.create_node(payload.dict())
  102. return StandardResponse(success=True, data=result)
  103. except Exception as e:
  104. logger.error(f"创建节点失败: {str(e)}")
  105. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  106. @router.get("/nodes/{node_id}", response_model=StandardResponse)
  107. async def get_node(
  108. node_id: int,
  109. db: Session = Depends(get_db)
  110. ):
  111. try:
  112. service = TrunksService()
  113. result = service.get_node(node_id)
  114. return StandardResponse(success=True, data=result)
  115. except Exception as e:
  116. logger.error(f"获取节点失败: {str(e)}")
  117. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  118. @router.put("/nodes/{node_id}", response_model=StandardResponse)
  119. async def update_node(
  120. node_id: int,
  121. payload: NodeUpdateRequest,
  122. db: Session = Depends(get_db)
  123. ):
  124. try:
  125. service = TrunksService()
  126. result = service.update_node(node_id, payload.dict(exclude_unset=True))
  127. return StandardResponse(success=True, data=result)
  128. except Exception as e:
  129. logger.error(f"更新节点失败: {str(e)}")
  130. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  131. @router.post('/kgrt_api/saas/trunks/vector_search', response_model=StandardResponse)
  132. @router.post('/knowledge/saas/trunks/vector_search', response_model=StandardResponse)
  133. async def vector_search(
  134. payload: VectorSearchRequest,
  135. db: Session = Depends(get_db)
  136. ):
  137. try:
  138. service = TrunksService()
  139. result = service.search_by_vector(
  140. payload.text,
  141. payload.limit,
  142. type=payload.type
  143. )
  144. return StandardResponse(success=True, data=result)
  145. except Exception as e:
  146. logger.error(f"向量搜索失败: {str(e)}")
  147. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  148. @router.get('/kgrt_api/saas/trunks/{trunk_id}', response_model=StandardResponse)
  149. @router.get('/knowledge/saas/trunks/{trunk_id}', response_model=StandardResponse)
  150. async def get_trunk(
  151. trunk_id: int,
  152. db: Session = Depends(get_db)
  153. ):
  154. try:
  155. service = TrunksService()
  156. result = service.get_trunk_by_id(trunk_id)
  157. return StandardResponse(success=True, data=result)
  158. except Exception as e:
  159. logger.error(f"获取trunk详情失败: {str(e)}")
  160. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  161. @router.post('/kgrt_api/saas/trunks/{trunk_id}/highlight', response_model=StandardResponse)
  162. @router.post('/knowledge/saas/trunks/{trunk_id}/highlight', response_model=StandardResponse)
  163. async def highlight(
  164. trunk_id: int,
  165. targetSentences: List[str],
  166. db: Session = Depends(get_db)
  167. ):
  168. try:
  169. service = TrunksService()
  170. result = service.highlight(trunk_id, targetSentences)
  171. return StandardResponse(success=True, data=result)
  172. except Exception as e:
  173. logger.error(f"获取trunk高亮信息失败: {str(e)}")
  174. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  175. saas_kb_router = router