123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- from sqlalchemy.orm import Session
- from typing import Optional
- from model.kg_node import KGNode
- from db.session import get_db
- import logging
- from sqlalchemy.exc import IntegrityError
- from utils import vectorizer
- from utils.vectorizer import Vectorizer
- from sqlalchemy import func
- from service.kg_prop_service import KGPropService
- from service.kg_edge_service import KGEdgeService
- logger = logging.getLogger(__name__)
- DISTANCE_THRESHOLD = 0.65
- DISTANCE_THRESHOLD2 = 0.3
- class KGNodeService:
- def __init__(self, db: Session):
- self.db = db
- _cache = {}
- def search_title_index(self, index: str, title: str, top_k: int = 3):
- cache_key = f"{index}:{title}:{top_k}"
- if cache_key in self._cache:
- return self._cache[cache_key]
- query_embedding = Vectorizer.get_embedding(title)
- db = next(get_db())
- # 执行向量搜索
- results = (
- db.query(
- KGNode.id,
- KGNode.name,
- KGNode.category,
- KGNode.embedding.l2_distance(query_embedding).label('distance')
- )
- .filter(KGNode.status == 0)
- #过滤掉version不等于'er'的节点
- .filter(KGNode.version != 'ER')
- .filter(KGNode.embedding.l2_distance(query_embedding) <= DISTANCE_THRESHOLD2)
- .order_by('distance').limit(top_k).all()
- )
- results = [
- {
- "id": node.id,
- "title": node.name,
- "text": node.category,
- "score": 2.0-node.distance
- }
- for node in results
- ]
- self._cache[cache_key] = results
- return results
- def paginated_search(self, search_params: dict) -> dict:
- load_props = search_params.get('load_props', False)
- prop_service = KGPropService(self.db)
- edge_service = KGEdgeService(self.db)
- keyword = search_params.get('keyword', '')
- page_no = search_params.get('pageNo', 1)
- limit = search_params.get('limit', 10)
- if page_no < 1:
- page_no = 1
- if limit < 1:
- limit = 10
- embedding = Vectorizer.get_embedding(keyword)
- offset = (page_no - 1) * limit
- try:
- total_count = self.db.query(func.count(KGNode.id)).filter(KGNode.version.in_(search_params['knowledge_ids']), KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD).scalar() if search_params.get('knowledge_ids') else self.db.query(func.count(KGNode.id)).filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD).scalar()
- query = self.db.query(
- KGNode.id,
- KGNode.name,
- KGNode.category,
- KGNode.embedding.l2_distance(embedding).label('distance')
- )
- query = query.filter(KGNode.status == 0)
- if search_params.get('knowledge_ids'):
- query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
- query = query.filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD)
- results = query.order_by('distance').offset(offset).limit(limit).all()
- #将results相同distance的category=疾病的放在前面
- results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
- return {
- 'records': [{
- 'id': r.id,
- 'name': r.name,
- 'category': r.category,
- 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
- #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
- 'distance': round(r.distance, 3)
- } for r in results],
- 'pagination': {
- 'total': total_count,
- 'pageNo': page_no,
- 'limit': limit,
- 'totalPages': (total_count + limit - 1) // limit
- }
-
- }
- except Exception as e:
- logger.error(f"分页查询失败: {str(e)}")
- raise e
- def create_node(self, node_data: dict):
- try:
- existing = self.db.query(KGNode).filter(
- KGNode.name == node_data['name'],
- KGNode.category == node_data['category'],
- KGNode.version == node_data.get('version')
- ).first()
-
- if existing:
- raise ValueError("Node already exists")
- new_node = KGNode(**node_data)
- self.db.add(new_node)
- self.db.commit()
- return new_node
- except IntegrityError as e:
- self.db.rollback()
- logger.error(f"创建节点失败: {str(e)}")
- raise ValueError("Database integrity error")
- def get_node(self, node_id: int):
-
- if node_id is None:
- raise ValueError("Node ID is required")
-
- node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
- if not node:
- raise ValueError("Node not found")
- return {
- 'id': node.id,
- 'name': node.name,
- 'category': node.category,
- 'version': node.version
- }
- def update_node(self, node_id: int, update_data: dict):
- node = self.db.query(KGNode).get(node_id)
- if not node:
- raise ValueError("Node not found")
- try:
- for key, value in update_data.items():
- setattr(node, key, value)
- self.db.commit()
- return node
- except Exception as e:
- self.db.rollback()
- logger.error(f"更新节点失败: {str(e)}")
- raise ValueError("Update failed")
- def delete_node(self, node_id: int):
- node = self.db.query(KGNode).get(node_id)
- if not node:
- raise ValueError("Node not found")
- try:
- self.db.delete(node)
- self.db.commit()
- return None
- except Exception as e:
- self.db.rollback()
- logger.error(f"删除节点失败: {str(e)}")
- raise ValueError("Delete failed")
- def batch_process_er_nodes(self):
- batch_size = 200
- offset = 0
- while True:
- try:
- nodes = self.db.query(KGNode).filter(
- #KGNode.version == 'ER',
- KGNode.embedding == None
- ).offset(offset).limit(batch_size).all()
- if not nodes:
- break
- updated_nodes = []
- for node in nodes:
- if not node.embedding:
- embedding = Vectorizer.get_embedding(node.name)
- node.embedding = embedding
- updated_nodes.append(node)
- if updated_nodes:
- self.db.commit()
- offset += batch_size
- except Exception as e:
- self.db.rollback()
- print(f"批量处理ER节点失败: {str(e)}")
- raise ValueError("Batch process failed")
|