Browse Source

代码提交

SGTY 3 weeks ago
parent
commit
fab0dadf99

+ 1 - 2
src/knowledge/.env

@@ -3,6 +3,5 @@ DB_NAME=medkg
 DB_PORT=5432
 DB_PORT=5432
 DB_USER=knowledge
 DB_USER=knowledge
 DB_PASSWORD=qwer1234.
 DB_PASSWORD=qwer1234.
-
-license=E:\project\knowledge\src\knowledge\utils\license_issued
+LICENSE_PATH=E:\project\knowledge2\src\knowledge\utils\license_issued
 EMBEDDING_MODEL=E:\project\bge-m3
 EMBEDDING_MODEL=E:\project\bge-m3

+ 7 - 28
src/knowledge/config/site.py

@@ -11,34 +11,13 @@ class SiteConfig:
     
     
     def load_config(self):        
     def load_config(self):        
         self.config = {
         self.config = {
-            "SITE_NAME": os.getenv("SITE_NAME", "DEMO"),
-            "SITE_DESCRIPTION": os.getenv("SITE_DESCRIPTION", "ChatGPT"),
-            "SITE_URL": os.getenv("SITE_URL", ""),
-            "SITE_LOGO": os.getenv("SITE_LOGO", ""),
-            "SITE_FAVICON": os.getenv("SITE_FAVICON"),
-            'ELASTICSEARCH_HOST': os.getenv("ELASTICSEARCH_HOST"),
-            'ELASTICSEARCH_USER': os.getenv("ELASTICSEARCH_USER"),
-            'ELASTICSEARCH_PWD': os.getenv("ELASTICSEARCH_PWD"),
-            'WORD_INDEX': os.getenv("WORD_INDEX"),
-            'TITLE_INDEX': os.getenv("TITLE_INDEX"),
-            'CHUNC_INDEX': os.getenv("CHUNC_INDEX"),
-            'DEEPSEEK_API_URL': os.getenv("DEEPSEEK_API_URL"),
-            'DEEPSEEK_API_KEY': os.getenv("DEEPSEEK_API_KEY"),
-            'CACHED_DATA_PATH': os.getenv("CACHED_DATA_PATH"),
-            'UPDATE_DATA_PATH': os.getenv("UPDATE_DATA_PATH"),
-            'FACTOR_DATA_PATH': os.getenv("FACTOR_DATA_PATH"),
-            'GRAPH_API_URL': os.getenv("GRAPH_API_URL"),
-            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL"),
-            'DOC_PATH': os.getenv("DOC_PATH"),
-            'DOC_STORAGE_PATH': os.getenv("DOC_STORAGE_PATH"),
-            'TRUNC_OUTPUT_PATH': os.getenv("TRUNC_OUTPUT_PATH"),
-            'DOC_ABSTRACT_OUTPUT_PATH': os.getenv("DOC_ABSTRACT_OUTPUT_PATH"),
-            'JIEBA_USER_DICT': os.getenv("JIEBA_USER_DICT"),
-            'JIEBA_STOP_DICT': os.getenv("JIEBA_STOP_DICT"),
-            'POSTGRESQL_HOST':  os.getenv("POSTGRESQL_HOST","localhost"),
-            'POSTGRESQL_DATABASE':  os.getenv("POSTGRESQL_DATABASE","kg"),
-            'POSTGRESQL_USER':  os.getenv("POSTGRESQL_USER","dify"),
-            'POSTGRESQL_PASSWORD':  os.getenv("POSTGRESQL_PASSWORD",quote("difyai123456")),
+            "LICENSE_PATH": os.getenv("LICENSE_PATH", ""),        
+            'EMBEDDING_MODEL': os.getenv("EMBEDDING_MODEL",""),
+            'DB_HOST':  os.getenv("DB_HOST",""),
+            'DB_NAME':  os.getenv("DB_NAME",""),
+            'DB_PORT':  os.getenv("DB_PORT",""),
+            'DB_USER': os.getenv("DB_USER", ""),
+            'DB_PASSWORD': os.getenv("DB_PASSWORD", ""),
         }
         }
     def get_config(self, config_name, default=None): 
     def get_config(self, config_name, default=None): 
         config_name = config_name.upper()     
         config_name = config_name.upper()     

+ 0 - 1
src/knowledge/model/response.py

@@ -6,5 +6,4 @@ class StandardResponse(BaseModel):
     requestId: Optional[str] = None
     requestId: Optional[str] = None
     errorCode: Optional[int] = None
     errorCode: Optional[int] = None
     errorMsg: Optional[str] = None
     errorMsg: Optional[str] = None
