dify_kb.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Header
  2. from typing import List, Optional
  3. from pydantic import BaseModel
  4. from models.response import StandardResponse
  5. from db.database import get_db
  6. from sqlalchemy.orm import Session
  7. from sqlalchemy import text
  8. router = APIRouter(prefix="/dify", tags=["Dify Knowledge Base"])
  9. # --- Data Models ---
  10. class RetrievalSetting(BaseModel):
  11. top_k: int
  12. score_threshold: float
  13. class MetadataCondition(BaseModel):
  14. name: List[str]
  15. comparison_operator: str
  16. value: Optional[str] = None
  17. class MetadataFilter(BaseModel):
  18. logical_operator: str = "and"
  19. conditions: List[MetadataCondition]
  20. class DifyRetrievalRequest(BaseModel):
  21. knowledge_id: str
  22. query: str
  23. retrieval_setting: RetrievalSetting
  24. metadata_condition: Optional[MetadataFilter] = None
  25. class KnowledgeRecord(BaseModel):
  26. content: str
  27. score: float
  28. title: str
  29. metadata: dict
  30. # --- Authentication ---
  31. async def verify_api_key(authorization: str = Header(...)):
  32. if not authorization.startswith("Bearer "):
  33. raise HTTPException(
  34. status_code=403,
  35. detail=StandardResponse(
  36. success=False,
  37. error_code=1001,
  38. error_msg="Invalid Authorization header format"
  39. )
  40. )
  41. api_key = authorization[7:]
  42. # TODO: Implement actual API key validation logic
  43. if not api_key:
  44. raise HTTPException(
  45. status_code=403,
  46. detail=StandardResponse(
  47. success=False,
  48. error_code=1002,
  49. error_msg="Authorization failed"
  50. )
  51. )
  52. return api_key
  53. @router.post("/retrieval", response_model=StandardResponse)
  54. async def dify_retrieval(
  55. request: DifyRetrievalRequest,
  56. api_key: str = Depends(verify_api_key),
  57. db: Session = Depends(get_db)
  58. ):
  59. """
  60. 实现Dify外部知识库检索接口
  61. """
  62. print("dify_retrieval start")
  63. try:
  64. # 检查知识库是否存在
  65. result = db.execute(text("select 1 from kg_graphs where id = :graph_id"), {"graph_id": request.knowledge_id})
  66. kb_exists = result.scalar()
  67. print(kb_exists)
  68. if kb_exists == 0:
  69. raise HTTPException(
  70. status_code=404,
  71. detail=StandardResponse(
  72. success=False,
  73. error_code=2001,
  74. error_msg="The knowledge does not exist"
  75. )
  76. )
  77. print("知识库存在")
  78. # 构建基础查询
  79. query = """
  80. select id,name,category,version from kg_nodes as node where node.graph_id = :graph_id and node.name = :node_name
  81. """
  82. # 添加元数据过滤条件
  83. if request.metadata_condition:
  84. conditions = []
  85. for cond in request.metadata_condition.conditions:
  86. operator_map = {
  87. "contains": "CONTAINS",
  88. "not contains": "NOT CONTAINS",
  89. "start with": "STARTS WITH",
  90. "end with": "ENDS WITH",
  91. "is": "=",
  92. "is not": "<>",
  93. "empty": "IS NULL",
  94. "not empty": "IS NOT NULL",
  95. ">": ">",
  96. "<": "<",
  97. "≥": ">=",
  98. "≤": "<=",
  99. "before": "<",
  100. "after": ">"
  101. }
  102. cypher_op = operator_map.get(cond.comparison_operator, "=")
  103. for field in cond.name:
  104. if cond.comparison_operator in ["empty", "not empty"]:
  105. conditions.append(f"node.{field} {cypher_op}")
  106. else:
  107. conditions.append(
  108. f"node.{field} {cypher_op} ${field}_value"
  109. )
  110. where_clause = " AND ".join(conditions)
  111. query += f" AND {where_clause}"
  112. query += """
  113. ORDER BY node.name DESC
  114. LIMIT :top_k
  115. """
  116. params = {'graph_id': request.knowledge_id, 'node_name': request.query,'top_k':request.retrieval_setting.top_k}
  117. supply_params = {f"{cond.name}_value": cond.value for cond in request.metadata_condition.conditions} if request.metadata_condition else {}
  118. # 执行查询
  119. result = db.execute(text(query),
  120. {
  121. **params,
  122. **supply_params
  123. } )
  124. data_returned = []
  125. for record in result:
  126. id,name,category,version = record
  127. doc_node = {
  128. "id": id,
  129. "name": name,
  130. "category": category,
  131. "version": version
  132. }
  133. data_returned.append({
  134. "content": "",
  135. "score": float(1.0),
  136. "title": doc_node.get("name", "Untitled"),
  137. "metadata": {
  138. "id": doc_node.get("id", 0),
  139. "category": doc_node.get("category", "")
  140. }
  141. })
  142. for data in data_returned:
  143. id = data['metadata']['id']
  144. result = db.execute(text("""
  145. select prop_title, prop_value from kg_props as prop where prop.category=1 and prop.ref_id =:node_id
  146. """),{'node_id':id})
  147. content = []
  148. for record in result:
  149. prop_title, prop_value = record
  150. content.append(f"{prop_title}:{prop_value}\n")
  151. data['content']="\n".join(content)
  152. response_data = StandardResponse(
  153. success=True,
  154. records=data_returned
  155. )
  156. return response_data
  157. except HTTPException:
  158. raise
  159. except Exception as e:
  160. print(e)
  161. raise HTTPException(
  162. status_code=500,
  163. detail=StandardResponse(
  164. success=False,
  165. error_code=500,
  166. error_msg=str(e)
  167. )
  168. )
  169. dify_kb_router = router