Bläddra i källkod

新增nodes表对应的代码

SGTY 3 månader sedan
förälder
incheckning
4efb6fb80b
3 ändrade filer med 305 tillägg och 0 borttagningar
  1. 24 0
      model/kg_node.py
  2. 133 0
      router/knowledge_saas.py
  3. 148 0
      service/kg_node_service.py

+ 24 - 0
model/kg_node.py

@@ -0,0 +1,24 @@
+from sqlalchemy import Column, Integer, String, text
+from sqlalchemy.dialects.postgresql import JSONB
+from db.base_class import Base
+
+class KGNode(Base):
+    __tablename__ = 'kg_nodes'
+
+    id = Column(Integer, primary_key=True, server_default=text("nextval('kg_ids_seq')"))
+    name = Column(String(64), nullable=False)
+    category = Column(String(64), nullable=False)
+    layout = Column(String(100))
+    version = Column(String(16))
+    status = Column(Integer, nullable=False, server_default=text('0'))
+    embedding = Column(JSONB)
+
+    __table_args__ = (
+        {'schema': 'public'},
+        {
+            'postgresql_partition_by': 'LIST (category)',
+            'postgresql_with': {
+                'fillfactor': '50'
+            }
+        }
+    )

+ 133 - 0
router/knowledge_saas.py

