Browse Source

代码提交

SGTY 2 weeks ago
parent
commit
9a86e1057c

+ 8 - 28
build/lib/knowledge/config/site.py

@@ -11,34 +11,14 @@ class SiteConfig:
     
     def load_config(self):        
         self.config = {
-            "SITE_NAME": os.getenv("SITE_NAME", "DEMO"),
-            "SITE_DESCRIPTION": os.getenv("SITE_DESCRIPTION", "ChatGPT"),
-            "SITE_URL": os.getenv("SITE_URL", ""),
-            "SITE_LOGO": os.getenv("SITE_LOGO", ""),
-            "SITE_FAVICON": os.getenv("SITE_FAVICON"),
-            'ELASTICSEARCH_HOST': os.getenv("ELASTICSEARCH_HOST"),
-            'ELASTICSEARCH_USER': os.getenv("ELASTICSEARCH_USER"),
-            'ELASTICSEARCH_PWD': os.getenv("ELASTICSEARCH_PWD"),
-            'WORD_INDEX': os.getenv("WORD_INDEX"),
-            'TITLE_INDEX': os.getenv("TITLE_INDEX"),
-            'CHUNC_INDEX': os.getenv("CHUNC_INDEX"),
-            'DEEPSEEK_API_URL': os.getenv("DEEPSEEK_API_URL"),
-            'DEEPSEEK_API_KEY': os.getenv("DEEPSEEK_API_KEY"),
-            'CACHED_DATA_PATH': os.getenv("CACHED_DATA_PATH"),
-            'UPDATE_DATA_PATH': os.getenv("UPDATE_DATA_PATH"),
-            'FACTOR_DATA_PATH': os.getenv("FACTOR_DATA_PATH"),
-            'GRAPH_API_URL': os.getenv("GRAPH_API_URL"),
-            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL"),
-            'DOC_PATH': os.getenv("DOC_PATH"),
-            'DOC_STORAGE_PATH': os.getenv("DOC_STORAGE_PATH"),
-            'TRUNC_OUTPUT_PATH': os.getenv("TRUNC_OUTPUT_PATH"),
-            'DOC_ABSTRACT_OUTPUT_PATH': os.getenv("DOC_ABSTRACT_OUTPUT_PATH"),
-            'JIEBA_USER_DICT': os.getenv("JIEBA_USER_DICT"),
-            'JIEBA_STOP_DICT': os.getenv("JIEBA_STOP_DICT"),
-            'POSTGRESQL_HOST':  os.getenv("POSTGRESQL_HOST","localhost"),
-            'POSTGRESQL_DATABASE':  os.getenv("POSTGRESQL_DATABASE","kg"),
-            'POSTGRESQL_USER':  os.getenv("POSTGRESQL_USER","dify"),
-            'POSTGRESQL_PASSWORD':  os.getenv("POSTGRESQL_PASSWORD",quote("difyai123456")),
+            "LICENSE_PATH": os.getenv("LICENSE_PATH", ""),        
+            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL",""),
+            'DB_HOST':  os.getenv("DB_HOST",""),
+            'DB_NAME':  os.getenv("DB_NAME",""),
+            'DB_PORT':  os.getenv("DB_PORT",""),
+            'DB_USER': os.getenv("DB_USER", ""),
+            'DB_PASSWORD': os.getenv("DB_PASSWORD", ""),
+            'BOOKS': os.getenv("BOOKS", ""),
         }
     def get_config(self, config_name, default=None): 
         config_name = config_name.upper()     

+ 4 - 0
build/lib/knowledge/main.py

@@ -1,9 +1,13 @@
 # 导入FastAPI及相关模块
 import uvicorn
+#from fastapi_mcp import FastApiMCP
 from py_tools.logging import logger
 
 from .settings import base_setting
 from .server import app
+# mcp = FastApiMCP(app)
+# mcp.mount()
+# mcp.setup_server()
 
 def main():
     logger.info(f"project run {base_setting.server_host}:{base_setting.server_port}")

+ 12 - 12
build/lib/knowledge/middlewares/base.py

@@ -170,20 +170,20 @@ class AuthMiddleware(BaseHTTPMiddleware):
 
         path = request.url.path
 
