123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Header, Request
- from typing import List, Optional
- from pydantic import BaseModel, Field, validator
- from model.response import StandardResponse
- from db.session import get_db
- from sqlalchemy.orm import Session
- from service.trunks_service import TrunksService
- import logging
- router = APIRouter(prefix="/dify", tags=["Dify Knowledge Base"])
- # --- Data Models ---
- class RetrievalSetting(BaseModel):
- top_k: int
- score_threshold: float
- class MetadataCondition(BaseModel):
- name: List[str]
- comparison_operator: str = Field(..., pattern=r'^(equals|not_equals|contains|not_contains|starts_with|ends_with|empty|not_empty|greater_than|less_than)$')
- value: Optional[str] = Field(None)
- @validator('value')
- def validate_value(cls, v, values):
- operator = values.get('comparison_operator')
- if operator in ['empty', 'not_empty'] and v is not None:
- raise ValueError('Value must be None for empty/not_empty operators')
- if operator not in ['empty', 'not_empty'] and v is None:
- raise ValueError('Value is required for this comparison operator')
- return v
- class MetadataFilter(BaseModel):
- logical_operator: str = Field(default="and", pattern=r'^(and|or)$')
- conditions: List[MetadataCondition] = Field(..., min_items=1)
- @validator('conditions')
- def validate_conditions(cls, v):
- if len(v) < 1:
- raise ValueError('At least one condition is required')
- return v
- class DifyRetrievalRequest(BaseModel):
- knowledge_id: str
- query: str
- retrieval_setting: RetrievalSetting
- metadata_condition: Optional[MetadataFilter] = Field(default=None, exclude=True)
- class KnowledgeRecord(BaseModel):
- content: str
- score: float
- title: str
- metadata: dict
- # --- Authentication ---
- async def verify_api_key(authorization: str = Header(...)):
- logger.info(f"Received authorization header: {authorization}") # 新增日志
- if not authorization.startswith("Bearer "):
- raise HTTPException(
- status_code=403,
- detail=StandardResponse(
- success=False,
- error_code=1001,
- error_msg="Invalid Authorization header format"
- )
- )
- api_key = authorization[7:]
- # TODO: Implement actual API key validation logic
- if not api_key:
- raise HTTPException(
- status_code=403,
- detail=StandardResponse(
- success=False,
- error_code=1002,
- error_msg="Authorization failed"
- )
- )
- return api_key
- logger = logging.getLogger(__name__)
- @router.post("/retrieval", response_model=StandardResponse)
- async def dify_retrieval(
- payload: DifyRetrievalRequest,
- request: Request,
- authorization: str = Depends(verify_api_key),
- db: Session = Depends(get_db)
- ):
- logger.info(f"All headers: {dict(request.headers)}")
- logger.info(f"Request body: {payload.model_dump()}")
-
- try:
- logger.info(f"Starting retrieval for knowledge base {payload.knowledge_id} with query: {payload.query}")
-
- trunks_service = TrunksService()
- search_results = trunks_service.search_by_vector(payload.query, payload.retrieval_setting.top_k)
-
- if not search_results:
- logger.warning(f"No results found for query: {request.query}")
- return StandardResponse(
- success=True,
- records=[]
- )
-
- # 格式化返回结果
- records = [{
- "metadata": {
- "path": result["file_path"],
- "description": str(result["id"])
- },
- "score": result["distance"],
- "title": result["file_path"].split("/")[-1],
- "content": result["content"]
- } for result in search_results]
-
- logger.info(f"Retrieval completed successfully for query: {payload.query}")
- return StandardResponse(
- success=True,
- records=records
- )
- except HTTPException as e:
- logger.error(f"HTTPException occurred: {str(e)}")
- raise
- except Exception as e:
- logger.error(f"Unexpected error occurred: {str(e)}")
- raise HTTPException(
- status_code=500,
- detail=StandardResponse(
- success=False,
- error_code=500,
- error_msg=str(e)
- )
- )
- dify_kb_router = router
|