-    records: Optional[Any] = None
     data: Optional[Any] = None
     data: Optional[Any] = None

+ 3 - 3
src/knowledge/router/knowledge_nodes_api.py

@@ -59,7 +59,7 @@ async def paginated_search(
             data=ObjectToJsonArrayConverter.convert(result)
             data=ObjectToJsonArrayConverter.convert(result)
         )
         )
     except Exception as e:
     except Exception as e:
-        logger.error(f"分页查询失败: {str(e)}")
+        logger.exception(f"分页查询失败: {str(e)}")
         raise HTTPException(
         raise HTTPException(
             status_code=500,
             status_code=500,
             detail=StandardResponse(
             detail=StandardResponse(
@@ -107,7 +107,7 @@ async def get_node_relationships_condition(
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
         )
         )
     except Exception as e:
     except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
+        logger.exception(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 
 
 
@@ -155,7 +155,7 @@ async def get_node_relationships(
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
             data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
         )
         )
     except Exception as e:
     except Exception as e:
-        logger.error(f"获取节点关系失败: {str(e)}")
+        logger.exception(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 
 class GetNodeProperties(BaseModel):
 class GetNodeProperties(BaseModel):

+ 48 - 17
src/knowledge/utils/license.py

@@ -1,13 +1,45 @@
+import copy
+import logging
+import os
+
+from cachetools import TTLCache
 from cryptography.hazmat.primitives.asymmetric import padding
 from cryptography.hazmat.primitives.asymmetric import padding
 from cryptography.hazmat.primitives import hashes,serialization
 from cryptography.hazmat.primitives import hashes,serialization
 import json
 import json
 import time
 import time
 import traceback
 import traceback
 
 
-def validate_license(public_key_pem, license_json, signature):
-    public_key = serialization.load_pem_public_key(public_key_pem)
-
+from src.knowledge.config.site import SiteConfig
+from src.knowledge.model.response import StandardResponse
+logger = logging.getLogger(__name__)
+_cache = TTLCache(maxsize=1, ttl=60*5)
+def validate_license():
+    
+    # 先检查缓存
+    cache_key = f"validate_license"
+    if cache_key in _cache:
+        return copy.deepcopy(_cache[cache_key])
+    cached_result = _cache.get('license_result')
+    if cached_result is not None:
+        return cached_result
+        
     try:
     try:
+        config = SiteConfig()
+        license_path = config.get_config('LICENSE_PATH')
+
+        public_key_pem = None
+        license_json = None
+        signature = None
+        with open(os.path.join(license_path, "public.key"), "rb") as f:
+            public_key_pem = f.read()
+            
+        with open(os.path.join(license_path, "license_issued.lic"), "rb") as f:
+            data = json.loads(f.read())
+            license_json = json.dumps(data, sort_keys=True).encode()
+        with open(os.path.join(license_path, "license_issued.key"), "rb") as f:
+            signature = f.read()
+            
+        public_key = serialization.load_pem_public_key(public_key_pem)
         public_key.verify(
         public_key.verify(
         signature,
         signature,
         license_json,
         license_json,
@@ -15,25 +47,24 @@ def validate_license(public_key_pem, license_json, signature):
         hashes.SHA256()
         hashes.SHA256()
         )
         )
     except:
     except:
-        #打印异常信息
-        traceback.print_exc()
-        return False
+        logger.exception("许可证签名验证失败")
+        result = StandardResponse(success=False, errorMsg="许可证签名验证失败:是否上传了许可证文件")
+        _cache['cache_key'] = copy.deepcopy(result)
+        return result
 
 
     license_data=json.loads(license_json.decode())
     license_data=json.loads(license_json.decode())
     # 检查是否过期
     # 检查是否过期
     if time.time()>license_data["expiration_time"]:
     if time.time()>license_data["expiration_time"]:
-        return False
-    return True
+        result = StandardResponse(success=False, errorMsg="许可证已过期")
+        _cache['cache_key'] = copy.deepcopy(result)
+        return result
+    result = StandardResponse(success=True, data=license_json)
+    _cache['cache_key'] = copy.deepcopy(result)
+    return result
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    with open("license_issued/public.key","rb") as f:
-        public_key_pem = f.read()
-    with open("license_issued/license_issued.lic","rb") as f:
-        data = json.loads(f.read())
-        license_json = json.dumps(data, sort_keys=True).encode()
-    with open("license_issued/license_issued.key","rb") as f:
-        signature = f.read()
-    if validate_license(public_key_pem,license_json, signature):
+    response = validate_license()
+    if response.success:
         print("许可证有效!")
         print("许可证有效!")
     else:
     else:
-        print("许可证无效或已过期!")
+        print(response.errorMsg)