-        # if not self.should_intercept(path):
-        #     return await call_next(request)
-        # 
-        # # 权限校验
-        # auth_header = request.headers.get("Authorization")
-        # if not auth_header:
-        #     return self.set_auth_err_resp("Missing Authorization header")
-        # 
-        # user_info = await self.verify_token(auth_header)
-        # if not user_info:
-        #     return self.set_auth_err_resp("Invalid token")
+        if not self.should_intercept(path):
+            return await call_next(request)
+
+        # 权限校验
+        auth_header = request.headers.get("Authorization")
+        if not auth_header:
+            return self.set_auth_err_resp("Missing Authorization header")
+
+        user_info = await self.verify_token(auth_header)
+        if not user_info:
+            return self.set_auth_err_resp("Invalid token")
 
         # 初始化操作:将用户信息添加到请求状态中
-        #request.state.user = user_info
+        request.state.user = user_info
         #cache["license_info"]是否存在license信息,如果不存在则验证证书
         # if not 'license_info' in self._cache or not self._cache["license_info"]:
         #     lisence_detail = license_handle()

+ 6 - 2
build/lib/knowledge/model/response.py

@@ -5,6 +5,10 @@ class StandardResponse(BaseModel):
     success: bool
     requestId: Optional[str] = None
     errorCode: Optional[int] = None
+    #验证结果编码:0:成功;1:时间过期;2:并发过高;3:次数耗尽;9:系统异常
+    vaildCode: Optional[int] = None
     errorMsg: Optional[str] = None
-    records: Optional[Any] = None
-    data: Optional[Any] = None
+    data: Optional[Any] = None
+    
+    def getSuccess(self):
+        return self.success

+ 3 - 3
build/lib/knowledge/router/knowledge_nodes_api.py

@@ -59,7 +59,7 @@ async def paginated_search(
             data=ObjectToJsonArrayConverter.convert(result)
         )
     except Exception as e:
-        logger.error(f"分页查询失败: {str(e)}")
+        logger.exception(f"分页查询失败: {str(e)}")
         raise HTTPException(
             status_code=500,
             detail=StandardResponse(
@@ -107,7 +107,7 @@ async def get_node_relationships_condition(
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
         )
     except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
+        logger.exception(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 
@@ -155,7 +155,7 @@ async def get_node_relationships(
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
         )
     except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
+        logger.exception(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 class GetNodeProperties(BaseModel):

+ 31 - 2
build/lib/knowledge/server.py

@@ -1,29 +1,37 @@
 from contextlib import asynccontextmanager
 from datetime import datetime
+from pathlib import Path
 
 from fastapi import FastAPI
 from py_tools.connections.http import AsyncHttpClient
 from py_tools.logging import logger
 
+from starlette.staticfiles import StaticFiles
+
 from .config.site import SiteConfig
 from .middlewares.base import register_middlewares
+
+from .router.graph_api import graph_router
 from .router.knowledge_nodes_api import knowledge_nodes_api_router
+from .router.knowledge_saas import saas_kb_router
+from .router.text_search import text_search_router
+
 from .utils import log_util
 
 
+
 @asynccontextmanager
 async def lifespan(app: FastAPI):
     await startup()
     yield
     await shutdown()
 
-
+config = SiteConfig()
 app = FastAPI(
     description="知识图谱开放平台",
     lifespan=lifespan,
     middleware=register_middlewares(),  # 注册web中间件
 )
-
 @app.get("/health")
 async def health_check():
     """健康检查接口"""
@@ -34,6 +42,7 @@ async def health_check():
     }
 
 
+
 async def init_setup():
     """初始化项目配置"""
 
@@ -46,7 +55,27 @@ async def startup():
 
     # 加载路由
     app.include_router(knowledge_nodes_api_router)
+    app.include_router(text_search_router)
+    app.include_router(graph_router)
+    app.include_router(saas_kb_router)
 
+    # 挂载静态文件目录,将/books路径映射到本地books文件夹
+
+    books_path = Path(config.get_config("BOOKS"))
+    #books_path = Path("E:\\project\\knowledge\\books")
+
+    app.mount("/books", StaticFiles(directory=books_path), name="books")
+
+    # 需要拦截的 URL 列表(支持通配符)
+    INTERCEPT_URLS = {
+        "/v1/knowledge/*"
+    }
+
+    # 白名单 URL(不需要拦截的路径)
+    WHITE_LIST = {
+        "/books/*",
+        "/knowledge/*"
+    }
     logger.info("fastapi startup success")
 
 

+ 45 - 21
build/lib/knowledge/service/kg_node_service.py

@@ -35,8 +35,6 @@ class KGNodeService:
                 KGNode.embedding.l2_distance(query_embedding).label('distance')
             )
             .filter(KGNode.status == 0)
-            #过滤掉version不等于'er'的节点
-            .filter(KGNode.version != 'ER')
             .filter(KGNode.embedding.l2_distance(query_embedding) <= DISTANCE_THRESHOLD2)
             .order_by('distance').limit(top_k).all()
         )
@@ -55,22 +53,24 @@ class KGNodeService:
 
     def paginated_search(self, search_params: dict) -> dict:
         load_props = search_params.get('load_props', False)
+        prop_service = KGPropService(self.db)
+        edge_service = KGEdgeService(self.db)
         keyword = search_params.get('keyword', '')
-        category = search_params.get('category', '')
+        category = search_params.get('category', None)
         page_no = search_params.get('pageNo', 1)
-        distance = search_params.get('distance',DISTANCE_THRESHOLD)
+        # distance 为NONE或不存在时,使用默认值
+        if search_params.get('distance') is None:
+            distance = DISTANCE_THRESHOLD
+        else:
+            distance = search_params.get('distance')
+        if distance==0:
+            distance = 0.1
         limit = search_params.get('limit', 10)
 
         if page_no < 1:
             page_no = 1
         if limit < 1:
             limit = 10
-            
-        cache_key = f"paginated_search:{keyword}:{category}:{page_no}:{distance}:{limit}:{str(search_params.get('knowledge_ids', ''))}:{load_props}"
-        logger.debug(f"Cache key: {cache_key}")
-        if cache_key in self._cache:
-            cached_value = self._cache[cache_key]
-            return copy.deepcopy(cached_value)
 
         embedding = Vectorizer.get_instance().get_embedding(keyword)
         offset = (page_no - 1) * limit
@@ -79,7 +79,7 @@ class KGNodeService:
             # 构建基础查询条件
             base_query = self.db.query(func.count(KGNode.id)).filter(
                 KGNode.status == 0,
-                KGNode.embedding.l2_distance(embedding) < distance
+                KGNode.embedding.l2_distance(embedding) <= distance
             )
             # 如果有category,则添加额外过滤条件
             if category:
@@ -97,23 +97,23 @@ class KGNodeService:
                 KGNode.name,
                 KGNode.category,
                 KGNode.embedding.l2_distance(embedding).label('distance')
-            )            
+            )
             query = query.filter(KGNode.status == 0)
-            #category有值时,过滤掉category不等于category的节点
+            # category有值时,过滤掉category不等于category的节点
             if category:
                 query = query.filter(KGNode.category == category)
             if search_params.get('knowledge_ids'):
                 query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
-            query = query.filter(KGNode.embedding.l2_distance(embedding) < distance)
+            query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
             results = query.order_by('distance').offset(offset).limit(limit).all()
-            #将results相同distance的category=疾病的放在前面
-            #results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
 
