浏览代码

代码提交

SGTY 1 月之前
父节点
当前提交
a2953b23dd

+ 0 - 0
build/lib/knowledge/config/__init__.py


+ 114 - 0
build/lib/knowledge/service/kg_edge_service2.py

@@ -0,0 +1,114 @@
+from sqlalchemy.orm import Session
+from sqlalchemy import or_
+from typing import Optional
+from model.kg_edges import KGEdge
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+from cachetools import TTLCache
+from cachetools.keys import hashkey
+
+logger = logging.getLogger(__name__)
+
+class KGEdgeService:
+    def __init__(self, db: Session):
+        self.db = db
+
+    _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
+    def get_edge(self, edge_id: int):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+        return edge
+
+    def create_edge(self, edge_data: dict):
+        try:
+            existing = self.db.query(KGEdge).filter(
+                KGEdge.src_id == edge_data['src_id'],
+                KGEdge.dest_id == edge_data['dest_id'],
+                KGEdge.name == edge_data['name'],
+                KGEdge.version == edge_data.get('version')
+            ).first()
+
+            if existing:
+                raise ValueError("Edge already exists")
+
+            new_edge = KGEdge(**edge_data)
+            self.db.add(new_edge)
+            self.db.commit()
+            return new_edge
+
+        except IntegrityError as e:
+            self.db.rollback()
+            logger.error(f"创建边失败: {str(e)}")
+            raise ValueError("Database integrity error")
+
+    def update_edge(self, edge_id: int, update_data: dict):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+
+        try:
+            for key, value in update_data.items():
+                setattr(edge, key, value)
+            self.db.commit()
+            return edge
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"更新边失败: {str(e)}")
+            raise ValueError("Update failed")
+
+    def delete_edge(self, edge_id: int):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+
+        try:
+            self.db.delete(edge)
+            self.db.commit()
+            return None
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"删除边失败: {str(e)}")
+            raise ValueError("Delete failed")
+
+    def get_edges_by_nodes(self, src_id: Optional[int]= None, dest_id: Optional[int]= None, category: Optional[str] = None):
+        cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{category}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+
+        if src_id is None and dest_id is None:
+            raise ValueError("至少需要提供一个有效的查询条件")
+        try:
+            filters = []
+            if src_id is not None:
+                filters.append(KGEdge.src_id == src_id)
+            if dest_id is not None:
+                filters.append(KGEdge.dest_id == dest_id)
+            if category is not None:
+                filters.append(KGEdge.category == category)
+            edges = self.db.query(KGEdge).filter(*filters).all()
+
+            from service.kg_node_service import KGNodeService
+            node_service = KGNodeService(self.db)
+            result = []
+            for edge in edges:
+                try:
+                    edge_info = {
+                        'id': edge.id,
+                        'src_id': edge.src_id,
+                        'dest_id': edge.dest_id,
+                        'name': edge.name,
+                        'version': edge.version,
+                        #'src_node': node_service.get_node(edge.src_id),
+                        'dest_node': node_service.get_node(edge.dest_id)
+                    }
+                    result.append(edge_info)
+                except ValueError as e:
+                    logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
+                    continue
+            self._cache[cache_key] = result
+            return result
+        except Exception as e:
+            logger.error(f"查询边失败: {str(e)}")
+            raise e

+ 258 - 0
build/lib/knowledge/service/kg_node_service2.py

