from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Header, Request from typing import List, Optional from pydantic import BaseModel, Field, validator from model.response import StandardResponse from db.session import get_db from sqlalchemy.orm import Session from service.trunks_service import TrunksService import logging router = APIRouter(prefix="/dify", tags=["Dify Knowledge Base"]) # --- Data Models --- class RetrievalSetting(BaseModel): top_k: int score_threshold: float class MetadataCondition(BaseModel): name: List[str] comparison_operator: str = Field(..., pattern=r'^(equals|not_equals|contains|not_contains|starts_with|ends_with|empty|not_empty|greater_than|less_than)$') value: Optional[str] = Field(None) @validator('value') def validate_value(cls, v, values): operator = values.get('comparison_operator') if operator in ['empty', 'not_empty'] and v is not None: raise ValueError('Value must be None for empty/not_empty operators') if operator not in ['empty', 'not_empty'] and v is None: raise ValueError('Value is required for this comparison operator') return v class MetadataFilter(BaseModel): logical_operator: str = Field(default="and", pattern=r'^(and|or)$') conditions: List[MetadataCondition] = Field(..., min_items=1) @validator('conditions') def validate_conditions(cls, v): if len(v) < 1: raise ValueError('At least one condition is required') return v class DifyRetrievalRequest(BaseModel): knowledge_id: str query: str retrieval_setting: RetrievalSetting metadata_condition: Optional[MetadataFilter] = Field(default=None, exclude=True) class KnowledgeRecord(BaseModel): content: str score: float title: str metadata: dict # --- Authentication --- async def verify_api_key(authorization: str = Header(...)): logger.info(f"Received authorization header: {authorization}") # 新增日志 if not authorization.startswith("Bearer "): raise HTTPException( status_code=403, detail=StandardResponse( success=False, error_code=1001, error_msg="Invalid Authorization header format" ) ) api_key = authorization[7:] # TODO: Implement actual API key validation logic if not api_key: raise HTTPException( status_code=403, detail=StandardResponse( success=False, error_code=1002, error_msg="Authorization failed" ) ) return api_key logger = logging.getLogger(__name__) @router.post("/retrieval", response_model=StandardResponse) async def dify_retrieval( payload: DifyRetrievalRequest, request: Request, authorization: str = Depends(verify_api_key), db: Session = Depends(get_db), conversation_id: Optional[str] = None ): logger.info(f"All headers: {dict(request.headers)}") logger.info(f"Request body: {payload.model_dump()}") try: logger.info(f"Starting retrieval for knowledge base {payload.knowledge_id} with query: {payload.query}") trunks_service = TrunksService() search_results = trunks_service.search_by_vector(payload.query, payload.retrieval_setting.top_k, conversation_id=conversation_id) if not search_results: logger.warning(f"No results found for query: {request.query}") return StandardResponse( success=True, records=[] ) # 格式化返回结果 records = [{ "metadata": { "path": result["file_path"], "description": str(result["id"]) }, "score": result["distance"], "title": result["file_path"].split("/")[-1], "content": result["content"] } for result in search_results] logger.info(f"Retrieval completed successfully for query: {payload.query}") return StandardResponse( success=True, records=records ) except HTTPException as e: logger.error(f"HTTPException occurred: {str(e)}") raise except Exception as e: logger.error(f"Unexpected error occurred: {str(e)}") raise HTTPException( status_code=500, detail=StandardResponse( success=False, error_code=500, error_msg=str(e) ) ) @router.post("/chatflow_retrieval", response_model=StandardResponse) async def dify_chatflow_retrieval( knowledge_id: str, query: str, top_k: int, score_threshold: float, conversation_id: str, request: Request, authorization: str = Depends(verify_api_key), db: Session = Depends(get_db) ): payload = DifyRetrievalRequest( knowledge_id=knowledge_id, query=query, retrieval_setting=RetrievalSetting(top_k=top_k, score_threshold=score_threshold) ) return await dify_retrieval(payload, request, authorization, db, conversation_id=conversation_id) dify_kb_router = router