-            finalResults = {
+            return {
                 'records': [{
                     'id': r.id,
                     'name': r.name,
                     'category': r.category,
+                    'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
+                    # 'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
                     'distance': round(r.distance, 3)
                 } for r in results],
                 'pagination': {
@@ -122,10 +122,8 @@ class KGNodeService:
                     'limit': limit,
                     'totalPages': (total_count + limit - 1) // limit
                 }
-            
+
             }
-            self._cache[cache_key] = copy.deepcopy(finalResults)
-            return finalResults
         except Exception as e:
             logger.error(f"分页查询失败: {str(e)}")
             raise e
@@ -231,4 +229,30 @@ class KGNodeService:
             except Exception as e:
                 self.db.rollback()
                 print(f"批量处理ER节点失败: {str(e)}")
-                raise ValueError("Batch process failed")
+                raise ValueError("Batch process failed")
+
+    def get_node_by_name_category(self, name: str, category: str):
+        if not name or not category:
+            raise ValueError("Name and category are required")
+
+        cache_key = f"get_node_by_name_category_{name}:{category}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+
+        node = self.db.query(KGNode).filter(
+            KGNode.name == name,
+            KGNode.category == category,
+            KGNode.status == 0
+        ).first()
+
+        if not node:
+            return None
+
+        node_data = {
+            'id': node.id,
+            'name': node.name,
+            'category': node.category,
+            'version': node.version
+        }
+        self._cache[cache_key] = node_data
+        return node_data

+ 17 - 0
build/lib/knowledge/service/kg_prop_service.py

@@ -32,7 +32,24 @@ class KGPropService:
     #     except Exception as e:
     #         logger.error(f"根据ref_id查询属性失败: {str(e)}")
     #         raise ValueError("查询失败")
+    def get_prop_by_id(self, id: int)-> dict:
+        try:
+            query = self.db.query(KGProp).filter(KGProp.id == id)
 
+            props = query.first()
+            if not props:
+                return None
+            return {
+                'id': props.id,
+                'category': props.category,
+                'prop_name': props.prop_name,
+                'prop_value': props.prop_value,
+                'prop_title': props.prop_title,
+                'type': props.type
+            }
+        except Exception as e:
+            logger.error(f"根据id查询属性失败: {str(e)}")
+            raise ValueError("查询失败")
 
     def get_props_by_ref_id(self, ref_id: int, prop_name: str = None) -> List[dict]:
 

+ 148 - 29
build/lib/knowledge/utils/license.py

@@ -1,39 +1,158 @@
-from cryptography.hazmat.primitives.asymmetric import padding
-from cryptography.hazmat.primitives import hashes,serialization
+import os
 import json
 import time
-import traceback
+import base64
+import logging
+from threading import Thread
+from cachetools import TTLCache
+from cryptography.hazmat.primitives.asymmetric import padding
+from cryptography.hazmat.primitives import hashes,serialization
+
+from ..db.session import get_db
+from ..model.dict_system import DictSystem
+from src.knowledge.config.site import SiteConfig
+from src.knowledge.model.response import StandardResponse
+
+logger = logging.getLogger(__name__)
+
+#并发缓存
+concurrency_cache = TTLCache(maxsize=10, ttl=5)
+#证书验证缓存
+vaild_cache = {}
+#5分值之内触发一次保存
+need_save = False
+
+#验证入口
+def validate_license():
+    if vaild_cache == {}:
+        return StandardResponse(success=False,vaildCode=9, errorMsg="系统异常:系统未加载证书信息")
+    elif 'result' in vaild_cache:   #已经是失败状态就直接返回失败
+        return vaild_cache["result"]
+    else:
+        timestamp = int(time.time())
+        timestamp_str = str(timestamp)
+        print(concurrency_cache,'=concurrency_cache')
+        if timestamp_str in concurrency_cache:  # 检查并发是否过高
+            if concurrency_cache[timestamp_str] > vaild_cache["api_concurrent_limit"]:
+                logger.info("当前["+timestamp_str+"]请求的并发数为"+str(concurrency_cache[timestamp_str]))
+                return StandardResponse(success=False,vaildCode=2, errorMsg="并发数过高")
+            else:
+                concurrency_cache[timestamp_str] += 1
+        else:
+            concurrency_cache[timestamp_str] = 1
+        
+        if timestamp > vaild_cache["expiration_time"]: # 检查日期是否过期
+            logger.info("证书过期时间为"+str(vaild_cache["expiration_time"]))
+            vaild_cache["result"] = StandardResponse(success=False,vaildCode=1, errorMsg="许可证已过期")
+            return vaild_cache["result"]
+        
+        if vaild_cache["api_use_times"] > vaild_cache["api_invoke_max_count"]:  # 检查次数
+            logger.info("请求使用次数为"+str(vaild_cache["api_concurrent_limit"]))
+            vaild_cache["result"] = StandardResponse(success=False,vaildCode=3, errorMsg="使用次数已经用完")
+            return vaild_cache["result"]
+        vaild_cache["api_use_times"] += 1
 
-def validate_license(public_key_pem, license_json, signature):
-    public_key = serialization.load_pem_public_key(public_key_pem)
+        global need_save
+        if need_save == False:
+            Thread(target=save_vaild_cache).start()
+            need_save = True
 
+        return StandardResponse(success=True,vaildCode=0, data=vaild_cache)
+        
+#初始化验证参数
+def init_vaild_cache():
     try:
-        public_key.verify(
-        signature,
-        license_json,
-        padding.PKCS1v15(),
-        hashes.SHA256()
-        )
+        config = SiteConfig()
+        license_path = config.get_config('LICENSE_PATH')
+
+        public_key_pem = None
+        license_json = None
+        signature = None
+
+        with open(os.path.join(license_path, "public.key"), "rb") as f:
+            public_key_pem = f.read()
+        with open(os.path.join(license_path, "license_issued.lic"), "rb") as f:
+            data = json.loads(f.read())
+            license_json = json.dumps(data, sort_keys=True).encode()
+        with open(os.path.join(license_path, "license_issued.key"), "rb") as f:
+            signature = f.read()
+            
+        public_key = serialization.load_pem_public_key(public_key_pem)
+        public_key.verify(signature,license_json,padding.PKCS1v15(),hashes.SHA256())
+        license_data = json.loads(license_json.decode())
+
+        content = license_data['content']
+        for c in content:
+            if c['name'] == 'api_invoke_max_count':
+                vaild_cache['api_invoke_max_count'] = c['value']
+            elif c['name'] == 'api_concurrent_limit':
+                vaild_cache['api_concurrent_limit'] = c['value']
+        vaild_cache['expiration_time'] = license_data['expiration_time']
+        vaild_cache['api_use_times'] = 0
+        
+        #从数据库加载 
+        db = next(get_db())
+        result = (db.query(DictSystem).filter_by(dict_name = 'vaild_str',is_deleted = 0).first())
+        print(result)
+        if result is not None: #说明还没有存储,不清楚是不是被删除,从证书上同步
+            dict_value = str(base64.b64decode(bytes(result.dict_value,encoding="utf-8") ),encoding='utf-8')
+            dict_value_json=json.loads(dict_value)
+            vaild_cache['api_use_times'] = dict_value_json['api_use_times']
+            vaild_cache['id'] = result.id
     except:
-        #打印异常信息
-        traceback.print_exc()
-        return False
+        vaild_cache["result"] =  StandardResponse(success=False,vaildCode=9, errorMsg="许可证签名验证失败:是否上传了许可证文件")
+        logger.exception("许可证签名验证失败")
+
+#保存验证的参数,有请求时延迟5分钟触发保存
+def save_vaild_cache():
+    time.sleep(10) #5分钟
+    global need_save
+    if need_save == False:
+        return
+    # time.sleep(5*60) #5分钟
+    if vaild_cache is not None:
+        #保存数据到数据库
+        tmp_cache = {}
+        tmp_cache['api_use_times'] = vaild_cache['api_use_times']
+        new_dict = DictSystem(
+            dict_name='vaild_str',
+            dict_value=str(base64.b64encode(json.dumps(tmp_cache, ensure_ascii=False, indent=1).encode("utf-8")),encoding="utf-8"),
+            is_deleted=0
+        )
+        db = next(get_db())
+        print(vaild_cache)
+        if 'id' in vaild_cache:
+            db.query(DictSystem).filter_by(id=vaild_cache['id']).update({DictSystem.dict_value: new_dict.dict_value})
+            new_dict.id = vaild_cache['id']
+        else:
+            db.add(new_dict)
+        db.commit()
+        vaild_cache['id'] = new_dict.id
+    
+    need_save = False
 
-    license_data=json.loads(license_json.decode())
-    # 检查是否过期
-    if time.time()>license_data["expiration_time"]:
-        return False
-    return True
+#重置验证参数
+def reset_vaild_cache():
+    global vaild_cache
+    if vaild_cache != {} and 'id' in vaild_cache:
+        db = next(get_db())
+        db.delete( db.query(DictSystem).filter_by(id=vaild_cache['id']).first()) #删除信息
+        db.commit()
+    vaild_cache={}
+    global need_save
+    need_save = False
+    init_vaild_cache()  #再初始化
 
 if __name__ == '__main__':
-    with open("license_issued/public.key","rb") as f:
-        public_key_pem = f.read()
-    with open("license_issued/license_issued.lic","rb") as f:
-        data = json.loads(f.read())
-        license_json = json.dumps(data, sort_keys=True).encode()
-    with open("license_issued/license_issued.key","rb") as f:
-        signature = f.read()
-    if validate_license(public_key_pem,license_json, signature):
-        print("许可证有效!")
+    init_vaild_cache()
+    response = validate_license()
+    if response.success:
+        print("1许可证有效!")
+    else:
+        print(response.errorMsg)
+    reset_vaild_cache()
+    response = validate_license()
+    if response.success:
+        print("3许可证有效!")
     else:
-        print("许可证无效或已过期!")
+        print(response.errorMsg)

BIN
dist/knowledge-1.0-py3-none-any.whl


BIN
dist/knowledge-1.0.tar.gz


+ 0 - 23
server.py

@@ -1,23 +0,0 @@
-# server.py
-from mcp.server.fastmcp import FastMCP
-
-# Create an MCP server
-mcp = FastMCP("Demo")
-
-
-# Add an addition tool
-@mcp.tool()
-def add(a: int, b: int) -> int:
-    """Add two numbers"""
-    return a + b
-
-
-# Add a dynamic greeting resource
-@mcp.resource("greeting://{name}")
-def get_greeting(name: str) -> str:
-    """Get a personalized greeting"""
-    return f"Hello, {name}!"
-
-if __name__ == "__main__":
-    print(111)
-    mcp.run()

+ 1 - 0
src/knowledge.egg-info/PKG-INFO

@@ -14,4 +14,5 @@ Requires-Dist: psycopg2-binary==2.9.10
 Requires-Dist: python-dotenv==1.0.0
 Requires-Dist: hui-tools[all]==0.5.8
 Requires-Dist: cachetools==6.1.0
+Requires-Dist: jieba==0.42.1
 Dynamic: requires-dist

+ 11 - 0
src/knowledge.egg-info/SOURCES.txt

@@ -22,23 +22,34 @@ src/knowledge/model/kg_edges.py
 src/knowledge/model/kg_node.py
 src/knowledge/model/kg_prop.py
 src/knowledge/model/response.py
+src/knowledge/model/trunks_model.py
 src/knowledge/router/__init__.py
+src/knowledge/router/graph_api.py
 src/knowledge/router/knowledge_nodes_api.py
+src/knowledge/router/knowledge_saas.py
+src/knowledge/router/text_search.py
 src/knowledge/service/__init__.py
 src/knowledge/service/dict_system_service.py
 src/knowledge/service/kg_edge_service.py
+src/knowledge/service/kg_graph_service.py
 src/knowledge/service/kg_node_service.py
 src/knowledge/service/kg_prop_service.py
+src/knowledge/service/trunks_service.py
 src/knowledge/settings/__init__.py
 src/knowledge/settings/auth_setting.py
 src/knowledge/settings/base_setting.py
 src/knowledge/settings/log_setting.py
+src/knowledge/utils/DeepseekUtil.py
 src/knowledge/utils/ObjectToJsonArrayConverter.py
 src/knowledge/utils/__init__.py
 src/knowledge/utils/context_util.py
 src/knowledge/utils/embed_helper.py
+src/knowledge/utils/json_to_text.py
 src/knowledge/utils/license.py
 src/knowledge/utils/log_util.py
+src/knowledge/utils/mcp_client.py
+src/knowledge/utils/sentence_util.py
+src/knowledge/utils/text_similarity.py
 src/knowledge/utils/trace_util.py
 src/knowledge/utils/vector_distance.py
 src/knowledge/utils/vectorizer.py

+ 1 - 0
src/knowledge.egg-info/requires.txt

@@ -11,3 +11,4 @@ psycopg2-binary==2.9.10
 python-dotenv==1.0.0
 hui-tools[all]==0.5.8
 cachetools==6.1.0
+jieba==0.42.1

+ 1 - 0
src/knowledge/.env

@@ -5,3 +5,4 @@ DB_USER=knowledge
 DB_PASSWORD=qwer1234.
 LICENSE_PATH=E:\project\knowledge2\src\knowledge\utils\license_issued
 EMBEDDING_MODEL=E:\project\bge-m3
+BOOKS=E:\project\knowledge\books

+ 1 - 0
src/knowledge/config/site.py

@@ -18,6 +18,7 @@ class SiteConfig:
             'DB_PORT':  os.getenv("DB_PORT",""),
             'DB_USER': os.getenv("DB_USER", ""),
             'DB_PASSWORD': os.getenv("DB_PASSWORD", ""),
+            'BOOKS': os.getenv("BOOKS", ""),
         }
     def get_config(self, config_name, default=None): 
         config_name = config_name.upper()     

