|
- from sqlalchemy.orm import Session
- from typing import Optional
- from model.kg_node import KGNode
- from db.session import get_db
- import logging
- from sqlalchemy.exc import IntegrityError
- from utils import vectorizer
- from utils.vectorizer import Vectorizer
- from sqlalchemy import func
- from service.kg_prop_service import KGPropService
- from service.kg_edge_service import KGEdgeService
- from cachetools import TTLCache
- from cachetools.keys import hashkey
- logger = logging.getLogger(__name__)
- DISTANCE_THRESHOLD = 0.65
- class KGNodeService:
- def __init__(self, db: Session):
- self.db = db
- _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
- def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
- cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
- if cache_key in self._cache:
- return self._cache[cache_key]
- query_embedding = Vectorizer.get_embedding(keywrod)
- db = next(get_db())
- # 执行向量搜索
- results = (
- db.query(
- KGNode.id,
- KGNode.name,
- KGNode.category,
- KGNode.embedding.l2_distance(query_embedding).label('distance')
- )
- .filter(KGNode.status == 0)
- .filter(KGNode.category == category)
- #todo 是否能提高性能 改成余弦算法
- .filter(KGNode.embedding.l2_distance(query_embedding) <= distance)
- .order_by('distance').limit(top_k).all()
- )
- results = [
- {
- "id": node.id,
- "title": node.name,
- "text": node.category,
- "score": 2.0-node.distance
- }
- for node in results
- ]
- self._cache[cache_key] = results
- return results
- def paginated_search(self, search_params: dict) -> dict:
- load_props = search_params.get('load_props', False)
- prop_service = KGPropService(self.db)
- edge_service = KGEdgeService(self.db)
- keyword = search_params.get('keyword', '')
- category = search_params.get('category', None)
- page_no = search_params.get('pageNo', 1)
- #distance 为NONE或不存在时,使用默认值
- if search_params.get('distance') is None:
- distance = DISTANCE_THRESHOLD
- else:
- distance = search_params.get('distance')
- limit = search_params.get('limit', 10)
- if page_no < 1:
- page_no = 1
- if limit < 1:
- limit = 10
- embedding = Vectorizer.get_embedding(keyword)
- offset = (page_no - 1) * limit
- try:
- # 构建基础查询条件
- base_query = self.db.query(func.count(KGNode.id)).filter(
- KGNode.status == 0,
- KGNode.embedding.l2_distance(embedding) <= distance
- )
- # 如果有category,则添加额外过滤条件
- if category:
- base_query = base_query.filter(KGNode.category == category)
- # 如果有knowledge_ids,则添加额外过滤条件
- if search_params.get('knowledge_ids'):
- total_count = base_query.filter(
- KGNode.version.in_(search_params['knowledge_ids'])
- ).scalar()
- else:
- total_count = base_query.scalar()
- query = self.db.query(
- KGNode.id,
- KGNode.name,
- KGNode.category,
- KGNode.embedding.l2_distance(embedding).label('distance')
- )
- query = query.filter(KGNode.status == 0)
- #category有值时,过滤掉category不等于category的节点
- if category:
- query = query.filter(KGNode.category == category)
- if search_params.get('knowledge_ids'):
- query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
- query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
- results = query.order_by('distance').offset(offset).limit(limit).all()
- return {
- 'records': [{
- 'id': r.id,
- 'name': r.name,
- 'category': r.category,
- 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
- #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
- 'distance': round(r.distance, 3)
- } for r in results],
- 'pagination': {
- 'total': total_count,
- 'pageNo': page_no,
- 'limit': limit,
- 'totalPages': (total_count + limit - 1) // limit
- }
-
- }
- except Exception as e:
- logger.error(f"分页查询失败: {str(e)}")
- raise e
- def create_node(self, node_data: dict):
- try:
- existing = self.db.query(KGNode).filter(
- KGNode.name == node_data['name'],
- KGNode.category == node_data['category'],
- KGNode.version == node_data.get('version')
- ).first()
-
- if existing:
- raise ValueError("Node already exists")
- new_node = KGNode(**node_data)
- self.db.add(new_node)
- self.db.commit()
- return new_node
- except IntegrityError as e:
- self.db.rollback()
- logger.error(f"创建节点失败: {str(e)}")
- raise ValueError("Database integrity error")
- def get_node(self, node_id: int):
- if node_id is None:
- raise ValueError("Node ID is required")
-
- cache_key = f"get_node_{node_id}"
- if cache_key in self._cache:
- return self._cache[cache_key]
-
- node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
-
- if not node:
- raise ValueError("Node not found")
-
- node_data = {
- 'id': node.id,
- 'name': node.name,
- 'category': node.category,
- 'version': node.version
- }
- self._cache[cache_key] = node_data
- return node_data
- def update_node(self, node_id: int, update_data: dict):
- node = self.db.query(KGNode).get(node_id)
- if not node:
- raise ValueError("Node not found")
- try:
- for key, value in update_data.items():
- setattr(node, key, value)
- self.db.commit()
- return node
- except Exception as e:
- self.db.rollback()
- logger.error(f"更新节点失败: {str(e)}")
- raise ValueError("Update failed")
- def delete_node(self, node_id: int):
- node = self.db.query(KGNode).get(node_id)
- if not node:
- raise ValueError("Node not found")
- try:
- self.db.delete(node)
- self.db.commit()
- return None
- except Exception as e:
- self.db.rollback()
- logger.error(f"删除节点失败: {str(e)}")
- raise ValueError("Delete failed")
- def batch_process_er_nodes(self):
- batch_size = 200
- offset = 0
- while True:
- try:
- #下面的查询语句,增加根据id排序,防止并发问题
- nodes = self.db.query(KGNode).filter(
- #KGNode.version == 'ER',
- KGNode.embedding == None
- ).order_by(KGNode.id).offset(offset).limit(batch_size).all()
- if not nodes:
- break
- updated_nodes = []
- for node in nodes:
- if not node.embedding:
- embedding = Vectorizer.get_embedding(node.name)
- node.embedding = embedding
- updated_nodes.append(node)
- if updated_nodes:
- self.db.commit()
- offset += batch_size
- except Exception as e:
- self.db.rollback()
- print(f"批量处理ER节点失败: {str(e)}")
- raise ValueError("Batch process failed")
- def get_node_by_name_category(self, name: str, category: str):
- if not name or not category:
- raise ValueError("Name and category are required")
-
- cache_key = f"get_node_by_name_category_{name}:{category}"
- if cache_key in self._cache:
- return self._cache[cache_key]
-
- node = self.db.query(KGNode).filter(
- KGNode.name == name,
- KGNode.category == category,
- KGNode.status == 0
- ).first()
-
- if not node:
- return None
-
- node_data = {
- 'id': node.id,
- 'name': node.name,
- 'category': node.category,
- 'version': node.version
- }
- self._cache[cache_key] = node_data
- return node_data
|