@@ -0,0 +1,133 @@
+from fastapi import APIRouter, Depends, HTTPException
+from typing import Optional
+from pydantic import BaseModel
+from model.response import StandardResponse
+from db.session import get_db
+from sqlalchemy.orm import Session
+from service.trunks_service import TrunksService
+import logging
+
+router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
+
+logger = logging.getLogger(__name__)
+
+class PaginatedSearchRequest(BaseModel):
+    keyword: Optional[str] = None
+    pageNo: int = 1
+    limit: int = 10
+    knowledge_ids: Optional[List[str]] = None
+
+class NodeCreateRequest(BaseModel):
+    name: str
+    category: str
+    layout: Optional[str] = None
+    version: Optional[str] = None
+    embedding: Optional[List[float]] = None
+
+class NodeUpdateRequest(BaseModel):
+    layout: Optional[str] = None
+    version: Optional[str] = None
+    status: Optional[int] = None
+    embedding: Optional[List[float]] = None
+
+@router.post("/paginated_search", response_model=StandardResponse)
+async def paginated_search(
+    payload: PaginatedSearchRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        service = KGNodeService()
+        search_params = {
+            'keyword': payload.keyword,
+            'pageNo': payload.pageNo,
+            'limit': payload.limit,
+            'knowledge_ids': payload.knowledge_ids
+        }
+        return service.paginated_search(search_params)
+    except Exception as e:
+        logger.error(f"分页查询失败: {str(e)}")
+        raise HTTPException(
+            status_code=500,
+            detail=StandardResponse(
+                success=False,
+                error_code=500,
+                error_msg=str(e)
+            )
+        )
+    try:
+        trunks_service = TrunksService()
+        search_params = {
+            'keyword': payload.keyword,
+            'pageNo': payload.pageNo,
+            'limit': payload.limit,
+            'knowledge_ids': payload.knowledge_ids
+        }
+        result = trunks_service.paginated_search(search_params)
+        return StandardResponse(
+            success=True,
+            data={
+                'records': result['data'],
+                'pagination': result['pagination']
+            }
+        )
+    except Exception as e:
+        logger.error(f"分页查询失败: {str(e)}")
+        raise HTTPException(
+            status_code=500,
+            detail=StandardResponse(
+                success=False,
+                error_code=500,
+                error_msg=str(e)
+            )
+        )
+
+@router.post("/nodes", response_model=StandardResponse)
+async def create_node(
+    payload: NodeCreateRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        service = TrunksService()
+        return service.create_node(payload.dict())
+    except Exception as e:
+        logger.error(f"创建节点失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.get("/nodes/{node_id}", response_model=StandardResponse)
+async def get_node(
+    node_id: int,
+    db: Session = Depends(get_db)
+):
+    try:
+        service = TrunksService()
+        return service.get_node(node_id)
+    except Exception as e:
+        logger.error(f"获取节点失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.put("/nodes/{node_id}", response_model=StandardResponse)
+async def update_node(
+    node_id: int,
+    payload: NodeUpdateRequest,
+    db: Session = Depends(get_db)
+):
+    try:
+        service = TrunksService()
+        return service.update_node(node_id, payload.dict(exclude_unset=True))
+    except Exception as e:
+        logger.error(f"更新节点失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+@router.delete("/nodes/{node_id}", response_model=StandardResponse)
+async def delete_node(
+    node_id: int,
+    db: Session = Depends(get_db)
+):
+    try:
+        service = TrunksService()
+        return service.delete_node(node_id)
+    except Exception as e:
+        logger.error(f"删除节点失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
+
+saas_kb_router = router

+ 148 - 0
service/kg_node_service.py

@@ -0,0 +1,148 @@
+from sqlalchemy.orm import Session
+from typing import Optional
+from model.trunks_model import KGNode
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+from schema.response import StandardResponse
+
+logger = logging.getLogger(__name__)
+
+class KGNodeService:
+    def __init__(self):
+        self.db = next(get_db())
+
+    def paginated_search(self, search_params: dict) -> StandardResponse:
+        keyword = search_params.get('keyword', '')
+        page_no = search_params.get('pageNo', 1)
+        limit = search_params.get('limit', 10)
+
+        if page_no < 1:
+            page_no = 1
+        if limit < 1:
+            limit = 10
+
+        embedding = Vectorizer.get_embedding(keyword)
+        offset = (page_no - 1) * limit
+
+        try:
+            total_count = self.db.query(func.count(KGNode.id)).scalar()
+
+            query = self.db.query(
+                KGNode.id,
+                KGNode.name,
+                KGNode.category,
+                KGNode.embedding.l2_distance(embedding).label('distance')
+            )
+        if search_params.get('knowledge_ids'):
+            query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
+        results = query.order_by('distance').offset(offset).limit(limit).all()
+
+            return StandardResponse(
+                success=True,
+                data={
+                    'records': [{
+                        'id': r.id,
+                        'name': r.name,
+                        'category': r.category,
+                        'distance': r.distance
+                    } for r in results],
+                    'pagination': {
+                        'total': total_count,
+                        'pageNo': page_no,
+                        'limit': limit,
+                        'totalPages': (total_count + limit - 1) // limit
+                    }
+                }
+            )
+        except Exception as e:
+            logger.error(f"分页查询失败: {str(e)}")
+            return StandardResponse(
+                success=False,
+                error_code=500,
+                error_msg=str(e)
+            )
+
+    def create_node(self, node_data: dict) -> StandardResponse:
+        try:
+            existing = self.db.query(KGNode).filter(
+                KGNode.name == node_data['name'],
+                KGNode.category == node_data['category'],
+                KGNode.version == node_data.get('version')
+            ).first()
+            
+            if existing:
+                return StandardResponse(
+                    success=False,
+                    error_code=409,
+                    error_msg="Node already exists"
+                )
+
+            new_node = KGNode(**node_data)
+            self.db.add(new_node)
+            self.db.commit()
+            return StandardResponse(success=True, data=new_node)
+
+        except IntegrityError as e:
+            self.db.rollback()
+            logger.error(f"创建节点失败: {str(e)}")
+            return StandardResponse(
+                success=False,
+                error_code=500,
+                error_msg="Database integrity error"
+            )
+
+    def get_node(self, node_id: int) -> StandardResponse:
+        node = self.db.query(KGNode).get(node_id)
+        if not node:
+            return StandardResponse(
+                success=False,
+                error_code=404,
+                error_msg="Node not found"
+            )
+        return StandardResponse(success=True, data=node)
+
+    def update_node(self, node_id: int, update_data: dict) -> StandardResponse:
+        node = self.db.query(KGNode).get(node_id)
+        if not node:
+            return StandardResponse(
+                success=False,
+                error_code=404,
+                error_msg="Node not found"
+            )
+
+        try:
+            for key, value in update_data.items():
+                setattr(node, key, value)
+            self.db.commit()
+            return StandardResponse(success=True, data=node)
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"更新节点失败: {str(e)}")
+            return StandardResponse(
+                success=False,
+                error_code=500,
+                error_msg="Update failed"
+            )
+
+    def delete_node(self, node_id: int) -> StandardResponse:
+        node = self.db.query(KGNode).get(node_id)
+        if not node:
+            return StandardResponse(
+                success=False,
+                error_code=404,
+                error_msg="Node not found"
+            )
+
+        try:
+            self.db.delete(node)
+            self.db.commit()
+            return StandardResponse(success=True)
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"删除节点失败: {str(e)}")
+            return StandardResponse(
+                success=False,
+                error_code=500,
+                error_msg="Delete failed"
+            )