SGTY hai 1 mes
pai
achega
2ae5cfea28

+ 1 - 1
src/knowledge/.env

@@ -4,5 +4,5 @@ DB_PORT=5432
 DB_USER=knowledge
 DB_PASSWORD=qwer1234.
 
-license=E:\project\knowledge\license_issued
+license=E:\project\knowledge\src\knowledge\utils\license_issued
 EMBEDDING_MODEL=E:\project\bge-m3

+ 74 - 3
src/knowledge/middlewares/base.py

@@ -1,4 +1,5 @@
 # @Desc: { 模块描述 }
+import json
 import time
 
 from fastapi import Request, status
@@ -8,10 +9,12 @@ from fastapi.responses import Response
 from py_tools.logging import logger
 from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 
+from ..config.site import SiteConfig
+from ..service.dict_system_service import DictSystemService
 from ..settings import auth_setting
 from ..utils.trace_util import TraceUtil
 from typing import Optional
-
+from cachetools import TTLCache
 
 class LoggingMiddleware(BaseHTTPMiddleware):
     """
@@ -157,6 +160,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
                 return True
         return False
 
+    _cache = TTLCache(1, 360)
     async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
         # 初始化请求上下文
         request.state.context = {
@@ -180,15 +184,82 @@ class AuthMiddleware(BaseHTTPMiddleware):
 
         # 初始化操作:将用户信息添加到请求状态中
         request.state.user = user_info
-
+        #cache["license_info"]是否存在license信息,如果不存在则验证证书
+        if not 'license_info' in self._cache or not self._cache["license_info"]:
+            lisence_detail = license_handle()
+            #判断license_info是否是个json,且存在key为expiration_time的字段,如果存在则判断是否过期
+            if lisence_detail and "expiration_time" in lisence_detail:
+                license_info = json.loads(lisence_detail)
+                expiration_time = license_info.get("expiration_time")
+                if expiration_time and expiration_time < time.time():
+                    return self.set_auth_err_resp("License expired")
+                else:
+                    self._cache["license_info"] = license_info
+            else:
+                return self.set_auth_err_resp(lisence_detail)
+
+            if hasattr(self, '_invok_info'):
+                current_count = self._invok_info.get('api_invoke_count', 0)
+                  # 从字典服务获取调用次数的枚举值
+                dict_service = DictSystemService(self.db)
+                try:
+                    call_count_dict = dict_service.get_dicts_by_name('api_invoke_count') 
+                    old_increment_value = int(call_count_dict.dict_value)
+                    
+                    # 更新调用次数
+                    updated_count = current_count + old_increment_value
+                    
+                    # 更新数据库中的调用次数
+                    dict_service.update_dict(call_count_dict['id'], {'dict_value': str(updated_count)})
+                    self._invok_info['api_invoke_count'] = 0
+
+                    api_invoke_max_count = license_info.get("content")[0].get("api_invoke_max_count") 
+                    if updated_count >= api_invoke_max_count:
+                        return self.set_auth_err_resp("调用次数已达上限!")
+                except Exception as e:
+                    logger.error(f"更新调用次数失败: {str(e)}")
+                
         # 继续处理请求
         response = await call_next(request)
+
+        # 统计调用次数
+        if not hasattr(self, '_invok_info'):
+            self._invok_info = TTLCache(maxsize=20)
+        
+        current_count = self._invok_info.get('api_invoke_count', 0)
+        self._invok_info['api_invoke_count'] = current_count + 1
         # 可以在返回前添加统一响应处理(如添加头信息)
         # response.headers["request-id"] = request.state.context["request_id"]
 
         return response
 
-
+def license_handle():
+    """验证证书"""
+    
+    license_dir = SiteConfig().get_config("license")
+        
+    try:
+        if not license_dir:
+            return "license目录未配置"
+            
+        with open(f"{license_dir}/public.key", "rb") as f:
+            public_key_pem = f.read()
+        with open(f"{license_dir}/license_issued.lic", "rb") as f:
+            data = json.loads(f.read())
+            license_json = json.dumps(data, sort_keys=True).encode()
+        with open(f"{license_dir}/license_issued.key", "rb") as f:
+            signature = f.read()
+            
+        from ..utils.license import validate_license
+        if not validate_license(public_key_pem, license_json, signature):
+            return "license验证失败"
+            
+        # 验证成功,将license_json存入缓存
+        return license_json
+        
+    except Exception as e:
+        # 验证失败,将错误信息存入缓存
+        return str(e)
 
 def register_middlewares():
     """注册中间件(逆序执行)"""

+ 0 - 17
src/knowledge/router/base.py

@@ -1,17 +0,0 @@
-#!/usr/bin/python3
-# -*- coding: utf-8 -*-
-# @Author: Hui
-# @Desc: { 模块描述 }
-# @Date: 2023/11/16 14:10
-# import fastapi
-#
-# from ..settings import log_setting
-# from ..middlewares.api_route import LoggingAPIRoute
-
-
-# class BaseAPIRouter(fastapi.APIRouter):
-#     def __init__(self, *args, api_log=log_setting.server_access_log, **kwargs):
-#         super().__init__(*args, **kwargs)
-#         if api_log:
-#             # 开启api请求日志信息
-#             self.route_class = LoggingAPIRoute

+ 2 - 2
src/knowledge/server.py

@@ -5,6 +5,7 @@ from fastapi import FastAPI
 from py_tools.connections.http import AsyncHttpClient
 from py_tools.logging import logger
 
+from .config.site import SiteConfig
 from .middlewares.base import register_middlewares
 from .router.knowledge_nodes_api import knowledge_nodes_api_router
 from .utils import log_util
@@ -37,8 +38,7 @@ async def init_setup():
     """初始化项目配置"""
 
     log_util.setup_logger()
-
-
+    
 async def startup():
     """项目启动时准备环境"""
 

+ 0 - 114
src/knowledge/service/kg_edge_service2.py

@@ -1,114 +0,0 @@
-from sqlalchemy.orm import Session
-from sqlalchemy import or_
-from typing import Optional
-from model.kg_edges import KGEdge
-from db.session import get_db
-import logging
-from sqlalchemy.exc import IntegrityError
-from cachetools import TTLCache
-from cachetools.keys import hashkey
-
-logger = logging.getLogger(__name__)
-
-class KGEdgeService:
-    def __init__(self, db: Session):
-        self.db = db
-
-    _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
-    def get_edge(self, edge_id: int):
-        edge = self.db.query(KGEdge).get(edge_id)
-        if not edge:
-            raise ValueError("Edge not found")
-        return edge
-
-    def create_edge(self, edge_data: dict):
-        try:
-            existing = self.db.query(KGEdge).filter(
-                KGEdge.src_id == edge_data['src_id'],
-                KGEdge.dest_id == edge_data['dest_id'],
-                KGEdge.name == edge_data['name'],
-                KGEdge.version == edge_data.get('version')
-            ).first()
-
-            if existing:
-                raise ValueError("Edge already exists")
-
-            new_edge = KGEdge(**edge_data)
-            self.db.add(new_edge)
-            self.db.commit()
-            return new_edge
-
-        except IntegrityError as e:
-            self.db.rollback()
-            logger.error(f"创建边失败: {str(e)}")
-            raise ValueError("Database integrity error")
-
-    def update_edge(self, edge_id: int, update_data: dict):
-        edge = self.db.query(KGEdge).get(edge_id)
-        if not edge:
-            raise ValueError("Edge not found")
-
-        try:
-            for key, value in update_data.items():
-                setattr(edge, key, value)
-            self.db.commit()
-            return edge
-        except Exception as e:
-            self.db.rollback()
-            logger.error(f"更新边失败: {str(e)}")
-            raise ValueError("Update failed")
-
-    def delete_edge(self, edge_id: int):
-        edge = self.db.query(KGEdge).get(edge_id)
-        if not edge:
-            raise ValueError("Edge not found")
-
-        try:
-            self.db.delete(edge)
-            self.db.commit()
-            return None
-        except Exception as e:
-            self.db.rollback()
-            logger.error(f"删除边失败: {str(e)}")
-            raise ValueError("Delete failed")
-
-    def get_edges_by_nodes(self, src_id: Optional[int]= None, dest_id: Optional[int]= None, category: Optional[str] = None):
-        cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{category}"
-        if cache_key in self._cache:
-            return self._cache[cache_key]
-
-        if src_id is None and dest_id is None:
-            raise ValueError("至少需要提供一个有效的查询条件")
-        try:
-            filters = []
-            if src_id is not None:
-                filters.append(KGEdge.src_id == src_id)
-            if dest_id is not None:
-                filters.append(KGEdge.dest_id == dest_id)
-            if category is not None:
-                filters.append(KGEdge.category == category)
-            edges = self.db.query(KGEdge).filter(*filters).all()
-
-            from service.kg_node_service import KGNodeService
-            node_service = KGNodeService(self.db)
-            result = []
-            for edge in edges:
-                try:
-                    edge_info = {
-                        'id': edge.id,
-                        'src_id': edge.src_id,
-                        'dest_id': edge.dest_id,
-                        'name': edge.name,
-                        'version': edge.version,
-                        #'src_node': node_service.get_node(edge.src_id),
-                        'dest_node': node_service.get_node(edge.dest_id)
-                    }
-                    result.append(edge_info)
-                except ValueError as e:
-                    logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
-                    continue
-            self._cache[cache_key] = result
-            return result
-        except Exception as e:
-            logger.error(f"查询边失败: {str(e)}")
-            raise e

+ 0 - 258
src/knowledge/service/kg_node_service2.py

@@ -1,258 +0,0 @@
-from sqlalchemy.orm import Session
-from typing import Optional
-from model.kg_node import KGNode
-from db.session import get_db
-import logging
-from sqlalchemy.exc import IntegrityError
-
-from utils import vectorizer
-from utils.vectorizer import Vectorizer
-from sqlalchemy import func
-from service.kg_prop_service import KGPropService
-from service.kg_edge_service import KGEdgeService
-from cachetools import TTLCache
-from cachetools.keys import hashkey
-logger = logging.getLogger(__name__)
-DISTANCE_THRESHOLD = 0.65
-class KGNodeService:
-    def __init__(self, db: Session):
-        self.db = db
-
-    _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
-
-    def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
-        cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
-        if cache_key in self._cache:
-            return self._cache[cache_key]
-
-        query_embedding = Vectorizer.get_embedding(keywrod)
-        db = next(get_db())
-        # 执行向量搜索
-        results = (
-            db.query(
-                KGNode.id,
-                KGNode.name,
-                KGNode.category,
-                KGNode.embedding.l2_distance(query_embedding).label('distance')
-            )
-            .filter(KGNode.status == 0)
-            .filter(KGNode.category == category)
-            #todo 是否能提高性能 改成余弦算法
-            .filter(KGNode.embedding.l2_distance(query_embedding) <= distance)
-            .order_by('distance').limit(top_k).all()
-        )
-        results = [
-            {
-                "id": node.id,
-               "title": node.name,
-               "text": node.category,
-               "score": 2.0-node.distance
-            }
-                for node in results
-            ]
-
-        self._cache[cache_key] = results
-        return results
-
-    def paginated_search(self, search_params: dict) -> dict:
-        load_props = search_params.get('load_props', False)
-        prop_service = KGPropService(self.db)
-        edge_service = KGEdgeService(self.db)
-        keyword = search_params.get('keyword', '')
-        category = search_params.get('category', None)
-        page_no = search_params.get('pageNo', 1)
-        #distance 为NONE或不存在时,使用默认值
-        if search_params.get('distance') is None:
-            distance = DISTANCE_THRESHOLD
-        else:
-            distance = search_params.get('distance')
-        limit = search_params.get('limit', 10)
-
-        if page_no < 1:
-            page_no = 1
-        if limit < 1:
-            limit = 10
-
-        embedding = Vectorizer.get_embedding(keyword)
-        offset = (page_no - 1) * limit
-
-        try:
-            # 构建基础查询条件
-            base_query = self.db.query(func.count(KGNode.id)).filter(
-                KGNode.status == 0,
-                KGNode.embedding.l2_distance(embedding) <= distance
-            )
-            # 如果有category,则添加额外过滤条件
-            if category:
-                base_query = base_query.filter(KGNode.category == category)
-            # 如果有knowledge_ids,则添加额外过滤条件
-            if search_params.get('knowledge_ids'):
-                total_count = base_query.filter(
-                    KGNode.version.in_(search_params['knowledge_ids'])
-                ).scalar()
-            else:
-                total_count = base_query.scalar()
-
-            query = self.db.query(
-                KGNode.id,
-                KGNode.name,
-                KGNode.category,
-                KGNode.embedding.l2_distance(embedding).label('distance')
-            )            
-            query = query.filter(KGNode.status == 0)
-            #category有值时,过滤掉category不等于category的节点
-            if category:
-                query = query.filter(KGNode.category == category)
-            if search_params.get('knowledge_ids'):
-                query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
-            query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
-            results = query.order_by('distance').offset(offset).limit(limit).all()
-
-            return {
-                'records': [{
-                    'id': r.id,
-                    'name': r.name,
-                    'category': r.category,
-                    'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
-                    #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
-                    'distance': round(r.distance, 3)
-                } for r in results],
-                'pagination': {
-                    'total': total_count,
-                    'pageNo': page_no,
-                    'limit': limit,
-                    'totalPages': (total_count + limit - 1) // limit
-                }
-            
-            }
-        except Exception as e:
-            logger.error(f"分页查询失败: {str(e)}")
-            raise e
-
-    def create_node(self, node_data: dict):
-        try:
-            existing = self.db.query(KGNode).filter(
-                KGNode.name == node_data['name'],
-                KGNode.category == node_data['category'],
-                KGNode.version == node_data.get('version')
-            ).first()
-            
-            if existing:
-                raise ValueError("Node already exists")
-
-            new_node = KGNode(**node_data)
-            self.db.add(new_node)
-            self.db.commit()
-            return new_node
-
-        except IntegrityError as e:
-            self.db.rollback()
-            logger.error(f"创建节点失败: {str(e)}")
-            raise ValueError("Database integrity error")
-
-    def get_node(self, node_id: int):
-        if node_id is None:
-            raise ValueError("Node ID is required")
-        
-        cache_key = f"get_node_{node_id}"
-        if cache_key in self._cache:
-            return self._cache[cache_key]
-        
-        node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
-        
-        if not node:
-            raise ValueError("Node not found")
-        
-        node_data = {
-            'id': node.id,
-            'name': node.name,
-            'category': node.category,
-            'version': node.version
-        }
-        self._cache[cache_key] = node_data
-        return node_data
-
-    def update_node(self, node_id: int, update_data: dict):
-        node = self.db.query(KGNode).get(node_id)
-        if not node:
-            raise ValueError("Node not found")
-
-        try:
-            for key, value in update_data.items():
-                setattr(node, key, value)
-            self.db.commit()
-            return node
-        except Exception as e:
-            self.db.rollback()
-            logger.error(f"更新节点失败: {str(e)}")
-            raise ValueError("Update failed")
-
-    def delete_node(self, node_id: int):
-        node = self.db.query(KGNode).get(node_id)
-        if not node:
-            raise ValueError("Node not found")
-
-        try:
-            self.db.delete(node)
-            self.db.commit()
-            return None
-        except Exception as e:
-            self.db.rollback()
-            logger.error(f"删除节点失败: {str(e)}")
-            raise ValueError("Delete failed")
-
-    def batch_process_er_nodes(self):
-        batch_size = 200
-        offset = 0
-
-        while True:
-            try:
-                #下面的查询语句,增加根据id排序,防止并发问题
-                nodes = self.db.query(KGNode).filter(
-                    #KGNode.version == 'ER',
-                    KGNode.embedding == None
-                ).order_by(KGNode.id).offset(offset).limit(batch_size).all()
-
-                if not nodes:
-                    break
-
-                updated_nodes = []
-                for node in nodes:
-                    if not node.embedding:
-                        embedding = Vectorizer.get_embedding(node.name)
-                        node.embedding = embedding
-                        updated_nodes.append(node)
-                if updated_nodes:
-                    self.db.commit()
-
-                offset += batch_size
-            except Exception as e:
-                self.db.rollback()
-                print(f"批量处理ER节点失败: {str(e)}")
-                raise ValueError("Batch process failed")
-
-    def get_node_by_name_category(self, name: str, category: str):
-        if not name or not category:
-            raise ValueError("Name and category are required")
-        
-        cache_key = f"get_node_by_name_category_{name}:{category}"
-        if cache_key in self._cache:
-            return self._cache[cache_key]
-        
-        node = self.db.query(KGNode).filter(
-            KGNode.name == name,
-            KGNode.category == category,
-            KGNode.status == 0
-        ).first()
-        
-        if not node:
-            return None
-        
-        node_data = {
-            'id': node.id,
-            'name': node.name,
-            'category': node.category,
-            'version': node.version
-        }
-        self._cache[cache_key] = node_data
-        return node_data