فهرست منبع

增加验证工具

攻心小虫 3 هفته پیش
والد
کامیت
1c7c3778b4
2فایلهای تغییر یافته به همراه125 افزوده شده و 36 حذف شده
  1. 6 1
      src/knowledge/model/response.py
  2. 119 35
      src/knowledge/utils/license.py

+ 6 - 1
src/knowledge/model/response.py

@@ -5,5 +5,10 @@ class StandardResponse(BaseModel):
     success: bool
     requestId: Optional[str] = None
     errorCode: Optional[int] = None
+    #验证结果编码:0:成功;1:时间过期;2:并发过高;3:次数耗尽;9:系统异常
+    vaildCode: Optional[int] = None
     errorMsg: Optional[str] = None
-    data: Optional[Any] = None
+    data: Optional[Any] = None
+    
+    def getSuccess(self):
+        return self.success

+ 119 - 35
src/knowledge/utils/license.py

@@ -1,28 +1,66 @@
-import copy
-import logging
 import os
-
+import json
+import time
+import base64
+import logging
+from threading import Thread
 from cachetools import TTLCache
 from cryptography.hazmat.primitives.asymmetric import padding
 from cryptography.hazmat.primitives import hashes,serialization
-import json
-import time
-import traceback
 
+from ..db.session import get_db
+from ..model.dict_system import DictSystem
 from src.knowledge.config.site import SiteConfig
 from src.knowledge.model.response import StandardResponse
+
 logger = logging.getLogger(__name__)
-_cache = TTLCache(maxsize=1, ttl=60*5)
+
+#并发缓存
+concurrency_cache = TTLCache(maxsize=10, ttl=5)
+#证书验证缓存
+vaild_cache = {}
+#5分值之内触发一次保存
+need_save = False
+
+#验证入口
 def validate_license():
-    
-    # 先检查缓存
-    cache_key = f"validate_license"
-    if cache_key in _cache:
-        return copy.deepcopy(_cache[cache_key])
-    cached_result = _cache.get('license_result')
-    if cached_result is not None:
-        return cached_result
+    if vaild_cache == {}:
+        return StandardResponse(success=False,vaildCode=9, errorMsg="系统异常:系统未加载证书信息")
+    elif 'result' in vaild_cache:   #已经是失败状态就直接返回失败
+        return vaild_cache["result"]
+    else:
+        timestamp = int(time.time())
+        timestamp_str = str(timestamp)
+        print(concurrency_cache,'=concurrency_cache')
+        if timestamp_str in concurrency_cache:  # 检查并发是否过高
+            if concurrency_cache[timestamp_str] > vaild_cache["api_concurrent_limit"]:
+                logger.info("当前["+timestamp_str+"]请求的并发数为"+str(concurrency_cache[timestamp_str]))
+                return StandardResponse(success=False,vaildCode=2, errorMsg="并发数过高")
+            else:
+                concurrency_cache[timestamp_str] += 1
+        else:
+            concurrency_cache[timestamp_str] = 1
+        
+        if timestamp > vaild_cache["expiration_time"]: # 检查日期是否过期
+            logger.info("证书过期时间为"+str(vaild_cache["expiration_time"]))
+            vaild_cache["result"] = StandardResponse(success=False,vaildCode=1, errorMsg="许可证已过期")
+            return vaild_cache["result"]
         
+        if vaild_cache["api_use_times"] > vaild_cache["api_invoke_max_count"]:  # 检查次数
+            logger.info("请求使用次数为"+str(vaild_cache["api_concurrent_limit"]))
+            vaild_cache["result"] = StandardResponse(success=False,vaildCode=3, errorMsg="使用次数已经用完")
+            return vaild_cache["result"]
+        vaild_cache["api_use_times"] += 1
+
+        global need_save
+        if need_save == False:
+            Thread(target=save_vaild_cache).start()
+            need_save = True
+
+        return StandardResponse(success=True,vaildCode=0, data=vaild_cache)
+        
+#初始化验证参数
+def init_vaild_cache():
     try:
         config = SiteConfig()
         license_path = config.get_config('LICENSE_PATH')
@@ -30,9 +68,9 @@ def validate_license():
         public_key_pem = None
         license_json = None
         signature = None
+
         with open(os.path.join(license_path, "public.key"), "rb") as f:
             public_key_pem = f.read()
