main.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import logging
  2. import uuid
  3. from logging.handlers import RotatingFileHandler
  4. from fastapi import FastAPI, Request, Response, status
  5. from typing import Optional, Set
  6. # 导入FastAPI及相关模块
  7. import os
  8. import uvicorn
  9. from fastapi.staticfiles import StaticFiles
  10. from fastapi.middleware.cors import CORSMiddleware
  11. # from agent.cdss.capbility import CDSSCapability
  12. from router.knowledge_dify import dify_kb_router
  13. from router.knowledge_saas import saas_kb_router
  14. from router.text_search import text_search_router
  15. from router.graph_router import graph_router
  16. # from router.knowledge_nodes_api import knowledge_nodes_api_router
  17. # 配置日志
  18. logging.basicConfig(
  19. level=logging.ERROR,
  20. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  21. handlers=[
  22. logging.StreamHandler(),
  23. RotatingFileHandler('app.log', maxBytes=10485760, backupCount=5, encoding='utf-8')
  24. ]
  25. )
  26. logger = logging.getLogger(__name__)
  27. logger.propagate = True
  28. # 创建FastAPI应用
  29. app = FastAPI(title="知识图谱")
  30. app.include_router(dify_kb_router)
  31. app.include_router(saas_kb_router)
  32. app.include_router(text_search_router)
  33. app.include_router(graph_router)
  34. # app.include_router(knowledge_nodes_api_router)
  35. # 挂载静态文件目录,将/books路径映射到本地books文件夹
  36. app.mount("/books", StaticFiles(directory="books"), name="books")
  37. # 允许所有来源(仅用于测试,生产环境应限制)
  38. app.add_middleware(
  39. CORSMiddleware,
  40. allow_origins=["*"], # 允许所有来源(或指定 ["http://localhost:3000"])
  41. allow_credentials=True, # 允许携带 Cookie
  42. allow_methods=["*"], # 允许所有方法(或指定 ["GET", "POST"])
  43. allow_headers=["*"], # 允许所有请求头
  44. )
  45. # 需要拦截的 URL 列表(支持通配符)
  46. INTERCEPT_URLS = {
  47. "/v1/knowledge/*"
  48. }
  49. # 白名单 URL(不需要拦截的路径)
  50. WHITE_LIST = {
  51. "/api/public",
  52. "/admin/login"
  53. }
  54. async def verify_token(authorization: str) -> Optional[dict]:
  55. """
  56. 验证 token 有效性
  57. 返回:验证成功返回用户信息字典,失败返回 None
  58. """
  59. if not authorization.startswith("Bearer "):
  60. return None
  61. token = authorization[7:]
  62. # 这里添加实际的 token 验证逻辑
  63. # 示例:简单验证 token 是否等于 secret-token
  64. if token == "secret-token":
  65. return {"id": 1, "username": "admin", "role": "admin"}
  66. return None
  67. def should_intercept(path: str) -> bool:
  68. """
  69. 判断是否需要拦截当前路径
  70. """
  71. if path in WHITE_LIST:
  72. return False
  73. for pattern in INTERCEPT_URLS:
  74. # 处理通配符匹配
  75. if pattern.endswith("/*"):
  76. if path.startswith(pattern[:-1]):
  77. return True
  78. # 精确匹配
  79. elif path == pattern:
  80. return True
  81. return False
  82. @app.middleware("http")
  83. async def interceptor_middleware(request: Request, call_next):
  84. path = request.url.path
  85. if not should_intercept(path):
  86. return await call_next(request)
  87. # 权限校验
  88. auth_header = request.headers.get("Authorization")
  89. if not auth_header:
  90. return Response(
  91. content="Missing Authorization header",
  92. status_code=status.HTTP_401_UNAUTHORIZED
  93. )
  94. user_info = await verify_token(auth_header)
  95. if not user_info:
  96. return Response(
  97. content="Invalid token",
  98. status_code=status.HTTP_401_UNAUTHORIZED
  99. )
  100. # 初始化操作:将用户信息添加到请求状态中
  101. request.state.user = user_info
  102. # 添加请求上下文(示例)
  103. request.state.context = {
  104. "request_id": request.headers.get("request-id", str(uuid.uuid4())),
  105. "client_ip": request.client.host
  106. }
  107. # 继续处理请求
  108. response = await call_next(request)
  109. # 可以在返回前添加统一响应处理(如添加头信息)
  110. response.headers["request-id"]=request.state.context["request_id"]
  111. return response
  112. #capability = CDSSCapability()
  113. if __name__ == "__main__":
  114. logger.info('Starting uvicorn server...2222')
  115. #uvicorn main:app --host 0.0.0.0 --port 8000 --reload
  116. uvicorn.run("main:app", host="0.0.0.0", port=8001, reload=False)