knowledge_dify.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Header, Request
  2. from typing import List, Optional
  3. from pydantic import BaseModel, Field, validator
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.trunks_service import TrunksService
  8. import logging
  9. router = APIRouter(prefix="/dify", tags=["Dify Knowledge Base"])
  10. # --- Data Models ---
  11. class RetrievalSetting(BaseModel):
  12. top_k: int
  13. score_threshold: float
  14. class MetadataCondition(BaseModel):
  15. name: List[str]
  16. comparison_operator: str = Field(..., pattern=r'^(equals|not_equals|contains|not_contains|starts_with|ends_with|empty|not_empty|greater_than|less_than)$')
  17. value: Optional[str] = Field(None)
  18. @validator('value')
  19. def validate_value(cls, v, values):
  20. operator = values.get('comparison_operator')
  21. if operator in ['empty', 'not_empty'] and v is not None:
  22. raise ValueError('Value must be None for empty/not_empty operators')
  23. if operator not in ['empty', 'not_empty'] and v is None:
  24. raise ValueError('Value is required for this comparison operator')
  25. return v
  26. class MetadataFilter(BaseModel):
  27. logical_operator: str = Field(default="and", pattern=r'^(and|or)$')
  28. conditions: List[MetadataCondition] = Field(..., min_items=1)
  29. @validator('conditions')
  30. def validate_conditions(cls, v):
  31. if len(v) < 1:
  32. raise ValueError('At least one condition is required')
  33. return v
  34. class DifyRetrievalRequest(BaseModel):
  35. knowledge_id: str
  36. query: str
  37. retrieval_setting: RetrievalSetting
  38. metadata_condition: Optional[MetadataFilter] = Field(default=None, exclude=True)
  39. class KnowledgeRecord(BaseModel):
  40. content: str
  41. score: float
  42. title: str
  43. metadata: dict
  44. # --- Authentication ---
  45. async def verify_api_key(authorization: str = Header(...)):
  46. logger.info(f"Received authorization header: {authorization}") # 新增日志
  47. if not authorization.startswith("Bearer "):
  48. raise HTTPException(
  49. status_code=403,
  50. detail=StandardResponse(
  51. success=False,
  52. error_code=1001,
  53. error_msg="Invalid Authorization header format"
  54. )
  55. )
  56. api_key = authorization[7:]
  57. # TODO: Implement actual API key validation logic
  58. if not api_key:
  59. raise HTTPException(
  60. status_code=403,
  61. detail=StandardResponse(
  62. success=False,
  63. error_code=1002,
  64. error_msg="Authorization failed"
  65. )
  66. )
  67. return api_key
  68. logger = logging.getLogger(__name__)
  69. @router.post("/retrieval", response_model=StandardResponse)
  70. async def dify_retrieval(
  71. payload: DifyRetrievalRequest,
  72. request: Request,
  73. authorization: str = Depends(verify_api_key),
  74. db: Session = Depends(get_db),
  75. conversation_id: Optional[str] = None
  76. ):
  77. logger.info(f"All headers: {dict(request.headers)}")
  78. logger.info(f"Request body: {payload.model_dump()}")
  79. try:
  80. logger.info(f"Starting retrieval for knowledge base {payload.knowledge_id} with query: {payload.query}")
  81. trunks_service = TrunksService()
  82. search_results = trunks_service.search_by_vector(payload.query, payload.retrieval_setting.top_k)
  83. if not search_results:
  84. logger.warning(f"No results found for query: {request.query}")
  85. return StandardResponse(
  86. success=True,
  87. records=[]
  88. )
  89. # 格式化返回结果
  90. records = [{
  91. "metadata": {
  92. "path": result["file_path"],
  93. "description": str(result["id"])
  94. },
  95. "score": result["distance"],
  96. "title": result["file_path"].split("/")[-1],
  97. "content": result["content"]
  98. } for result in search_results]
  99. logger.info(f"Retrieval completed successfully for query: {payload.query}")
  100. return StandardResponse(
  101. success=True,
  102. records=records
  103. )
  104. except HTTPException as e:
  105. logger.error(f"HTTPException occurred: {str(e)}")
  106. raise
  107. except Exception as e:
  108. logger.error(f"Unexpected error occurred: {str(e)}")
  109. raise HTTPException(
  110. status_code=500,
  111. detail=StandardResponse(
  112. success=False,
  113. error_code=500,
  114. error_msg=str(e)
  115. )
  116. )
  117. @router.post("/chatflow_retrieval", response_model=StandardResponse)
  118. async def dify_chatflow_retrieval(
  119. knowledge_id: str,
  120. query: str,
  121. top_k: int,
  122. score_threshold: float,
  123. conversation_id: str,
  124. request: Request,
  125. authorization: str = Depends(verify_api_key),
  126. db: Session = Depends(get_db)
  127. ):
  128. payload = DifyRetrievalRequest(
  129. knowledge_id=knowledge_id,
  130. query=query,
  131. retrieval_setting=RetrievalSetting(top_k=top_k, score_threshold=score_threshold)
  132. )
  133. return await dify_retrieval(payload, request, authorization, db)
  134. dify_kb_router = router