SGTY hai 3 meses
pai
achega
f9d25222a5

+ 6 - 0
.idea/encodings.xml

@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="Encoding">
+    <file url="file://$PROJECT_DIR$/app.log" charset="ISO-8859-1" />
+  </component>
+</project>

+ 16 - 0
model/kg_edges.py

@@ -0,0 +1,16 @@
+from sqlalchemy import Column, Integer, String, text
+from db.base_class import Base
+
+class KGEdge(Base):
+    __tablename__ = 'kg_edges'
+
+    id = Column(Integer, primary_key=True, index=True)
+    category = Column(String(64), nullable=False)
+    src_id = Column(Integer, nullable=False)
+    dest_id = Column(Integer, nullable=False)
+    name = Column(String(64), nullable=False)
+    version = Column(String(16))
+    status = Column(Integer, nullable=False, default=0)
+
+    def __repr__(self):
+        return f"<KGEdge(id={self.id}, name={self.name}, category={self.category})>"

+ 8 - 16
model/kg_node.py

@@ -1,24 +1,16 @@
 from sqlalchemy import Column, Integer, String, text
 from sqlalchemy.dialects.postgresql import JSONB
 from db.base_class import Base
+from pgvector.sqlalchemy import Vector
 
 class KGNode(Base):
     __tablename__ = 'kg_nodes'
 
-    id = Column(Integer, primary_key=True, server_default=text("nextval('kg_ids_seq')"))
-    name = Column(String(64), nullable=False)
-    category = Column(String(64), nullable=False)
-    layout = Column(String(100))
-    version = Column(String(16))
-    status = Column(Integer, nullable=False, server_default=text('0'))
-    embedding = Column(JSONB)
+    id = Column(Integer, primary_key=True, index=True)
+    name = Column(String(255), nullable=False)
+    category = Column(String(255), nullable=False)
+    embedding = Column(Vector(1024))
+    version = Column(String(50))
 
-    __table_args__ = (
-        {'schema': 'public'},
-        {
-            'postgresql_partition_by': 'LIST (category)',
-            'postgresql_with': {
-                'fillfactor': '50'
-            }
-        }
-    )
+    def __repr__(self):
+        return f"<KGNode(id={self.id}, name={self.name}, category={self.category})>"

+ 18 - 0
model/kg_prop.py

@@ -0,0 +1,18 @@
+from sqlalchemy import Column, Integer, String, Text
+from db.base_class import Base
+
+class KGProp(Base):
+    __tablename__ = 'kg_props'
+
+    id = Column(Integer, primary_key=True, index=True)
+    category = Column(Integer, nullable=False, default=0)
+    ref_id = Column(Integer, nullable=False)
+    prop_name = Column(String(64), nullable=False)
+    prop_value = Column(Text)
+    prop_title = Column(String(64))
+    type = Column(Integer)
+
+    def __repr__(self):
+        return f"<KGProp(id={self.id}, prop_name={self.prop_name}, ref_id={self.ref_id})>"
+
+# 类型:空或1-实体属性,2-关系属性

+ 1 - 1
model/trunks_model.py

@@ -13,4 +13,4 @@ class Trunks(Base):
     content_tsvector = Column(TSVECTOR)
 
     def __repr__(self):
-        return f"<Trunks(id={self.id}, file_path={self.file_path})>"
+        return f"<Trunks(id={self.id}, file_path={self.file_path})>"

+ 16 - 28
router/knowledge_saas.py

@@ -1,9 +1,11 @@
 from fastapi import APIRouter, Depends, HTTPException
-from typing import Optional
+from typing import Optional, List
 from pydantic import BaseModel
 from model.response import StandardResponse
 from db.session import get_db
 from sqlalchemy.orm import Session
+
+from service.kg_node_service import KGNodeService
 from service.trunks_service import TrunksService
 import logging
 
@@ -36,37 +38,19 @@ async def paginated_search(
     db: Session = Depends(get_db)
 ):
     try:
-        service = KGNodeService()
-        search_params = {
-            'keyword': payload.keyword,
-            'pageNo': payload.pageNo,
-            'limit': payload.limit,
-            'knowledge_ids': payload.knowledge_ids
-        }
-        return service.paginated_search(search_params)
-    except Exception as e:
-        logger.error(f"分页查询失败: {str(e)}")
-        raise HTTPException(
-            status_code=500,
-            detail=StandardResponse(
-                success=False,
-                error_code=500,
-                error_msg=str(e)
-            )
-        )
-    try:
-        trunks_service = TrunksService()
+        service = KGNodeService(db)
         search_params = {
             'keyword': payload.keyword,
             'pageNo': payload.pageNo,
             'limit': payload.limit,
-            'knowledge_ids': payload.knowledge_ids
+            'knowledge_ids': payload.knowledge_ids,
+            'load_props': True
         }
-        result = trunks_service.paginated_search(search_params)
+        result = service.paginated_search(search_params)
         return StandardResponse(
             success=True,
             data={
-                'records': result['data'],
+                'records': result['records'],
                 'pagination': result['pagination']
             }
         )
@@ -88,7 +72,8 @@ async def create_node(
 ):
     try:
         service = TrunksService()
-        return service.create_node(payload.dict())
+        result = service.create_node(payload.dict())
+        return StandardResponse(success=True, data=result)
     except Exception as e:
         logger.error(f"创建节点失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
@@ -100,7 +85,8 @@ async def get_node(
 ):
     try:
         service = TrunksService()
-        return service.get_node(node_id)
+        result = service.get_node(node_id)
+        return StandardResponse(success=True, data=result)
     except Exception as e:
         logger.error(f"获取节点失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
@@ -113,7 +99,8 @@ async def update_node(
 ):
     try:
         service = TrunksService()
-        return service.update_node(node_id, payload.dict(exclude_unset=True))
+        result = service.update_node(node_id, payload.dict(exclude_unset=True))
+        return StandardResponse(success=True, data=result)
     except Exception as e:
         logger.error(f"更新节点失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
@@ -125,7 +112,8 @@ async def delete_node(
 ):
     try:
         service = TrunksService()
-        return service.delete_node(node_id)
+        service.delete_node(node_id)
+        return StandardResponse(success=True)
     except Exception as e:
         logger.error(f"删除节点失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))

+ 99 - 0
service/kg_edge_service.py

@@ -0,0 +1,99 @@
+from sqlalchemy.orm import Session
+from sqlalchemy import or_
+from typing import Optional
+from model.kg_edges import KGEdge
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+
+logger = logging.getLogger(__name__)
+
+class KGEdgeService:
+    def __init__(self, db: Session):
+        self.db = db
+
+    def get_edge(self, edge_id: int):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+        return edge
+
+    def create_edge(self, edge_data: dict):
+        try:
+            existing = self.db.query(KGEdge).filter(
+                KGEdge.src_id == edge_data['src_id'],
+                KGEdge.dest_id == edge_data['dest_id'],
+                KGEdge.name == edge_data['name'],
+                KGEdge.version == edge_data.get('version')
+            ).first()
+
+            if existing:
+                raise ValueError("Edge already exists")
+
+            new_edge = KGEdge(**edge_data)
+            self.db.add(new_edge)
+            self.db.commit()
+            return new_edge
+
+        except IntegrityError as e:
+            self.db.rollback()
+            logger.error(f"创建边失败: {str(e)}")
+            raise ValueError("Database integrity error")
+
+    def update_edge(self, edge_id: int, update_data: dict):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+
+        try:
+            for key, value in update_data.items():
+                setattr(edge, key, value)
+            self.db.commit()
+            return edge
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"更新边失败: {str(e)}")
+            raise ValueError("Update failed")
+
+    def delete_edge(self, edge_id: int):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+
+        try:
+            self.db.delete(edge)
+            self.db.commit()
+            return None
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"删除边失败: {str(e)}")
+            raise ValueError("Delete failed")
+
+    def get_edges_by_nodes(self, src_id: Optional[int], dest_id: Optional[int], and_logic: bool = True):
+        if src_id is None and dest_id is None:
+            raise ValueError("至少需要提供一个有效的查询条件")
+        try:
+            filters = []
+            if src_id is not None:
+                filters.append(KGEdge.src_id == src_id)
+            if dest_id is not None:
+                filters.append(KGEdge.dest_id == dest_id)
+
+            if and_logic:
+                edges = self.db.query(KGEdge).filter(*filters).all()
+            else:
+                edges = self.db.query(KGEdge).filter(or_(*filters)).all()
+            from service.kg_node_service import KGNodeService
+            node_service = KGNodeService(self.db)
+            return [{
+                'id': edge.id,
+                'src_id': edge.src_id,
+                'dest_id': edge.dest_id,
+                'name': edge.name,
+                'version': edge.version,
+                'src_node': node_service.get_node(edge.src_id),
+                'dest_node': node_service.get_node(edge.dest_id)
+            } for edge in edges]
+        except Exception as e:
+            logger.error(f"查询边失败: {str(e)}")
+            raise e

+ 78 - 74
service/kg_node_service.py

@@ -1,18 +1,24 @@
 from sqlalchemy.orm import Session
 from typing import Optional
-from model.trunks_model import KGNode
+from model.kg_node import KGNode
 from db.session import get_db
 import logging
 from sqlalchemy.exc import IntegrityError
-from schema.response import StandardResponse
+from utils.vectorizer import Vectorizer
+from sqlalchemy import func
+from service.kg_prop_service import KGPropService
+from service.kg_edge_service import KGEdgeService
 
 logger = logging.getLogger(__name__)
 
 class KGNodeService:
-    def __init__(self):
-        self.db = next(get_db())
+    def __init__(self, db: Session):
+        self.db = db
 
-    def paginated_search(self, search_params: dict) -> StandardResponse:
+    def paginated_search(self, search_params: dict) -> dict:
+        load_props = search_params.get('load_props', False)
+        prop_service = KGPropService(self.db)
+        edge_service = KGEdgeService(self.db)
         keyword = search_params.get('keyword', '')
         page_no = search_params.get('pageNo', 1)
         limit = search_params.get('limit', 10)
@@ -34,36 +40,31 @@ class KGNodeService:
                 KGNode.category,
                 KGNode.embedding.l2_distance(embedding).label('distance')
             )
-        if search_params.get('knowledge_ids'):
-            query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
-        results = query.order_by('distance').offset(offset).limit(limit).all()
-
-            return StandardResponse(
-                success=True,
-                data={
-                    'records': [{
-                        'id': r.id,
-                        'name': r.name,
-                        'category': r.category,
-                        'distance': r.distance
-                    } for r in results],
-                    'pagination': {
-                        'total': total_count,
-                        'pageNo': page_no,
-                        'limit': limit,
-                        'totalPages': (total_count + limit - 1) // limit
-                    }
+            if search_params.get('knowledge_ids'):
+                query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
+            results = query.order_by('distance').offset(offset).limit(limit).all()
+
+            return {
+                'records': [{
+                    'id': r.id,
+                    'name': r.name,
+                    'category': r.category,
+                    'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
+                    'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
+                    'distance': r.distance
+                } for r in results],
+                'pagination': {
+                    'total': total_count,
+                    'pageNo': page_no,
+                    'limit': limit,
+                    'totalPages': (total_count + limit - 1) // limit
                 }
-            )
+            }
         except Exception as e:
             logger.error(f"分页查询失败: {str(e)}")
-            return StandardResponse(
-                success=False,
-                error_code=500,
-                error_msg=str(e)
-            )
+            raise e
 
-    def create_node(self, node_data: dict) -> StandardResponse:
+    def create_node(self, node_data: dict):
         try:
             existing = self.db.query(KGNode).filter(
                 KGNode.name == node_data['name'],
@@ -72,77 +73,80 @@ class KGNodeService:
             ).first()
             
             if existing:
-                return StandardResponse(
-                    success=False,
-                    error_code=409,
-                    error_msg="Node already exists"
-                )
+                raise ValueError("Node already exists")
 
             new_node = KGNode(**node_data)
             self.db.add(new_node)
             self.db.commit()
-            return StandardResponse(success=True, data=new_node)
+            return new_node
 
         except IntegrityError as e:
             self.db.rollback()
             logger.error(f"创建节点失败: {str(e)}")
-            return StandardResponse(
-                success=False,
-                error_code=500,
-                error_msg="Database integrity error"
-            )
+            raise ValueError("Database integrity error")
 
-    def get_node(self, node_id: int) -> StandardResponse:
+    def get_node(self, node_id: int):
         node = self.db.query(KGNode).get(node_id)
         if not node:
-            return StandardResponse(
-                success=False,
-                error_code=404,
-                error_msg="Node not found"
-            )
-        return StandardResponse(success=True, data=node)
-
-    def update_node(self, node_id: int, update_data: dict) -> StandardResponse:
+            raise ValueError("Node not found")
+        return {
+            'id': node.id,
+            'name': node.name,
+            'category': node.category,
+            'version': node.version
+        }
+
+    def update_node(self, node_id: int, update_data: dict):
         node = self.db.query(KGNode).get(node_id)
         if not node:
-            return StandardResponse(
-                success=False,
-                error_code=404,
-                error_msg="Node not found"
-            )
+            raise ValueError("Node not found")
 
         try:
             for key, value in update_data.items():
                 setattr(node, key, value)
             self.db.commit()
-            return StandardResponse(success=True, data=node)
+            return node
         except Exception as e:
             self.db.rollback()
             logger.error(f"更新节点失败: {str(e)}")
-            return StandardResponse(
-                success=False,
-                error_code=500,
-                error_msg="Update failed"
-            )
+            raise ValueError("Update failed")
 
-    def delete_node(self, node_id: int) -> StandardResponse:
+    def delete_node(self, node_id: int):
         node = self.db.query(KGNode).get(node_id)
         if not node:
-            return StandardResponse(
-                success=False,
-                error_code=404,
-                error_msg="Node not found"
-            )
+            raise ValueError("Node not found")
 
         try:
             self.db.delete(node)
             self.db.commit()
-            return StandardResponse(success=True)
+            return None
         except Exception as e:
             self.db.rollback()
             logger.error(f"删除节点失败: {str(e)}")
-            return StandardResponse(
-                success=False,
-                error_code=500,
-                error_msg="Delete failed"
-            )
+            raise ValueError("Delete failed")
+
+    def batch_process_er_nodes(self):
+        batch_size = 200
+        offset = 0
+
+        while True:
+            try:
+                nodes = self.db.query(KGNode).filter(
+                    KGNode.version == 'ER',
+                    KGNode.embedding == None
+                ).offset(offset).limit(batch_size).all()
+
+                if not nodes:
+                    break
+
+                for node in nodes:
+                    if not node.embedding:
+                        embedding = Vectorizer.get_embedding(node.name)
+                        node.embedding = embedding
+                        self.db.commit()
+
+                offset += batch_size
+            except Exception as e:
+                self.db.rollback()
+                logger.error(f"批量处理ER节点失败: {str(e)}")
+                raise ValueError("Batch process failed")

+ 67 - 0
service/kg_prop_service.py

@@ -0,0 +1,67 @@
+from sqlalchemy.orm import Session
+from typing import List
+from model.kg_prop import KGProp
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+
+logger = logging.getLogger(__name__)
+
+class KGPropService:
+    def __init__(self, db: Session):
+        self.db = db
+
+    def get_props_by_ref_id(self, ref_id: int) -> List[dict]:
+        try:
+            props = self.db.query(KGProp).filter(KGProp.ref_id == ref_id).all()
+            return [{
+                'id': p.id,
+                'category': p.category,
+                'prop_name': p.prop_name,
+                'prop_value': p.prop_value,
+                'prop_title': p.prop_title,
+                'type': p.type
+            } for p in props]
+        except Exception as e:
+            logger.error(f"根据ref_id查询属性失败: {str(e)}")
+            raise ValueError("查询失败")
+
+    def create_prop(self, prop_data: dict) -> KGProp:
+        try:
+            new_prop = KGProp(**prop_data)
+            self.db.add(new_prop)
+            self.db.commit()
+            return new_prop
+        except IntegrityError as e:
+            self.db.rollback()
+            logger.error(f"创建属性失败: {str(e)}")
+            raise ValueError("数据库完整性错误")
+
+    def update_prop(self, prop_id: int, update_data: dict) -> KGProp:
+        prop = self.db.query(KGProp).get(prop_id)
+        if not prop:
+            raise ValueError("属性未找到")
+
+        try:
+            for key, value in update_data.items():
+                setattr(prop, key, value)
+            self.db.commit()
+            return prop
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"更新属性失败: {str(e)}")
+            raise ValueError("更新失败")
+
+    def delete_prop(self, prop_id: int) -> None:
+        prop = self.db.query(KGProp).get(prop_id)
+        if not prop:
+            raise ValueError("属性未找到")
+
+        try:
+            self.db.delete(prop)
+            self.db.commit()
+            return None
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"删除属性失败: {str(e)}")
+            raise ValueError("删除失败")

