Browse Source

代码提交

SGTY 3 weeks atrás
parent
commit
fab0dadf99

+ 1 - 2
src/knowledge/.env

@@ -3,6 +3,5 @@ DB_NAME=medkg
 DB_PORT=5432
 DB_USER=knowledge
 DB_PASSWORD=qwer1234.
-
-license=E:\project\knowledge\src\knowledge\utils\license_issued
+LICENSE_PATH=E:\project\knowledge2\src\knowledge\utils\license_issued
 EMBEDDING_MODEL=E:\project\bge-m3

+ 7 - 28
src/knowledge/config/site.py

@@ -11,34 +11,13 @@ class SiteConfig:
     
     def load_config(self):        
         self.config = {
-            "SITE_NAME": os.getenv("SITE_NAME", "DEMO"),
-            "SITE_DESCRIPTION": os.getenv("SITE_DESCRIPTION", "ChatGPT"),
-            "SITE_URL": os.getenv("SITE_URL", ""),
-            "SITE_LOGO": os.getenv("SITE_LOGO", ""),
-            "SITE_FAVICON": os.getenv("SITE_FAVICON"),
-            'ELASTICSEARCH_HOST': os.getenv("ELASTICSEARCH_HOST"),
-            'ELASTICSEARCH_USER': os.getenv("ELASTICSEARCH_USER"),
-            'ELASTICSEARCH_PWD': os.getenv("ELASTICSEARCH_PWD"),
-            'WORD_INDEX': os.getenv("WORD_INDEX"),
-            'TITLE_INDEX': os.getenv("TITLE_INDEX"),
-            'CHUNC_INDEX': os.getenv("CHUNC_INDEX"),
-            'DEEPSEEK_API_URL': os.getenv("DEEPSEEK_API_URL"),
-            'DEEPSEEK_API_KEY': os.getenv("DEEPSEEK_API_KEY"),
-            'CACHED_DATA_PATH': os.getenv("CACHED_DATA_PATH"),
-            'UPDATE_DATA_PATH': os.getenv("UPDATE_DATA_PATH"),
-            'FACTOR_DATA_PATH': os.getenv("FACTOR_DATA_PATH"),
-            'GRAPH_API_URL': os.getenv("GRAPH_API_URL"),
-            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL"),
-            'DOC_PATH': os.getenv("DOC_PATH"),
-            'DOC_STORAGE_PATH': os.getenv("DOC_STORAGE_PATH"),
-            'TRUNC_OUTPUT_PATH': os.getenv("TRUNC_OUTPUT_PATH"),
-            'DOC_ABSTRACT_OUTPUT_PATH': os.getenv("DOC_ABSTRACT_OUTPUT_PATH"),
-            'JIEBA_USER_DICT': os.getenv("JIEBA_USER_DICT"),
-            'JIEBA_STOP_DICT': os.getenv("JIEBA_STOP_DICT"),
-            'POSTGRESQL_HOST':  os.getenv("POSTGRESQL_HOST","localhost"),
-            'POSTGRESQL_DATABASE':  os.getenv("POSTGRESQL_DATABASE","kg"),
-            'POSTGRESQL_USER':  os.getenv("POSTGRESQL_USER","dify"),
-            'POSTGRESQL_PASSWORD':  os.getenv("POSTGRESQL_PASSWORD",quote("difyai123456")),
+            "LICENSE_PATH": os.getenv("LICENSE_PATH", ""),        
+            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL",""),
+            'DB_HOST':  os.getenv("DB_HOST",""),
+            'DB_NAME':  os.getenv("DB_NAME",""),
+            'DB_PORT':  os.getenv("DB_PORT",""),
+            'DB_USER': os.getenv("DB_USER", ""),
+            'DB_PASSWORD': os.getenv("DB_PASSWORD", ""),
         }
     def get_config(self, config_name, default=None): 
         config_name = config_name.upper()     

+ 0 - 1
src/knowledge/model/response.py