+ 9 - 6
src/knowledge/main.py

@@ -1,18 +1,21 @@
 # 导入FastAPI及相关模块
 import uvicorn
-from fastapi_mcp import FastApiMCP
+#from fastapi_mcp import FastApiMCP
 from py_tools.logging import logger
 
 from .settings import base_setting
 from .server import app
-mcp = FastApiMCP(app)
-mcp.mount()
-mcp.setup_server()
-if __name__ == "__main__":
-    logger.info(f"project run {base_setting.server_host}:{base_setting.server_port}")
+# mcp = FastApiMCP(app)
+# mcp.mount()
+# mcp.setup_server()
 
+def main():
+    logger.info(f"project run {base_setting.server_host}:{base_setting.server_port}")
     uvicorn.run(
         app=app, host=base_setting.server_host, port=base_setting.server_port, log_level=base_setting.server_log_level,
         access_log=False
     )
 
+if __name__ == "__main__":
+    main()
+

+ 29 - 11
src/knowledge/server.py

@@ -1,25 +1,23 @@
 from contextlib import asynccontextmanager
 from datetime import datetime
-from typing import Optional
+from pathlib import Path
 
-from fastapi import FastAPI, Depends, Security, HTTPException, Query
+from fastapi import FastAPI
 from py_tools.connections.http import AsyncHttpClient
 from py_tools.logging import logger
