123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- from fastapi import APIRouter, Depends, HTTPException
- from typing import Optional, List
- from pydantic import BaseModel
- from model.response import StandardResponse
- from db.session import get_db
- from sqlalchemy.orm import Session
- from service.kg_node_service import KGNodeService
- from service.trunks_service import TrunksService
- from service.kg_edge_service import KGEdgeService
- from service.kg_prop_service import KGPropService
- import logging
- router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
- logger = logging.getLogger(__name__)
- class PaginatedSearchRequest(BaseModel):
- keyword: Optional[str] = None
- pageNo: int = 1
- limit: int = 10
- knowledge_ids: Optional[List[str]] = None
- class NodePaginatedSearchRequest(BaseModel):
- name: str
- category: Optional[str] = None
- pageNo: int = 1
- limit: int = 10
- class NodeCreateRequest(BaseModel):
- name: str
- category: str
- layout: Optional[str] = None
- version: Optional[str] = None
- embedding: Optional[List[float]] = None
- class NodeUpdateRequest(BaseModel):
- layout: Optional[str] = None
- version: Optional[str] = None
- status: Optional[int] = None
- embedding: Optional[List[float]] = None
- class VectorSearchRequest(BaseModel):
- text: str
- limit: int = 10
- type: Optional[str] = None
- class NodeRelationshipRequest(BaseModel):
- src_id: int
- @router.post("/nodes/paginated_search2", response_model=StandardResponse)
- async def paginated_search(
- payload: PaginatedSearchRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = KGNodeService(db)
- search_params = {
- 'keyword': payload.keyword,
- 'pageNo': payload.pageNo,
- 'limit': payload.limit,
- 'knowledge_ids': payload.knowledge_ids,
- 'load_props': True
- }
- result = service.paginated_search(search_params)
- return StandardResponse(
- success=True,
- data={
- 'records': result['records'],
- 'pagination': result['pagination']
- }
- )
- except Exception as e:
- logger.error(f"分页查询失败: {str(e)}")
- raise HTTPException(
- status_code=500,
- detail=StandardResponse(
- success=False,
- error_code=500,
- error_msg=str(e)
- )
- )
- @router.post("/nodes/paginated_search", response_model=StandardResponse)
- async def paginated_search(
- payload: NodePaginatedSearchRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = KGNodeService(db)
- search_params = {
- 'keyword': payload.name,
- 'category': payload.category,
- 'pageNo': payload.pageNo,
- 'limit': payload.limit,
- 'load_props': True
- }
- result = service.paginated_search(search_params)
- return StandardResponse(
- success=True,
- data={
- 'records': result['records'],
- 'pagination': result['pagination']
- }
- )
- except Exception as e:
- logger.error(f"分页查询失败: {str(e)}")
- raise HTTPException(
- status_code=500,
- detail=StandardResponse(
- success=False,
- error_code=500,
- error_msg=str(e)
- )
- )
- @router.post("/nodes", response_model=StandardResponse)
- async def create_node(
- payload: NodeCreateRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.create_node(payload.dict())
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"创建节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.get("/nodes/{node_id}", response_model=StandardResponse)
- async def get_node(
- node_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.get_node(node_id)
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"获取节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.put("/nodes/{node_id}", response_model=StandardResponse)
- async def update_node(
- node_id: int,
- payload: NodeUpdateRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.update_node(node_id, payload.dict(exclude_unset=True))
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"更新节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.delete("/nodes/{node_id}", response_model=StandardResponse)
- async def delete_node(
- node_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- service.delete_node(node_id)
- return StandardResponse(success=True)
- except Exception as e:
- logger.error(f"删除节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.post('/trunks/vector_search', response_model=StandardResponse)
- async def vector_search(
- payload: VectorSearchRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.search_by_vector(
- payload.text,
- payload.limit,
- {'type': payload.type} if payload.type else None
- )
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"向量搜索失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.get('/trunks/{trunk_id}', response_model=StandardResponse)
- async def get_trunk(
- trunk_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.get_trunk_by_id(trunk_id)
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"获取trunk详情失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
- async def get_node_relationships(
- src_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- edge_service = KGEdgeService(db)
- prop_service = KGPropService(db)
-
- edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
- relationships = []
-
- for edge in edges:
- dest_node = edge['dest_node']
- dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
- for p in prop_service.get_props_by_ref_id(dest_node['id'])]
-
- edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
- for p in prop_service.get_props_by_ref_id(edge['id'])]
-
- relationships.append({
- "name": edge['name'],
- "properties": edge_props,
- "targetNode": {
- "category": dest_node['category'],
- "id": str(dest_node['id']),
- "name": dest_node['name'],
- "properties": dest_props
- }
- })
-
- return StandardResponse(
- success=True,
- data={"relationships": relationships}
- )
- except Exception as e:
- logger.error(f"获取节点关系失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- saas_kb_router = router
|