+ 1 - 0
service/trunks_service.py

@@ -1,5 +1,6 @@
 from sqlalchemy import func
 from sqlalchemy.orm import Session
+from db.session import get_db
 from typing import List, Optional
 from model.trunks_model import Trunks
 from db.session import SessionLocal

+ 53 - 0
tests/service/test_kg_node_service.py

@@ -0,0 +1,53 @@
+import pytest
+from service.kg_node_service import KGNodeService
+from model.trunks_model import KGNode
+from sqlalchemy.exc import IntegrityError
+
+@pytest.fixture(scope="module")
+def kg_node_service():
+    return KGNodeService()
+
+@pytest.fixture
+def test_node_data():
+    return {
+        "name": "测试节点",
+        "category": "测试类别",
+        "version": "1.0"
+    }
+
+class TestKGNodeServiceCRUD:
+    def test_create_and_get_node(self, kg_node_service, test_node_data):
+        created = kg_node_service.create_node(test_node_data)
+        assert created.id is not None
+        retrieved = kg_node_service.get_node(created.id)
+        assert retrieved.name == test_node_data['name']
+
+    def test_update_node(self, kg_node_service, test_node_data):
+        node = kg_node_service.create_node(test_node_data)
+        updated = kg_node_service.update_node(node.id, {"name": "更新后的节点"})
+        assert updated.name == "更新后的节点"
+
+    def test_delete_node(self, kg_node_service, test_node_data):
+        node = kg_node_service.create_node(test_node_data)
+        assert kg_node_service.delete_node(node.id) is None
+        with pytest.raises(ValueError):
+            kg_node_service.get_node(node.id)
+
+    def test_duplicate_node(self, kg_node_service, test_node_data):
+        kg_node_service.create_node(test_node_data)
+        with pytest.raises(ValueError):
+            kg_node_service.create_node(test_node_data)
+
+class TestPaginatedSearch:
+    def test_paginated_search(self, kg_node_service, test_node_data):
+        results = kg_node_service.paginated_search({
+            'keyword': '感染性',
+            'pageNo': 1,
+            'limit': 10
+        })
+        assert len(results['records']) > 0
+        assert results['pagination']['pageNo'] == 1
+
+class TestBatchProcess:
+    def test_batch_process_er_nodes(self, kg_node_service, test_node_data):
+        kg_node_service.batch_process_er_nodes()