-            
         with open(os.path.join(license_path, "license_issued.lic"), "rb") as f:
             data = json.loads(f.read())
             license_json = json.dumps(data, sort_keys=True).encode()
@@ -40,31 +78,77 @@ def validate_license():
             signature = f.read()
             
         public_key = serialization.load_pem_public_key(public_key_pem)
-        public_key.verify(
-        signature,
-        license_json,
-        padding.PKCS1v15(),
-        hashes.SHA256()
-        )
+        public_key.verify(signature,license_json,padding.PKCS1v15(),hashes.SHA256())
+        license_data = json.loads(license_json.decode())
+
+        content = license_data['content']
+        for c in content:
+            if c['name'] == 'api_invoke_max_count':
+                vaild_cache['api_invoke_max_count'] = c['value']
+            elif c['name'] == 'api_concurrent_limit':
+                vaild_cache['api_concurrent_limit'] = c['value']
+        vaild_cache['expiration_time'] = license_data['expiration_time']
+        vaild_cache['api_use_times'] = 0
+        
+        #从数据库加载 
+        db = next(get_db())
+        result = (db.query(DictSystem).filter_by(dict_name = 'vaild_str',is_deleted = 0).first())
+        print(result)
+        if result is not None: #说明还没有存储,不清楚是不是被删除,从证书上同步
+            dict_value = str(base64.b64decode(bytes(result.dict_value,encoding="utf-8") ),encoding='utf-8')
+            dict_value_json=json.loads(dict_value)
+            vaild_cache['api_use_times'] = dict_value_json['api_use_times']
+            vaild_cache['id'] = result.id
     except:
+        vaild_cache["result"] =  StandardResponse(success=False,vaildCode=9, errorMsg="许可证签名验证失败:是否上传了许可证文件")
         logger.exception("许可证签名验证失败")
-        result = StandardResponse(success=False, errorMsg="许可证签名验证失败:是否上传了许可证文件")
-        _cache['cache_key'] = copy.deepcopy(result)
-        return result
 
-    license_data=json.loads(license_json.decode())
-    # 检查是否过期
-    if time.time()>license_data["expiration_time"]:
-        result = StandardResponse(success=False, errorMsg="许可证已过期")
-        _cache['cache_key'] = copy.deepcopy(result)
-        return result
-    result = StandardResponse(success=True, data=license_json)
-    _cache['cache_key'] = copy.deepcopy(result)
-    return result
+#保存验证的参数,有请求时延迟5分钟触发保存
+def save_vaild_cache():
+    time.sleep(10) #5分钟
+    # time.sleep(5*60) #5分钟
+    if vaild_cache is not None:
+        #保存数据到数据库
+        tmp_cache = {}
+        tmp_cache['api_use_times'] = vaild_cache['api_use_times']
+        new_dict = DictSystem(
+            dict_name='vaild_str',
+            dict_value=str(base64.b64encode(json.dumps(tmp_cache, ensure_ascii=False, indent=1).encode("utf-8")),encoding="utf-8"),
+            is_deleted=0
+        )
+        db = next(get_db())
+        print(vaild_cache)
+        if 'id' in vaild_cache:
+            db.query(DictSystem).filter_by(id=vaild_cache['id']).update({DictSystem.dict_value: new_dict.dict_value})
+            new_dict.id = vaild_cache['id']
+        else:
+            db.add(new_dict)
+        db.commit()
+        vaild_cache['id'] = new_dict.id
+    global need_save
+    need_save = False
+
+#重置验证参数
+def reset_vaild_cache():
+    global vaild_cache
+    if vaild_cache != {} and 'id' in vaild_cache:
+        db = next(get_db())
+        db.delete( db.query(DictSystem).filter_by(id=vaild_cache['id']).first()) #删除信息
+    vaild_cache={}
+    global need_save
+    need_save = False
+    init_vaild_cache()  #再初始化
 
 if __name__ == '__main__':
+    init_vaild_cache()
+    response = validate_license()
+    if response.success:
+        print("1许可证有效!")
+    else:
+        print(response.errorMsg)
+    reset_vaild_cache()
     response = validate_license()
     if response.success:
-        print("许可证有效!")
+        print("3许可证有效!")
     else:
         print(response.errorMsg)