dify_kb.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Header
  2. from typing import List, Optional
  3. from pydantic import BaseModel
  4. from models.response import StandardResponse
  5. from db.database import get_db
  6. from sqlalchemy.orm import Session
  7. from sqlalchemy import text
  8. from db.schemas import *
  9. from db.models import *
  10. from utils.response import resp_200
  11. router = APIRouter(prefix="/dify", tags=["Dify Knowledge Base"])
  12. # --- Data Models ---
  13. class RetrievalSetting(BaseModel):
  14. top_k: int
  15. score_threshold: float
  16. class MetadataCondition(BaseModel):
  17. name: List[str]
  18. comparison_operator: str
  19. value: Optional[str] = None
  20. class MetadataFilter(BaseModel):
  21. logical_operator: str = "and"
  22. conditions: List[MetadataCondition]
  23. class DifyRetrievalRequest(BaseModel):
  24. knowledge_id: str
  25. query: str
  26. retrieval_setting: RetrievalSetting
  27. metadata_condition: Optional[MetadataFilter] = None
  28. class KnowledgeRecord(BaseModel):
  29. content: str
  30. score: float
  31. title: str
  32. metadata: dict
  33. # --- Authentication ---
  34. async def verify_api_key(authorization: str = Header(...)):
  35. if not authorization.startswith("Bearer "):
  36. raise HTTPException(
  37. status_code=403,
  38. detail=StandardResponse(
  39. success=False,
  40. error_code=1001,
  41. error_msg="Invalid Authorization header format"
  42. )
  43. )
  44. api_key = authorization[7:]
  45. # TODO: Implement actual API key validation logic
  46. if not api_key:
  47. raise HTTPException(
  48. status_code=403,
  49. detail=StandardResponse(
  50. success=False,
  51. error_code=1002,
  52. error_msg="Authorization failed"
  53. )
  54. )
  55. return api_key
  56. @router.post("/retrieval", response_model=StandardResponse)
  57. async def dify_retrieval(
  58. request: DifyRetrievalRequest,
  59. api_key: str = Depends(verify_api_key),
  60. db: Session = Depends(get_db)
  61. ):
  62. """
  63. 实现Dify外部知识库检索接口
  64. """
  65. print("dify_retrieval start")
  66. try:
  67. # 检查知识库是否存在
  68. kb_exists = db.query(DbKgGraphs).filter(DbKgGraphs.id == request.knowledge_id).count()
  69. if kb_exists == 0:
  70. raise HTTPException(
  71. status_code=404,
  72. detail=StandardResponse(
  73. success=False,
  74. error_code=2001,
  75. error_msg="The knowledge does not exist"
  76. )
  77. )
  78. print("知识库存在")
  79. # 构建基础查询
  80. query = """
  81. select id,name,category,version from kg_nodes as node where node.graph_id = :graph_id and node.name = :node_name
  82. """
  83. # 添加元数据过滤条件
  84. if request.metadata_condition:
  85. conditions = []
  86. for cond in request.metadata_condition.conditions:
  87. operator_map = {
  88. "contains": "CONTAINS",
  89. "not contains": "NOT CONTAINS",
  90. "start with": "STARTS WITH",
  91. "end with": "ENDS WITH",
  92. "is": "=",
  93. "is not": "<>",
  94. "empty": "IS NULL",
  95. "not empty": "IS NOT NULL",
  96. ">": ">",
  97. "<": "<",
  98. "≥": ">=",
  99. "≤": "<=",
  100. "before": "<",
  101. "after": ">"
  102. }
  103. cypher_op = operator_map.get(cond.comparison_operator, "=")
  104. for field in cond.name:
  105. if cond.comparison_operator in ["empty", "not empty"]:
  106. conditions.append(f"node.{field} {cypher_op}")
  107. else:
  108. conditions.append(
  109. f"node.{field} {cypher_op} ${field}_value"
  110. )
  111. where_clause = " AND ".join(conditions)
  112. query += f" AND {where_clause}"
  113. query += """
  114. ORDER BY node.name DESC
  115. LIMIT :top_k
  116. """
  117. params = {'graph_id': request.knowledge_id, 'node_name': request.query,'top_k':request.retrieval_setting.top_k}
  118. supply_params = {f"{cond.name}_value": cond.value for cond in request.metadata_condition.conditions} if request.metadata_condition else {}
  119. # 执行查询
  120. result = db.execute(text(query),
  121. {
  122. **params,
  123. **supply_params
  124. } )
  125. data_returned = []
  126. for record in result:
  127. id,name,category,version = record
  128. doc_node = {
  129. "id": id,
  130. "name": name,
  131. "category": category,
  132. "version": version
  133. }
  134. data_returned.append({
  135. "content": "",
  136. "score": float(1.0),
  137. "title": doc_node.get("name", "Untitled"),
  138. "metadata": {
  139. "id": doc_node.get("id", 0),
  140. "category": doc_node.get("category", "")
  141. }
  142. })
  143. for data in data_returned:
  144. id = data['metadata']['id']
  145. result = db.execute(text("""
  146. select prop_title, prop_value from kg_props as prop where prop.category=1 and prop.ref_id =:node_id
  147. """),{'node_id':id})
  148. content = []
  149. for record in result:
  150. prop_title, prop_value = record
  151. content.append(f"{prop_title}:{prop_value}\n")
  152. data['content']="\n".join(content)
  153. response_data = StandardResponse(
  154. success=True,
  155. records=data_returned
  156. )
  157. return response_data
  158. except HTTPException:
  159. raise
  160. except Exception as e:
  161. print(e)
  162. raise HTTPException(
  163. status_code=500,
  164. detail=StandardResponse(
  165. success=False,
  166. error_code=500,
  167. error_msg=str(e)
  168. )
  169. )
  170. dify_kb_router = router