@@ -0,0 +1,258 @@
+from sqlalchemy.orm import Session
+from typing import Optional
+from model.kg_node import KGNode
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+
+from utils import vectorizer
+from utils.vectorizer import Vectorizer
+from sqlalchemy import func
+from service.kg_prop_service import KGPropService
+from service.kg_edge_service import KGEdgeService
+from cachetools import TTLCache
+from cachetools.keys import hashkey
+logger = logging.getLogger(__name__)
+DISTANCE_THRESHOLD = 0.65
+class KGNodeService:
+    def __init__(self, db: Session):
+        self.db = db
+
+    _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
+
+    def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
+        cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+
+        query_embedding = Vectorizer.get_embedding(keywrod)
+        db = next(get_db())
+        # 执行向量搜索
+        results = (
+            db.query(
+                KGNode.id,
+                KGNode.name,
+                KGNode.category,
+                KGNode.embedding.l2_distance(query_embedding).label('distance')
+            )
+            .filter(KGNode.status == 0)
+            .filter(KGNode.category == category)
+            #todo 是否能提高性能 改成余弦算法
+            .filter(KGNode.embedding.l2_distance(query_embedding) <= distance)
+            .order_by('distance').limit(top_k).all()
+        )
+        results = [
+            {
+                "id": node.id,
+               "title": node.name,
+               "text": node.category,
+               "score": 2.0-node.distance
+            }
+                for node in results
+            ]
+
+        self._cache[cache_key] = results
+        return results
+
+    def paginated_search(self, search_params: dict) -> dict:
+        load_props = search_params.get('load_props', False)
+        prop_service = KGPropService(self.db)
+        edge_service = KGEdgeService(self.db)
+        keyword = search_params.get('keyword', '')
+        category = search_params.get('category', None)
+        page_no = search_params.get('pageNo', 1)
+        #distance 为NONE或不存在时,使用默认值
+        if search_params.get('distance') is None:
+            distance = DISTANCE_THRESHOLD
+        else:
+            distance = search_params.get('distance')
+        limit = search_params.get('limit', 10)
+
+        if page_no < 1:
+            page_no = 1
+        if limit < 1:
+            limit = 10
+
+        embedding = Vectorizer.get_embedding(keyword)
+        offset = (page_no - 1) * limit
+
+        try:
+            # 构建基础查询条件
+            base_query = self.db.query(func.count(KGNode.id)).filter(
+                KGNode.status == 0,
+                KGNode.embedding.l2_distance(embedding) <= distance
+            )
+            # 如果有category,则添加额外过滤条件
+            if category:
+                base_query = base_query.filter(KGNode.category == category)
+            # 如果有knowledge_ids,则添加额外过滤条件
+            if search_params.get('knowledge_ids'):
+                total_count = base_query.filter(
+                    KGNode.version.in_(search_params['knowledge_ids'])
+                ).scalar()
+            else:
+                total_count = base_query.scalar()
+
+            query = self.db.query(
+                KGNode.id,
+                KGNode.name,
+                KGNode.category,
+                KGNode.embedding.l2_distance(embedding).label('distance')
+            )            
+            query = query.filter(KGNode.status == 0)
+            #category有值时,过滤掉category不等于category的节点
+            if category:
+                query = query.filter(KGNode.category == category)
+            if search_params.get('knowledge_ids'):
+                query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
+            query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
+            results = query.order_by('distance').offset(offset).limit(limit).all()
+
+            return {
+                'records': [{
+                    'id': r.id,
+                    'name': r.name,
+                    'category': r.category,
+                    'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
+                    #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
+                    'distance': round(r.distance, 3)
+                } for r in results],
+                'pagination': {
+                    'total': total_count,
+                    'pageNo': page_no,
+                    'limit': limit,
+                    'totalPages': (total_count + limit - 1) // limit
+                }
+            
+            }
+        except Exception as e:
+            logger.error(f"分页查询失败: {str(e)}")
+            raise e
+
+    def create_node(self, node_data: dict):
+        try:
+            existing = self.db.query(KGNode).filter(
+                KGNode.name == node_data['name'],
+                KGNode.category == node_data['category'],
+                KGNode.version == node_data.get('version')
+            ).first()
+            
+            if existing:
+                raise ValueError("Node already exists")
+
+            new_node = KGNode(**node_data)
+            self.db.add(new_node)
+            self.db.commit()
+            return new_node
+
+        except IntegrityError as e:
+            self.db.rollback()
+            logger.error(f"创建节点失败: {str(e)}")
+            raise ValueError("Database integrity error")
+
+    def get_node(self, node_id: int):
+        if node_id is None:
+            raise ValueError("Node ID is required")
+        
+        cache_key = f"get_node_{node_id}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+        
+        node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
+        
+        if not node:
+            raise ValueError("Node not found")
+        
+        node_data = {
+            'id': node.id,
+            'name': node.name,
+            'category': node.category,
+            'version': node.version
+        }
+        self._cache[cache_key] = node_data
+        return node_data
+
+    def update_node(self, node_id: int, update_data: dict):
+        node = self.db.query(KGNode).get(node_id)
+        if not node:
+            raise ValueError("Node not found")
+
+        try:
+            for key, value in update_data.items():
+                setattr(node, key, value)
+            self.db.commit()
+            return node
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"更新节点失败: {str(e)}")
+            raise ValueError("Update failed")
+
+    def delete_node(self, node_id: int):
+        node = self.db.query(KGNode).get(node_id)
+        if not node:
+            raise ValueError("Node not found")
+
+        try:
+            self.db.delete(node)
+            self.db.commit()
+            return None
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"删除节点失败: {str(e)}")
+            raise ValueError("Delete failed")
+
+    def batch_process_er_nodes(self):
+        batch_size = 200
+        offset = 0
+
+        while True:
+            try:
+                #下面的查询语句,增加根据id排序,防止并发问题
+                nodes = self.db.query(KGNode).filter(
+                    #KGNode.version == 'ER',
+                    KGNode.embedding == None
+                ).order_by(KGNode.id).offset(offset).limit(batch_size).all()
+
+                if not nodes:
+                    break
+
+                updated_nodes = []
+                for node in nodes:
+                    if not node.embedding:
+                        embedding = Vectorizer.get_embedding(node.name)
+                        node.embedding = embedding
+                        updated_nodes.append(node)
+                if updated_nodes:
+                    self.db.commit()
+
+                offset += batch_size
+            except Exception as e:
+                self.db.rollback()
+                print(f"批量处理ER节点失败: {str(e)}")
+                raise ValueError("Batch process failed")
+
+    def get_node_by_name_category(self, name: str, category: str):
+        if not name or not category:
+            raise ValueError("Name and category are required")
+        
+        cache_key = f"get_node_by_name_category_{name}:{category}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+        
+        node = self.db.query(KGNode).filter(
+            KGNode.name == name,
+            KGNode.category == category,
+            KGNode.status == 0
+        ).first()
+        
+        if not node:
+            return None
+        
+        node_data = {
+            'id': node.id,
+            'name': node.name,
+            'category': node.category,
+            'version': node.version
+        }
+        self._cache[cache_key] = node_data
+        return node_data