@@ -6,5 +6,4 @@ class StandardResponse(BaseModel):
     requestId: Optional[str] = None
     errorCode: Optional[int] = None
     errorMsg: Optional[str] = None
-    records: Optional[Any] = None
     data: Optional[Any] = None

+ 3 - 3
src/knowledge/router/knowledge_nodes_api.py

@@ -59,7 +59,7 @@ async def paginated_search(
             data=ObjectToJsonArrayConverter.convert(result)
         )
     except Exception as e:
-        logger.error(f"分页查询失败: {str(e)}")
+        logger.exception(f"分页查询失败: {str(e)}")
         raise HTTPException(
             status_code=500,
             detail=StandardResponse(
@@ -107,7 +107,7 @@ async def get_node_relationships_condition(
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
         )
     except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
+        logger.exception(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 
@@ -155,7 +155,7 @@ async def get_node_relationships(
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
         )
     except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
+        logger.exception(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 class GetNodeProperties(BaseModel):

+ 48 - 17
src/knowledge/utils/license.py

@@ -1,13 +1,45 @@
+import copy
+import logging
+import os
+
+from cachetools import TTLCache
 from cryptography.hazmat.primitives.asymmetric import padding
 from cryptography.hazmat.primitives import hashes,serialization
 import json
 import time
 import traceback
 
-def validate_license(public_key_pem, license_json, signature):
-    public_key = serialization.load_pem_public_key(public_key_pem)
-
+from src.knowledge.config.site import SiteConfig
+from src.knowledge.model.response import StandardResponse
+logger = logging.getLogger(__name__)
+_cache = TTLCache(maxsize=1, ttl=60*5)
+def validate_license():
+    
+    # 先检查缓存
+    cache_key = f"validate_license"
+    if cache_key in _cache:
+        return copy.deepcopy(_cache[cache_key])
+    cached_result = _cache.get('license_result')
+    if cached_result is not None:
+        return cached_result
+        
     try:
+        config = SiteConfig()
+        license_path = config.get_config('LICENSE_PATH')
+
+        public_key_pem = None
+        license_json = None
+        signature = None
+        with open(os.path.join(license_path, "public.key"), "rb") as f:
+            public_key_pem = f.read()
+            
+        with open(os.path.join(license_path, "license_issued.lic"), "rb") as f:
+            data = json.loads(f.read())
+            license_json = json.dumps(data, sort_keys=True).encode()
+        with open(os.path.join(license_path, "license_issued.key"), "rb") as f:
+            signature = f.read()
+            
+        public_key = serialization.load_pem_public_key(public_key_pem)
         public_key.verify(
         signature,
         license_json,
@@ -15,25 +47,24 @@ def validate_license(public_key_pem, license_json, signature):
         hashes.SHA256()
         )
     except:
-        #打印异常信息
-        traceback.print_exc()
-        return False
+        logger.exception("许可证签名验证失败")
+        result = StandardResponse(success=False, errorMsg="许可证签名验证失败:是否上传了许可证文件")
+        _cache['cache_key'] = copy.deepcopy(result)
+        return result
 
     license_data=json.loads(license_json.decode())
     # 检查是否过期
     if time.time()>license_data["expiration_time"]:
-        return False
-    return True
+        result = StandardResponse(success=False, errorMsg="许可证已过期")
+        _cache['cache_key'] = copy.deepcopy(result)
+        return result
+    result = StandardResponse(success=True, data=license_json)
+    _cache['cache_key'] = copy.deepcopy(result)
+    return result
 
 if __name__ == '__main__':
-    with open("license_issued/public.key","rb") as f:
-        public_key_pem = f.read()
-    with open("license_issued/license_issued.lic","rb") as f:
-        data = json.loads(f.read())
-        license_json = json.dumps(data, sort_keys=True).encode()
-    with open("license_issued/license_issued.key","rb") as f:
-        signature = f.read()
-    if validate_license(public_key_pem,license_json, signature):
+    response = validate_license()
+    if response.success:
         print("许可证有效!")
     else:
-        print("许可证无效或已过期!")
+        print(response.errorMsg)