|
@@ -1,18 +1,24 @@
|
|
|
from sqlalchemy.orm import Session
|
|
|
from typing import Optional
|
|
|
-from model.trunks_model import KGNode
|
|
|
+from model.kg_node import KGNode
|
|
|
from db.session import get_db
|
|
|
import logging
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
-from schema.response import StandardResponse
|
|
|
+from utils.vectorizer import Vectorizer
|
|
|
+from sqlalchemy import func
|
|
|
+from service.kg_prop_service import KGPropService
|
|
|
+from service.kg_edge_service import KGEdgeService
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KGNodeService:
|
|
|
- def __init__(self):
|
|
|
- self.db = next(get_db())
|
|
|
+ def __init__(self, db: Session):
|
|
|
+ self.db = db
|
|
|
|
|
|
- def paginated_search(self, search_params: dict) -> StandardResponse:
|
|
|
+ def paginated_search(self, search_params: dict) -> dict:
|
|
|
+ load_props = search_params.get('load_props', False)
|
|
|
+ prop_service = KGPropService(self.db)
|
|
|
+ edge_service = KGEdgeService(self.db)
|
|
|
keyword = search_params.get('keyword', '')
|
|
|
page_no = search_params.get('pageNo', 1)
|
|
|
limit = search_params.get('limit', 10)
|
|
@@ -34,36 +40,31 @@ class KGNodeService:
|
|
|
KGNode.category,
|
|
|
KGNode.embedding.l2_distance(embedding).label('distance')
|
|
|
)
|
|
|
- if search_params.get('knowledge_ids'):
|
|
|
- query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
|
|
|
- results = query.order_by('distance').offset(offset).limit(limit).all()
|
|
|
-
|
|
|
- return StandardResponse(
|
|
|
- success=True,
|
|
|
- data={
|
|
|
- 'records': [{
|
|
|
- 'id': r.id,
|
|
|
- 'name': r.name,
|
|
|
- 'category': r.category,
|
|
|
- 'distance': r.distance
|
|
|
- } for r in results],
|
|
|
- 'pagination': {
|
|
|
- 'total': total_count,
|
|
|
- 'pageNo': page_no,
|
|
|
- 'limit': limit,
|
|
|
- 'totalPages': (total_count + limit - 1) // limit
|
|
|
- }
|
|
|
+ if search_params.get('knowledge_ids'):
|
|
|
+ query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
|
|
|
+ results = query.order_by('distance').offset(offset).limit(limit).all()
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'records': [{
|
|
|
+ 'id': r.id,
|
|
|
+ 'name': r.name,
|
|
|
+ 'category': r.category,
|
|
|
+ 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
|
|
|
+ 'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
|
|
|
+ 'distance': r.distance
|
|
|
+ } for r in results],
|
|
|
+ 'pagination': {
|
|
|
+ 'total': total_count,
|
|
|
+ 'pageNo': page_no,
|
|
|
+ 'limit': limit,
|
|
|
+ 'totalPages': (total_count + limit - 1) // limit
|
|
|
}
|
|
|
- )
|
|
|
+ }
|
|
|
except Exception as e:
|
|
|
logger.error(f"分页查询失败: {str(e)}")
|
|
|
- return StandardResponse(
|
|
|
- success=False,
|
|
|
- error_code=500,
|
|
|
- error_msg=str(e)
|
|
|
- )
|
|
|
+ raise e
|
|
|
|
|
|
- def create_node(self, node_data: dict) -> StandardResponse:
|
|
|
+ def create_node(self, node_data: dict):
|
|
|
try:
|
|
|
existing = self.db.query(KGNode).filter(
|
|
|
KGNode.name == node_data['name'],
|
|
@@ -72,77 +73,80 @@ class KGNodeService:
|
|
|
).first()
|
|
|
|
|
|
if existing:
|
|
|
- return StandardResponse(
|
|
|
- success=False,
|
|
|
- error_code=409,
|
|
|
- error_msg="Node already exists"
|
|
|
- )
|
|
|
+ raise ValueError("Node already exists")
|
|
|
|
|
|
new_node = KGNode(**node_data)
|
|
|
self.db.add(new_node)
|
|
|
self.db.commit()
|
|
|
- return StandardResponse(success=True, data=new_node)
|
|
|
+ return new_node
|
|
|
|
|
|
except IntegrityError as e:
|
|
|
self.db.rollback()
|
|
|
logger.error(f"创建节点失败: {str(e)}")
|
|
|
- return StandardResponse(
|
|
|
- success=False,
|
|
|
- error_code=500,
|
|
|
- error_msg="Database integrity error"
|
|
|
- )
|
|
|
+ raise ValueError("Database integrity error")
|
|
|
|
|
|
- def get_node(self, node_id: int) -> StandardResponse:
|
|
|
+ def get_node(self, node_id: int):
|
|
|
node = self.db.query(KGNode).get(node_id)
|
|
|
if not node:
|
|
|
- return StandardResponse(
|
|
|
- success=False,
|
|
|
- error_code=404,
|
|
|
- error_msg="Node not found"
|
|
|
- )
|
|
|
- return StandardResponse(success=True, data=node)
|
|
|
-
|
|
|
- def update_node(self, node_id: int, update_data: dict) -> StandardResponse:
|
|
|
+ raise ValueError("Node not found")
|
|
|
+ return {
|
|
|
+ 'id': node.id,
|
|
|
+ 'name': node.name,
|
|
|
+ 'category': node.category,
|
|
|
+ 'version': node.version
|
|
|
+ }
|
|
|
+
|
|
|
+ def update_node(self, node_id: int, update_data: dict):
|
|
|
node = self.db.query(KGNode).get(node_id)
|
|
|
if not node:
|
|
|
- return StandardResponse(
|
|
|
- success=False,
|
|
|
- error_code=404,
|
|
|
- error_msg="Node not found"
|
|
|
- )
|
|
|
+ raise ValueError("Node not found")
|
|
|
|
|
|
try:
|
|
|
for key, value in update_data.items():
|
|
|
setattr(node, key, value)
|
|
|
self.db.commit()
|
|
|
- return StandardResponse(success=True, data=node)
|
|
|
+ return node
|
|
|
except Exception as e:
|
|
|
self.db.rollback()
|
|
|
logger.error(f"更新节点失败: {str(e)}")
|
|
|
- return StandardResponse(
|
|
|
- success=False,
|
|
|
- error_code=500,
|
|
|
- error_msg="Update failed"
|
|
|
- )
|
|
|
+ raise ValueError("Update failed")
|
|
|
|
|
|
- def delete_node(self, node_id: int) -> StandardResponse:
|
|
|
+ def delete_node(self, node_id: int):
|
|
|
node = self.db.query(KGNode).get(node_id)
|
|
|
if not node:
|
|
|
- return StandardResponse(
|
|
|
- success=False,
|
|
|
- error_code=404,
|
|
|
- error_msg="Node not found"
|
|
|
- )
|
|
|
+ raise ValueError("Node not found")
|
|
|
|
|
|
try:
|
|
|
self.db.delete(node)
|
|
|
self.db.commit()
|
|
|
- return StandardResponse(success=True)
|
|
|
+ return None
|
|
|
except Exception as e:
|
|
|
self.db.rollback()
|
|
|
logger.error(f"删除节点失败: {str(e)}")
|
|
|
- return StandardResponse(
|
|
|
- success=False,
|
|
|
- error_code=500,
|
|
|
- error_msg="Delete failed"
|
|
|
- )
|
|
|
+ raise ValueError("Delete failed")
|
|
|
+
|
|
|
+ def batch_process_er_nodes(self):
|
|
|
+ batch_size = 200
|
|
|
+ offset = 0
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ nodes = self.db.query(KGNode).filter(
|
|
|
+ KGNode.version == 'ER',
|
|
|
+ KGNode.embedding == None
|
|
|
+ ).offset(offset).limit(batch_size).all()
|
|
|
+
|
|
|
+ if not nodes:
|
|
|
+ break
|
|
|
+
|
|
|
+ for node in nodes:
|
|
|
+ if not node.embedding:
|
|
|
+ embedding = Vectorizer.get_embedding(node.name)
|
|
|
+ node.embedding = embedding
|
|
|
+ self.db.commit()
|
|
|
+
|
|
|
+ offset += batch_size
|
|
|
+ except Exception as e:
|
|
|
+ self.db.rollback()
|
|
|
+ logger.error(f"批量处理ER节点失败: {str(e)}")
|
|
|
+ raise ValueError("Batch process failed")
|