|
@@ -1,258 +0,0 @@
|
|
-from sqlalchemy.orm import Session
|
|
|
|
-from typing import Optional
|
|
|
|
-from model.kg_node import KGNode
|
|
|
|
-from db.session import get_db
|
|
|
|
-import logging
|
|
|
|
-from sqlalchemy.exc import IntegrityError
|
|
|
|
-
|
|
|
|
-from utils import vectorizer
|
|
|
|
-from utils.vectorizer import Vectorizer
|
|
|
|
-from sqlalchemy import func
|
|
|
|
-from service.kg_prop_service import KGPropService
|
|
|
|
-from service.kg_edge_service import KGEdgeService
|
|
|
|
-from cachetools import TTLCache
|
|
|
|
-from cachetools.keys import hashkey
|
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
|
-DISTANCE_THRESHOLD = 0.65
|
|
|
|
-class KGNodeService:
|
|
|
|
- def __init__(self, db: Session):
|
|
|
|
- self.db = db
|
|
|
|
-
|
|
|
|
- _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
|
|
|
|
-
|
|
|
|
- def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
|
|
|
|
- cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
|
|
|
|
- if cache_key in self._cache:
|
|
|
|
- return self._cache[cache_key]
|
|
|
|
-
|
|
|
|
- query_embedding = Vectorizer.get_embedding(keywrod)
|
|
|
|
- db = next(get_db())
|
|
|
|
- # 执行向量搜索
|
|
|
|
- results = (
|
|
|
|
- db.query(
|
|
|
|
- KGNode.id,
|
|
|
|
- KGNode.name,
|
|
|
|
- KGNode.category,
|
|
|
|
- KGNode.embedding.l2_distance(query_embedding).label('distance')
|
|
|
|
- )
|
|
|
|
- .filter(KGNode.status == 0)
|
|
|
|
- .filter(KGNode.category == category)
|
|
|
|
- #todo 是否能提高性能 改成余弦算法
|
|
|
|
- .filter(KGNode.embedding.l2_distance(query_embedding) <= distance)
|
|
|
|
- .order_by('distance').limit(top_k).all()
|
|
|
|
- )
|
|
|
|
- results = [
|
|
|
|
- {
|
|
|
|
- "id": node.id,
|
|
|
|
- "title": node.name,
|
|
|
|
- "text": node.category,
|
|
|
|
- "score": 2.0-node.distance
|
|
|
|
- }
|
|
|
|
- for node in results
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- self._cache[cache_key] = results
|
|
|
|
- return results
|
|
|
|
-
|
|
|
|
- def paginated_search(self, search_params: dict) -> dict:
|
|
|
|
- load_props = search_params.get('load_props', False)
|
|
|
|
- prop_service = KGPropService(self.db)
|
|
|
|
- edge_service = KGEdgeService(self.db)
|
|
|
|
- keyword = search_params.get('keyword', '')
|
|
|
|
- category = search_params.get('category', None)
|
|
|
|
- page_no = search_params.get('pageNo', 1)
|
|
|
|
- #distance 为NONE或不存在时,使用默认值
|
|
|
|
- if search_params.get('distance') is None:
|
|
|
|
- distance = DISTANCE_THRESHOLD
|
|
|
|
- else:
|
|
|
|
- distance = search_params.get('distance')
|
|
|
|
- limit = search_params.get('limit', 10)
|
|
|
|
-
|
|
|
|
- if page_no < 1:
|
|
|
|
- page_no = 1
|
|
|
|
- if limit < 1:
|
|
|
|
- limit = 10
|
|
|
|
-
|
|
|
|
- embedding = Vectorizer.get_embedding(keyword)
|
|
|
|
- offset = (page_no - 1) * limit
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- # 构建基础查询条件
|
|
|
|
- base_query = self.db.query(func.count(KGNode.id)).filter(
|
|
|
|
- KGNode.status == 0,
|
|
|
|
- KGNode.embedding.l2_distance(embedding) <= distance
|
|
|
|
- )
|
|
|
|
- # 如果有category,则添加额外过滤条件
|
|
|
|
- if category:
|
|
|
|
- base_query = base_query.filter(KGNode.category == category)
|
|
|
|
- # 如果有knowledge_ids,则添加额外过滤条件
|
|
|
|
- if search_params.get('knowledge_ids'):
|
|
|
|
- total_count = base_query.filter(
|
|
|
|
- KGNode.version.in_(search_params['knowledge_ids'])
|
|
|
|
- ).scalar()
|
|
|
|
- else:
|
|
|
|
- total_count = base_query.scalar()
|
|
|
|
-
|
|
|
|
- query = self.db.query(
|
|
|
|
- KGNode.id,
|
|
|
|
- KGNode.name,
|
|
|
|
- KGNode.category,
|
|
|
|
- KGNode.embedding.l2_distance(embedding).label('distance')
|
|
|
|
- )
|
|
|
|
- query = query.filter(KGNode.status == 0)
|
|
|
|
- #category有值时,过滤掉category不等于category的节点
|
|
|
|
- if category:
|
|
|
|
- query = query.filter(KGNode.category == category)
|
|
|
|
- if search_params.get('knowledge_ids'):
|
|
|
|
- query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
|
|
|
|
- query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
|
|
|
|
- results = query.order_by('distance').offset(offset).limit(limit).all()
|
|
|
|
-
|
|
|
|
- return {
|
|
|
|
- 'records': [{
|
|
|
|
- 'id': r.id,
|
|
|
|
- 'name': r.name,
|
|
|
|
- 'category': r.category,
|
|
|
|
- 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
|
|
|
|
- #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
|
|
|
|
- 'distance': round(r.distance, 3)
|
|
|
|
- } for r in results],
|
|
|
|
- 'pagination': {
|
|
|
|
- 'total': total_count,
|
|
|
|
- 'pageNo': page_no,
|
|
|
|
- 'limit': limit,
|
|
|
|
- 'totalPages': (total_count + limit - 1) // limit
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- }
|
|
|
|
- except Exception as e:
|
|
|
|
- logger.error(f"分页查询失败: {str(e)}")
|
|
|
|
- raise e
|
|
|
|
-
|
|
|
|
- def create_node(self, node_data: dict):
|
|
|
|
- try:
|
|
|
|
- existing = self.db.query(KGNode).filter(
|
|
|
|
- KGNode.name == node_data['name'],
|
|
|
|
- KGNode.category == node_data['category'],
|
|
|
|
- KGNode.version == node_data.get('version')
|
|
|
|
- ).first()
|
|
|
|
-
|
|
|
|
- if existing:
|
|
|
|
- raise ValueError("Node already exists")
|
|
|
|
-
|
|
|
|
- new_node = KGNode(**node_data)
|
|
|
|
- self.db.add(new_node)
|
|
|
|
- self.db.commit()
|
|
|
|
- return new_node
|
|
|
|
-
|
|
|
|
- except IntegrityError as e:
|
|
|
|
- self.db.rollback()
|
|
|
|
- logger.error(f"创建节点失败: {str(e)}")
|
|
|
|
- raise ValueError("Database integrity error")
|
|
|
|
-
|
|
|
|
- def get_node(self, node_id: int):
|
|
|
|
- if node_id is None:
|
|
|
|
- raise ValueError("Node ID is required")
|
|
|
|
-
|
|
|
|
- cache_key = f"get_node_{node_id}"
|
|
|
|
- if cache_key in self._cache:
|
|
|
|
- return self._cache[cache_key]
|
|
|
|
-
|
|
|
|
- node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
|
|
|
|
-
|
|
|
|
- if not node:
|
|
|
|
- raise ValueError("Node not found")
|
|
|
|
-
|
|
|
|
- node_data = {
|
|
|
|
- 'id': node.id,
|
|
|
|
- 'name': node.name,
|
|
|
|
- 'category': node.category,
|
|
|
|
- 'version': node.version
|
|
|
|
- }
|
|
|
|
- self._cache[cache_key] = node_data
|
|
|
|
- return node_data
|
|
|
|
-
|
|
|
|
- def update_node(self, node_id: int, update_data: dict):
|
|
|
|
- node = self.db.query(KGNode).get(node_id)
|
|
|
|
- if not node:
|
|
|
|
- raise ValueError("Node not found")
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- for key, value in update_data.items():
|
|
|
|
- setattr(node, key, value)
|
|
|
|
- self.db.commit()
|
|
|
|
- return node
|
|
|
|
- except Exception as e:
|
|
|
|
- self.db.rollback()
|
|
|
|
- logger.error(f"更新节点失败: {str(e)}")
|
|
|
|
- raise ValueError("Update failed")
|
|
|
|
-
|
|
|
|
- def delete_node(self, node_id: int):
|
|
|
|
- node = self.db.query(KGNode).get(node_id)
|
|
|
|
- if not node:
|
|
|
|
- raise ValueError("Node not found")
|
|
|
|
-
|
|
|
|
- try:
|
|
|
|
- self.db.delete(node)
|
|
|
|
- self.db.commit()
|
|
|
|
- return None
|
|
|
|
- except Exception as e:
|
|
|
|
- self.db.rollback()
|
|
|
|
- logger.error(f"删除节点失败: {str(e)}")
|
|
|
|
- raise ValueError("Delete failed")
|
|
|
|
-
|
|
|
|
- def batch_process_er_nodes(self):
|
|
|
|
- batch_size = 200
|
|
|
|
- offset = 0
|
|
|
|
-
|
|
|
|
- while True:
|
|
|
|
- try:
|
|
|
|
- #下面的查询语句,增加根据id排序,防止并发问题
|
|
|
|
- nodes = self.db.query(KGNode).filter(
|
|
|
|
- #KGNode.version == 'ER',
|
|
|
|
- KGNode.embedding == None
|
|
|
|
- ).order_by(KGNode.id).offset(offset).limit(batch_size).all()
|
|
|
|
-
|
|
|
|
- if not nodes:
|
|
|
|
- break
|
|
|
|
-
|
|
|
|
- updated_nodes = []
|
|
|
|
- for node in nodes:
|
|
|
|
- if not node.embedding:
|
|
|
|
- embedding = Vectorizer.get_embedding(node.name)
|
|
|
|
- node.embedding = embedding
|
|
|
|
- updated_nodes.append(node)
|
|
|
|
- if updated_nodes:
|
|
|
|
- self.db.commit()
|
|
|
|
-
|
|
|
|
- offset += batch_size
|
|
|
|
- except Exception as e:
|
|
|
|
- self.db.rollback()
|
|
|
|
- print(f"批量处理ER节点失败: {str(e)}")
|
|
|
|
- raise ValueError("Batch process failed")
|
|
|
|
-
|
|
|
|
- def get_node_by_name_category(self, name: str, category: str):
|
|
|
|
- if not name or not category:
|
|
|
|
- raise ValueError("Name and category are required")
|
|
|
|
-
|
|
|
|
- cache_key = f"get_node_by_name_category_{name}:{category}"
|
|
|
|
- if cache_key in self._cache:
|
|
|
|
- return self._cache[cache_key]
|
|
|
|
-
|
|
|
|
- node = self.db.query(KGNode).filter(
|
|
|
|
- KGNode.name == name,
|
|
|
|
- KGNode.category == category,
|
|
|
|
- KGNode.status == 0
|
|
|
|
- ).first()
|
|
|
|
-
|
|
|
|
- if not node:
|
|
|
|
- return None
|
|
|
|
-
|
|
|
|
- node_data = {
|
|
|
|
- 'id': node.id,
|
|
|
|
- 'name': node.name,
|
|
|
|
- 'category': node.category,
|
|
|
|
- 'version': node.version
|
|
|
|
- }
|
|
|
|
- self._cache[cache_key] = node_data
|
|
|
|
- return node_data
|
|
|