-from pydantic import BaseModel
-from requests import Session
+
+from starlette.staticfiles import StaticFiles
 
 from .config.site import SiteConfig
-from .db.session import get_db
 from .middlewares.base import register_middlewares
-from .model.response import StandardResponse
+
 from .router.graph_api import graph_router
-from .router.knowledge_nodes_api import knowledge_nodes_api_router, get_request_id, api_key_header
+from .router.knowledge_nodes_api import knowledge_nodes_api_router
 from .router.knowledge_saas import saas_kb_router
 from .router.text_search import text_search_router
-from .service.kg_edge_service import KGEdgeService
-from .service.kg_node_service import KGNodeService
+
 from .utils import log_util
-from .utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
+
 
 
 @asynccontextmanager
@@ -28,7 +26,7 @@ async def lifespan(app: FastAPI):
     yield
     await shutdown()
 
-
+config = SiteConfig()
 app = FastAPI(
     description="知识图谱开放平台",
     lifespan=lifespan,
@@ -43,6 +41,8 @@ async def health_check():
         "service": "knowledge-graph"
     }
 
+
+
 async def init_setup():
     """初始化项目配置"""
 
@@ -58,6 +58,24 @@ async def startup():
     app.include_router(text_search_router)
     app.include_router(graph_router)
     app.include_router(saas_kb_router)
+
+    # 挂载静态文件目录,将/books路径映射到本地books文件夹
+
+    books_path = Path(config.get_config("BOOKS"))
+    #books_path = Path("E:\\project\\knowledge\\books")
+
+    app.mount("/books", StaticFiles(directory=books_path), name="books")
+
+    # 需要拦截的 URL 列表(支持通配符)
+    INTERCEPT_URLS = {
+        "/v1/knowledge/*"
+    }
+
+    # 白名单 URL(不需要拦截的路径)
+    WHITE_LIST = {
+        "/books/*",
+        "/knowledge/*"
+    }
     logger.info("fastapi startup success")