+ 1 - 0
src/knowledge.egg-info/PKG-INFO

@@ -13,4 +13,5 @@ Requires-Dist: uvicorn==0.34.0
 Requires-Dist: psycopg2-binary==2.9.10
 Requires-Dist: psycopg2-binary==2.9.10
 Requires-Dist: python-dotenv==1.0.0
 Requires-Dist: python-dotenv==1.0.0
 Requires-Dist: hui-tools[all]==0.5.8
 Requires-Dist: hui-tools[all]==0.5.8
+Requires-Dist: cachetools==6.1.0
 Dynamic: requires-dist
 Dynamic: requires-dist

+ 2 - 0
src/knowledge.egg-info/SOURCES.txt

@@ -26,7 +26,9 @@ src/knowledge/router/base.py
 src/knowledge/router/knowledge_nodes_api.py
 src/knowledge/router/knowledge_nodes_api.py
 src/knowledge/service/__init__.py
 src/knowledge/service/__init__.py
 src/knowledge/service/kg_edge_service.py
 src/knowledge/service/kg_edge_service.py
+src/knowledge/service/kg_edge_service2.py
 src/knowledge/service/kg_node_service.py
 src/knowledge/service/kg_node_service.py
+src/knowledge/service/kg_node_service2.py
 src/knowledge/service/kg_prop_service.py
 src/knowledge/service/kg_prop_service.py
 src/knowledge/settings/__init__.py
 src/knowledge/settings/__init__.py
 src/knowledge/settings/auth_setting.py
 src/knowledge/settings/auth_setting.py

+ 1 - 0
src/knowledge.egg-info/requires.txt

@@ -10,3 +10,4 @@ uvicorn==0.34.0
 psycopg2-binary==2.9.10
 psycopg2-binary==2.9.10
 python-dotenv==1.0.0
 python-dotenv==1.0.0
 hui-tools[all]==0.5.8
 hui-tools[all]==0.5.8
+cachetools==6.1.0