SGTY 1 месяц назад
Родитель
Сommit
a9912e30c2

+ 56 - 3
build/lib/knowledge/router/knowledge_nodes_api.py

@@ -22,6 +22,10 @@ class PaginatedSearchRequest(BaseModel):
     pageNo: int = 1
     limit: int = 10
 
+class GetNodeRelationshipsRequest(BaseModel):
+    relation_name: Optional[str] = None
+
+
 async def get_request_id(request: Request):
     return request.state.context["request_id"]
 
@@ -69,9 +73,10 @@ async def paginated_search(
             )
         )
 
-@router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
-async def get_node_relationships(
+@router.post("/nodes/{src_id}/relationships", response_model=StandardResponse)
+async def get_node_relationships_condition(
     src_id: int,
+    payload: GetNodeRelationshipsRequest,
     db: Session = Depends(get_db),
     request_id: str = Depends(get_request_id),
     api_key: str = Security(api_key_header)
@@ -80,7 +85,7 @@ async def get_node_relationships(
         edge_service = KGEdgeService(db)
         prop_service = KGPropService(db)
         
-        edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
+        edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None,name=payload.relation_name if payload and payload.relation_name else None)
         relationships = []
                
         #count = 0
@@ -117,4 +122,52 @@ async def get_node_relationships(
         logger.error(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
+
+@router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
+async def get_node_relationships(
+        src_id: int,
+        db: Session = Depends(get_db),
+        request_id: str = Depends(get_request_id),
+        api_key: str = Security(api_key_header)
+):
+    try:
+        edge_service = KGEdgeService(db)
+        prop_service = KGPropService(db)
+
+        edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
+        relationships = []
+
+        # count = 0
+        for edge in edges:
+            # if count >= 2:
+            # break
+            dest_node = edge['dest_node']
+            dest_props = []
+            edge_props = []
+            # count += 1
+            # dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
+            #              for p in prop_service.get_props_by_ref_id(dest_node['id'])]
+
+            # edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
+            #             for p in prop_service.get_props_by_ref_id(edge['id'])]
+
+            relationships.append({
+                "name": edge['name'],
+                "props": edge_props,
+                "destNode": {
+                    "category": dest_node['category'],
+                    "id": str(dest_node['id']),
+                    "name": dest_node['name'],
+                    "props": dest_props
+                }
+            })
+
+        return StandardResponse(
+            success=True,
+            requestId=request_id,
+            data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
+        )
+    except Exception as e:
+        logger.error(f"获取节点关系失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
 knowledge_nodes_api_router = router

+ 16 - 7
build/lib/knowledge/service/kg_edge_service.py

@@ -1,13 +1,17 @@
+import copy
+
 from sqlalchemy.orm import Session
 from sqlalchemy import or_
 from typing import Optional
 from ..model.kg_edges import KGEdge
 import logging
 from sqlalchemy.exc import IntegrityError
+from cachetools import TTLCache
 
 logger = logging.getLogger(__name__)
 
 class KGEdgeService:
+    _cache = TTLCache(maxsize=10000, ttl=60*60*24)
     def __init__(self, db: Session):
         self.db = db
 
@@ -68,7 +72,12 @@ class KGEdgeService:
             logger.error(f"删除边失败: {str(e)}")
             raise ValueError("Delete failed")
 
-    def get_edges_by_nodes(self, src_id: Optional[int], dest_id: Optional[int], and_logic: bool = True):
+    def get_edges_by_nodes(self, src_id: Optional[int] = None, dest_id: Optional[int] = None,
+                           name: Optional[str] = None):
+        cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{name}"
+        if cache_key in self._cache:
+            return copy.deepcopy(self._cache[cache_key])
+
         if src_id is None and dest_id is None:
             raise ValueError("至少需要提供一个有效的查询条件")
         try:
@@ -77,11 +86,10 @@ class KGEdgeService:
                 filters.append(KGEdge.src_id == src_id)
             if dest_id is not None:
                 filters.append(KGEdge.dest_id == dest_id)
+            if name:
+                filters.append(KGEdge.name == name)
+            edges = self.db.query(KGEdge).filter(*filters).all()
 
-            if and_logic:
-                edges = self.db.query(KGEdge).filter(*filters).all()
-            else:
-                edges = self.db.query(KGEdge).filter(or_(*filters)).all()
             from ..service.kg_node_service import KGNodeService
             node_service = KGNodeService(self.db)
             result = []
@@ -93,14 +101,15 @@ class KGEdgeService:
                         'dest_id': edge.dest_id,
                         'name': edge.name,
                         'version': edge.version,
-                        'src_node': node_service.get_node(edge.src_id),
+                        # 'src_node': node_service.get_node(edge.src_id),
                         'dest_node': node_service.get_node(edge.dest_id)
                     }
                     result.append(edge_info)
                 except ValueError as e:
                     logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
                     continue
+            self._cache[cache_key] = copy.deepcopy(result)
             return result
         except Exception as e:
             logger.error(f"查询边失败: {str(e)}")
-            raise e
+            raise e

+ 28 - 10
build/lib/knowledge/service/kg_node_service.py

@@ -1,9 +1,11 @@
+import copy
+
 from sqlalchemy.orm import Session
 from ..model.kg_node import KGNode
 from ..db.session import get_db
 import logging
 from sqlalchemy.exc import IntegrityError
-
+from cachetools import TTLCache
 from ..utils.vectorizer import Vectorizer
 from sqlalchemy import func
 from ..service.kg_prop_service import KGPropService
@@ -13,11 +15,10 @@ logger = logging.getLogger(__name__)
 DISTANCE_THRESHOLD = 0.65
 DISTANCE_THRESHOLD2 = 0.3
 class KGNodeService:
+    _cache = TTLCache(maxsize=10000, ttl=60*60*24)
     def __init__(self, db: Session):
         self.db = db
-
-    _cache = {}
-
+        
     def search_title_index(self, index: str, title: str, top_k: int = 3):
         cache_key = f"{index}:{title}:{top_k}"
         if cache_key in self._cache:
@@ -66,6 +67,12 @@ class KGNodeService:
             page_no = 1
         if limit < 1:
             limit = 10
+            
+        cache_key = f"paginated_search:{keyword}:{category}:{page_no}:{distance}:{limit}:{str(search_params.get('knowledge_ids', ''))}:{load_props}"
+        logger.debug(f"Cache key: {cache_key}")
+        if cache_key in self._cache:
+            cached_value = self._cache[cache_key]
+            return copy.deepcopy(cached_value)
 
         embedding = Vectorizer.get_instance().get_embedding(keyword)
         offset = (page_no - 1) * limit
@@ -102,8 +109,9 @@ class KGNodeService:
             query = query.filter(KGNode.embedding.l2_distance(embedding) < distance)
             results = query.order_by('distance').offset(offset).limit(limit).all()
             #将results相同distance的category=疾病的放在前面
-            results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
-            return {
+            #results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
+
+            finalResults = {
                 'records': [{
                     'id': r.id,
                     'name': r.name,
@@ -120,6 +128,8 @@ class KGNodeService:
                 }
             
             }
+            self._cache[cache_key] = copy.deepcopy(finalResults)
+            return finalResults
         except Exception as e:
             logger.error(f"分页查询失败: {str(e)}")
             raise e
@@ -146,20 +156,28 @@ class KGNodeService:
             raise ValueError("Database integrity error")
 
     def get_node(self, node_id: int):
-   
         if node_id is None:
             raise ValueError("Node ID is required")
-     
-        node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()     
+
+        cache_key = f"get_node_{node_id}"
+        if cache_key in self._cache:
+            return copy.deepcopy(self._cache[cache_key])
+
+        node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
 
         if not node:
             raise ValueError("Node not found")
-        return {
+
+        node_data = {
             'id': node.id,
             'name': node.name,
             'category': node.category,
             'version': node.version
         }
+        #node_data深拷贝
+        node_data = node_data.copy()
+        self._cache[cache_key] = copy.deepcopy(node_data)
+        return node_data
 
     def update_node(self, node_id: int, update_data: dict):
         node = self.db.query(KGNode).get(node_id)

+ 52 - 4
src/knowledge/router/knowledge_nodes_api.py

@@ -23,7 +23,7 @@ class PaginatedSearchRequest(BaseModel):
     limit: int = 10
 
 class GetNodeRelationshipsRequest(BaseModel):
-    relation_name: str
+    relation_name: Optional[str] = None
 
 
 async def get_request_id(request: Request):
@@ -73,8 +73,8 @@ async def paginated_search(
             )
         )
 
-@router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
-async def get_node_relationships(
+@router.post("/nodes/{src_id}/relationships", response_model=StandardResponse)
+async def get_node_relationships_condition(
     src_id: int,
     payload: GetNodeRelationshipsRequest,
     db: Session = Depends(get_db),
@@ -85,7 +85,7 @@ async def get_node_relationships(
         edge_service = KGEdgeService(db)
         prop_service = KGPropService(db)
         
-        edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
+        edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None,name=payload.relation_name if payload and payload.relation_name else None)
         relationships = []
                
         #count = 0
@@ -122,4 +122,52 @@ async def get_node_relationships(
         logger.error(f"获取节点关系失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
+
+@router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
+async def get_node_relationships(
+        src_id: int,
+        db: Session = Depends(get_db),
+        request_id: str = Depends(get_request_id),
+        api_key: str = Security(api_key_header)
+):
+    try:
+        edge_service = KGEdgeService(db)
+        prop_service = KGPropService(db)
+
+        edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
+        relationships = []
+
+        # count = 0
+        for edge in edges:
+            # if count >= 2:
+            # break
+            dest_node = edge['dest_node']
+            dest_props = []
+            edge_props = []
+            # count += 1
+            # dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
+            #              for p in prop_service.get_props_by_ref_id(dest_node['id'])]
+
+            # edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
+            #             for p in prop_service.get_props_by_ref_id(edge['id'])]
+
+            relationships.append({
+                "name": edge['name'],
+                "props": edge_props,
+                "destNode": {
+                    "category": dest_node['category'],
+                    "id": str(dest_node['id']),
+                    "name": dest_node['name'],
+                    "props": dest_props
+                }
+            })
+
+        return StandardResponse(
+            success=True,
+            requestId=request_id,
+            data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
+        )
+    except Exception as e:
+        logger.error(f"获取节点关系失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
 knowledge_nodes_api_router = router

+ 8 - 6
src/knowledge/service/kg_edge_service.py

@@ -1,3 +1,5 @@
+import copy
+
 from sqlalchemy.orm import Session
 from sqlalchemy import or_
 from typing import Optional
@@ -71,10 +73,10 @@ class KGEdgeService:
             raise ValueError("Delete failed")
 
     def get_edges_by_nodes(self, src_id: Optional[int] = None, dest_id: Optional[int] = None,
-                           category: Optional[str] = None):
-        cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{category}"
+                           name: Optional[str] = None):
+        cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{name}"
         if cache_key in self._cache:
-            return self._cache[cache_key]
+            return copy.deepcopy(self._cache[cache_key])
 
         if src_id is None and dest_id is None:
             raise ValueError("至少需要提供一个有效的查询条件")
@@ -84,8 +86,8 @@ class KGEdgeService:
                 filters.append(KGEdge.src_id == src_id)
             if dest_id is not None:
                 filters.append(KGEdge.dest_id == dest_id)
-            if category is not None:
-                filters.append(KGEdge.category == category)
+            if name:
+                filters.append(KGEdge.name == name)
             edges = self.db.query(KGEdge).filter(*filters).all()
 
             from ..service.kg_node_service import KGNodeService
@@ -106,7 +108,7 @@ class KGEdgeService:
                 except ValueError as e:
                     logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
                     continue
-            self._cache[cache_key] = result
+            self._cache[cache_key] = copy.deepcopy(result)
             return result
         except Exception as e:
             logger.error(f"查询边失败: {str(e)}")

+ 6 - 5
src/knowledge/service/kg_node_service.py

@@ -1,3 +1,5 @@
+import copy
+
 from sqlalchemy.orm import Session
 from ..model.kg_node import KGNode
 from ..db.session import get_db
@@ -70,8 +72,7 @@ class KGNodeService:
         logger.debug(f"Cache key: {cache_key}")
         if cache_key in self._cache:
             cached_value = self._cache[cache_key]
-            print(cached_value)
-            return cached_value
+            return copy.deepcopy(cached_value)
 
         embedding = Vectorizer.get_instance().get_embedding(keyword)
         offset = (page_no - 1) * limit
@@ -127,7 +128,7 @@ class KGNodeService:
                 }
             
             }
-            self._cache[cache_key] = finalResults
+            self._cache[cache_key] = copy.deepcopy(finalResults)
             return finalResults
         except Exception as e:
             logger.error(f"分页查询失败: {str(e)}")
@@ -160,7 +161,7 @@ class KGNodeService:
 
         cache_key = f"get_node_{node_id}"
         if cache_key in self._cache:
-            return self._cache[cache_key]
+            return copy.deepcopy(self._cache[cache_key])
 
         node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
 
@@ -175,7 +176,7 @@ class KGNodeService:
         }
         #node_data深拷贝
         node_data = node_data.copy()
-        self._cache[cache_key] = node_data
+        self._cache[cache_key] = copy.deepcopy(node_data)
         return node_data
 
     def update_node(self, node_id: int, update_data: dict):