SGTY пре 2 недеља
родитељ
комит
a1a1d11eb8

+ 4 - 3
src/knowledge/router/graph_api.py

@@ -1,9 +1,10 @@
 from fastapi import APIRouter, Depends
 from pydantic import BaseModel
-from typing import Optional
+
 
 from ..db.session import get_db
 from ..model.response import StandardResponse
+
 from ..service.kg_graph_service import KgGraphService
 
 router = APIRouter()
@@ -15,8 +16,8 @@ class KgQuery(BaseModel):
 @router.post("/knowledge/getGraph")
 async def get_graph(kg_query: KgQuery) -> StandardResponse:
     db = next(get_db())
-    kg_graph_service = KgGraphService(db)
-    graph_data = kg_graph_service.get_graph_fac(kg_query.label_name,kg_query.input_str)
+    kg_graph_service = KgGraphService()
+    graph_data = kg_graph_service.get_graph_fac({"label_name": kg_query.label_name, "input_str": kg_query.input_str})
     return StandardResponse(
         success=True,
         data=graph_data

+ 2 - 9
src/knowledge/router/knowledge_saas.py

@@ -97,15 +97,8 @@ async def paginated_search(
             }
         )
     except Exception as e:
-        logger.error(f"分页查询失败: {str(e)}")
-        raise HTTPException(
-            status_code=500,
-            detail=StandardResponse(
-                success=False,
-                error_code=500,
-                error_msg=str(e)
-            )
-        )
+        logger.exception(f"分页查询失败: {str(e)}")
+        raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 @router.post("/nodes", response_model=StandardResponse)
 async def create_node(

+ 5 - 5
src/knowledge/router/text_search.py

@@ -567,7 +567,7 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                             symptom_node = node_service.get_node_by_name_category(symptom, '症状')
                             # 获取症状相关同义词(包括1.0和2.0版本)
                             for category in ['症状同义词', '症状同义词2.0']:
-                                edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], category=category)
+                                edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], name=category)
                                 if edges:
                                     # 添加同义词
                                     for edge in edges:
@@ -611,7 +611,7 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
                     symptom_node = node_service.get_node_by_name_category(symptom, '症状')
                     # 获取症状相关同义词(包括1.0和2.0版本)
                     for category in ['症状同义词', '症状同义词2.0']:
-                        edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], category=category)
+                        edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], name=category)
                         if edges:
                             # 添加同义词
                             for edge in edges:
@@ -762,9 +762,9 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
             # 更新answer中的index
             if "answer" in prop_result:
                 for sentence in prop_result["answer"]:
-                    if sentence["index"]:
+                    if sentence["flag"]:
                         for ref in prop_result["references"]:
-                            if ref["index"].endswith(f"-{sentence['index']}"):
+                            if ref["index"].endswith(f"-{sentence['flag']}"):
                                 sentence["flag"] = ref["index"]
                                 break
 
@@ -791,7 +791,7 @@ async def node_props_search(request: NodePropsSearchRequest, db: Session = Depen
 
         return StandardResponse(success=True, data=result)
     except Exception as e:
-        logger.error(f"Node props search failed: {str(e)}")
+        logger.exception(f"Node props search failed: {str(e)}")
         raise HTTPException(status_code=500, detail=str(e))
 
 class FindSimilarTexts(BaseModel):

+ 2 - 1
src/knowledge/server.py

@@ -14,6 +14,7 @@ from .middlewares.base import register_middlewares
 from .model.response import StandardResponse
 from .router.graph_api import graph_router
 from .router.knowledge_nodes_api import knowledge_nodes_api_router, get_request_id, api_key_header
+from .router.knowledge_saas import saas_kb_router
 from .router.text_search import text_search_router
 from .service.kg_edge_service import KGEdgeService
 from .service.kg_node_service import KGNodeService
@@ -56,7 +57,7 @@ async def startup():
     app.include_router(knowledge_nodes_api_router)
     app.include_router(text_search_router)
     app.include_router(graph_router)
-    
+    app.include_router(saas_kb_router)
     logger.info("fastapi startup success")
 
 

+ 267 - 172
src/knowledge/service/kg_graph_service.py

@@ -1,195 +1,290 @@
-from typing import List, Dict, Optional
+from typing import List, Dict, Any
 from sqlalchemy.orm import Session
 from sqlalchemy import text
 from ..db.session import get_db
 
 class KgGraphService:
-    def __init__(self, db: Session):
-        self.db = db
+    def __init__(self):
+        self.db = next(get_db())
 
-    def get_node_properties(self, node_id: int) -> Dict[str, str]:
-        """查询节点属性"""
-        prop_sql = "SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id"
-        result = self.db.execute(text(prop_sql), {'node_id': node_id}).fetchall()
-        return {row[0]: row[1] for row in result}
+    def _get_node_properties(self, node_id: int) -> Dict[str, Any]:
+        prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id")
+        prop_results = self.db.execute(prop_sql, {"node_id": node_id}).fetchall()
+
+        properties = {}
+        for prop in prop_results:
+            properties[prop.prop_title] = prop.prop_value
+        return properties
+
+    def get_graph(self, kg_query: Dict[str, str]) -> List[Dict[str, Any]]:
+        graph_dto_list = []
+        graph_map = {}
+        rtype_map = {}
 
-    def get_graph(self, label_name: str, input_str: str) -> List[Dict]:
-        """查询知识图谱关系"""
         # 1. 查询中心节点
-        node_sql = """
+        node_sql = text("""
         SELECT n.id, n.name, n.category 
         FROM kg_nodes n 
         WHERE n.category = :label_name AND n.name = :input_str 
-        AND n.status = '0'"""
-        
-        center_node = self.db.execute(
-            text(node_sql), 
-            {'label_name': label_name, 'input_str': input_str}
-        ).fetchone()
-        
-        if not center_node:
-            return []
-            
-        # 获取中心节点属性
-        s_id, s_name, s_label = center_node
-        s_prop = self.get_node_properties(s_id)
-        s_prop['name'] = s_name
-        
-        # 2. 查询关联的边和目标节点
+        AND n.status = '0'
+        """)
+        center_nodes = self.db.execute(node_sql, {
+            'label_name': kg_query['label_name'],
+            'input_str': kg_query['input_str']
+        }).fetchall()
+
+        for row in center_nodes:
+            s_name = row[1]
+            s_label = row[2]
+            s_id = row[0]
+            s_prop = self._get_node_properties(s_id)
+            s_prop['name'] = s_name
+
+            graph_dto = {
+                'name': s_name,
+                'label': s_label,
+                'id': s_id,
+                'properties': s_prop
+            }
+            graph_map[f"{s_name}_{s_label}"] = graph_dto
+
+        # 2. 查询关联的边和目标节点,每种关系最多查询50条
         relation_sql = """
         WITH RankedRelations AS (
-            SELECT e.name as rType, m.id as target_id, m.name as target_name, 
-                   m.category as target_label, 
-                   (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount, 
-                   ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn 
-            FROM kg_nodes n 
-            JOIN kg_edges e ON n.id = e.src_id 
-            JOIN kg_nodes m ON e.dest_id = m.id 
-            WHERE n.category = :label_name AND n.name = :input_str 
-            AND n.status = '0' 
-        ) 
-        SELECT rType, target_id, target_name, target_label, pCount 
-        FROM RankedRelations 
-        WHERE rn <= 50 
-        ORDER BY rType"""
-        
-        relations = self.db.execute(
-            text(relation_sql),
-            {'label_name': label_name, 'input_str': input_str}
-        ).fetchall()
-        
-        # 3. 组装返回结果
-        graph_data = {
-            'name': s_name,
-            'label': s_label,
-            'id': s_id,
-            'properties': s_prop,
-            'relations': []
-        }
-        
-        for r in relations:
-            r_type, e_id, e_name, e_label, p_count = r
-            e_prop = self.get_node_properties(e_id)
+            SELECT e.name as rType, m.id as target_id, m.name as target_name,
+                   m.category as target_label,
+                   (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount,
+                   ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn
+            FROM kg_nodes n
+            JOIN kg_edges e ON n.id = e.src_id
+            JOIN kg_nodes m ON e.dest_id = m.id
+            WHERE n.category = %s AND n.name = %s
+            AND n.status = '0'
+        )
+        SELECT rType, target_id, target_name, target_label, pCount
+        FROM RankedRelations
+        WHERE rn <= 50
+        ORDER BY rType
+        """
+        relation_sql = text("""
+        WITH RankedRelations AS (
+            SELECT e.name as rType, m.id as target_id, m.name as target_name,
+                   m.category as target_label,
+                   (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount,
+                   ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn
+            FROM kg_nodes n
+            JOIN kg_edges e ON n.id = e.src_id
+            JOIN kg_nodes m ON e.dest_id = m.id
+            WHERE n.category = :label_name AND n.name = :input_str
+            AND n.status = '0'
+        )
+        SELECT rType, target_id, target_name, target_label, pCount
+        FROM RankedRelations
+        WHERE rn <= 50
+        ORDER BY rType
+        """)
+        relations = self.db.execute(relation_sql, {
+            'label_name': kg_query['label_name'],
+            'input_str': kg_query['input_str']
+        }).fetchall()
+
+        for row in relations:
+            r_type = row[0]
+            e_name = row[2]
+            e_label = row[3]
+            e_id = row[1]
+            e_prop = self._get_node_properties(e_id)
             e_prop['name'] = e_name
-            
-            relation_data = {
-                'rType': r_type,
-                'target': {
-                    'id': e_id,
-                    'name': e_name,
-                    'label': e_label,
-                    'pCount': p_count,
-                    'properties': e_prop
-                }
+            p_count = row[4]
+
+            next_node_dto = {
+                'name': e_name,
+                'label': e_label,
+                'pCount': p_count,
+                'id': e_id,
+                'properties': e_prop
             }
-            graph_data['relations'].append(relation_data)
-        
-        return [graph_data]
-        
-    def get_graph_fac(self, label_name: str, input_str: str) -> Dict:
-        """获取知识图谱前端展示数据"""
-        graph_data = {
-            'categories': [],
-            'node': [],
-            'links': []
-        }
-        
-        # 获取原始图谱数据
-        res = self.get_graph(label_name, input_str)
+
+            if r_type not in rtype_map:
+                rtype_map[r_type] = []
+            rtype_map[r_type].append(next_node_dto)
+
+
+
+        # 3. 组装返回结果
+        if graph_map:
+            graph_dto = list(graph_map.values())[0]
+            base_node_rsdtos = []
+
+            for key, value in rtype_map.items():
+                node_rsdto = {
+                    'rType': key,
+                    'ENodeDTOS': value
+                }
+                base_node_rsdtos.append(node_rsdto)
+
+            graph_dto['ENodeRSDTOS'] = base_node_rsdtos
+            graph_dto_list.append(graph_dto)
+
+        return graph_dto_list
+
+
+    def get_graph_fac(self, kg_query):
+        graph_label_dto = GraphLabelDTO()
+        categories = []
+        node_list = []
+        links = []
+        res = self.get_graph(kg_query)
+
         if not res:
-            return graph_data
-            
-        item_style_map = {"display": True}
-        node_id = 0
+            return
+        else:
+            item_style_map = {"display": True}
+            node_id = 0
+            categories.append(CategorieDTO("中心词"))
+            categories.append(CategorieDTO("关系"))
+            c_map = {"中心词": 0, "关系": 1}
+
+            graph_dto = res[0]
+            g_node_dto = GNodeDTO(
+                type=graph_dto["label"],
+                label=graph_dto["name"],
+                category=0,
+                name="0",
+                #id=node_id,
+                symbol="circle",
+                symbol_size=50,
+                properties=graph_dto["properties"],
+                nodeId=graph_dto["id"],
+                item_style=item_style_map
+            )
+
+            node_id += 1
+            node_list.append(g_node_dto)
+
+            if graph_dto["ENodeRSDTOS"]:
+                rs_id = 2
+                for base_node_r_s_dto in graph_dto["ENodeRSDTOS"]:
+                    if base_node_r_s_dto["rType"] not in c_map:
+                        c_map[base_node_r_s_dto["rType"]] = rs_id
+                        categories.append(CategorieDTO(base_node_r_s_dto["rType"]))
+                        rs_id += 1
+                    n_node_dto = GNodeDTO(
+                        type=graph_dto["label"],
+                        category=1,
+                        label="",
+                        name=node_id,
+                        symbol="diamond",
+                        symbol_size=10,
+                        properties=graph_dto["properties"],
+                        nodeId=graph_dto["id"],
+                        item_style=item_style_map
+                    )
+
+                    node_list.append(n_node_dto)
+                    links.append(LinkDTO(
+                        source=g_node_dto.name,
+                        target=n_node_dto.name,
+                        value=base_node_r_s_dto["rType"],
+                        relationShipType=base_node_r_s_dto["rType"]
+                    ))
+                    node_id += 1
+
+                    if base_node_r_s_dto["ENodeDTOS"]:
+                        for base_node_dto in base_node_r_s_dto["ENodeDTOS"]:
+                            children_item_style_map = {"display": True}
+                            if base_node_dto["pCount"] == 0:
+                                children_item_style_map["display"] = False
+                            e_node_dto = GNodeDTO(
+                                type=base_node_dto["label"],
+                                category=c_map[base_node_r_s_dto["rType"]],
+                                label=base_node_dto["name"],
+                                name=node_id,
+                                symbol="circle",
+                                symbol_size=28,
+                                properties=base_node_dto["properties"],
+                                nodeId=base_node_dto["id"],
+                                item_style=children_item_style_map
+                            )
+
+                            node_id += 1
+                            node_list.append(e_node_dto)
+                            links.append(LinkDTO(
+                                source=n_node_dto.name,
+                                target=e_node_dto.name,
+                                value="",
+                                relationShipType=base_node_r_s_dto["rType"]
+                            ))
+
+            graph_label_dto.categories = categories
+            graph_label_dto.node = node_list
+            graph_label_dto.links = links
+
+        return graph_label_dto.to_dict()
+
+class CategorieDTO:
+    def __init__(self, name: str):
+        self.name = name
         
-        # 添加分类
-        graph_data['categories'].append({"name": "中心词"})
-        graph_data['categories'].append({"name": "关系"})
-        c_map = {"中心词": 0, "关系": 1}
+    def to_dict(self):
+        return {
+            "name": self.name
+        }
+
+class GNodeDTO:
+    def __init__(self, type: str, category: int, label: str,name: str, symbol: str, symbol_size: int,
+                 properties: Dict[str, Any], nodeId: int, item_style: Dict[str, Any]):
+        self.type = type
+        self.category = category
+        self.label = label
+        self.name = name
+        self.symbol = symbol
+        self.symbol_size = symbol_size
+        self.properties = properties
+        self.nodeId = nodeId
+        self.item_style = item_style
         
-        # 处理中心节点
-        graph_dto = res[0]
-        g_node_dto = {
-            "label": graph_dto["name"],
-            'type': graph_dto["label"],
-            "category": 0,
-            "name": "0",
-            # "id": graph_dto["id"],
-            "symbol": "circle",
-            "symbolSize": 50,
-            "properties": graph_dto["properties"],
-            "nodeId": graph_dto["id"],
-            "itemStyle": {"display": True}
+    def to_dict(self):
+        return {
+            "type": self.type,
+            "category": self.category,
+            "label": self.label,
+            "name": self.name,
+            "symbol": self.symbol,
+            "symbolSize": self.symbol_size,
+            "properties": self.properties,
+            "nodeId": self.nodeId,
+            "itemStyle": self.item_style
         }
 
-        node_id += 1
-        graph_data['node'].append(g_node_dto)
+
+
+from typing import List
+
+class GraphLabelDTO:
+    def __init__(self):
+        self.categories: List[CategorieDTO] = []
+        self.node: List[GNodeDTO] = []
+        self.links: List[LinkDTO] = []
         
-        # 处理关系节点
-        if graph_dto['relations']:
-            rs_id = 2
-            for relation in graph_dto['relations']:
-                r_type = relation['rType']
-                if r_type not in c_map:
-                    c_map[r_type] = rs_id
-                    graph_data['categories'].append({"name": r_type})
-                    rs_id += 1
-                
-                # 添加关系节点
-                n_node_dto = {
-                    "label": "",
-                    'type': graph_dto["label"],
-                    "category": 1,
-                    "name": node_id,
-                    # "id": len(nodes),
-                    "symbol": "diamond",
-                    "symbolSize": 10,
-                    "properties": graph_dto["properties"],
-                    "nodeId": graph_dto["label"],
-                    "itemStyle": item_style_map
-                }
-                graph_data['node'].append(n_node_dto)
-                
-                # 添加关系链接
-                graph_data['links'].append({
-                    'source': str(g_node_dto['name']),
-                    'target': str(n_node_dto['name']),
-                    'value': r_type,
-                    'relationShipType': r_type
-                })
-                
-                node_id += 1
-                
-                # 处理目标节点
-                target = relation['target']
-                children_item_style_map = {"display": True}
-                symbol = "circle"
-                
-                if target['pCount'] == 0:
-                    children_item_style_map["display"] = False
-                
-                e_node_dto = {
-                    "label": target["name"],
-                    "type": graph_dto["label"],
-                    "category": c_map[r_type],
-                    "name": node_id,
-                    # "id": e_node["Id"],
-                    "symbol": "circle",
-                    "symbolSize": 28,
-                    "properties": target["properties"],
-                    "nodeId": target["id"],
-                    "itemStyle": children_item_style_map
-                }
+    def to_dict(self):
+        return {
+            "categories": [category.to_dict() for category in self.categories],
+            "node": [node.to_dict() for node in self.node],
+            "links": [link.to_dict() for link in self.links]
+        }
 
-                node_id += 1
-                graph_data['node'].append(e_node_dto)
-                
-                # 添加目标节点链接
-                graph_data['links'].append({
-                    'source': str(n_node_dto['name']),
-                    'target': str(e_node_dto['name']),
-                    'value': "",
-                    'relationShipType': r_type
-                })
+class LinkDTO:
+    def __init__(self, source: str, target: str, value: str, relationShipType: str):
+        self.source = source
+        self.target = target
+        self.value = value
+        self.relationShipType = relationShipType
         
-        return graph_data
+    def to_dict(self):
+        return {
+            "source": self.source,
+            "target": self.target,
+            "value": self.value,
+            "relationShipType": self.relationShipType
+        }

+ 45 - 19
src/knowledge/service/kg_node_service.py

@@ -53,22 +53,24 @@ class KGNodeService:
 
     def paginated_search(self, search_params: dict) -> dict:
         load_props = search_params.get('load_props', False)
+        prop_service = KGPropService(self.db)
+        edge_service = KGEdgeService(self.db)
         keyword = search_params.get('keyword', '')
-        category = search_params.get('category', '')
+        category = search_params.get('category', None)
         page_no = search_params.get('pageNo', 1)
-        distance = search_params.get('distance',DISTANCE_THRESHOLD)
+        # distance 为NONE或不存在时,使用默认值
+        if search_params.get('distance') is None:
+            distance = DISTANCE_THRESHOLD
+        else:
+            distance = search_params.get('distance')
+        if distance==0:
+            distance = 0.1
         limit = search_params.get('limit', 10)
 
         if page_no < 1:
             page_no = 1
         if limit < 1:
             limit = 10
-            
-        cache_key = f"paginated_search:{keyword}:{category}:{page_no}:{distance}:{limit}:{str(search_params.get('knowledge_ids', ''))}:{load_props}"
-        logger.debug(f"Cache key: {cache_key}")
-        if cache_key in self._cache:
-            cached_value = self._cache[cache_key]
-            return copy.deepcopy(cached_value)
 
         embedding = Vectorizer.get_instance().get_embedding(keyword)
         offset = (page_no - 1) * limit
@@ -77,7 +79,7 @@ class KGNodeService:
             # 构建基础查询条件
             base_query = self.db.query(func.count(KGNode.id)).filter(
                 KGNode.status == 0,
-                KGNode.embedding.l2_distance(embedding) < distance
+                KGNode.embedding.l2_distance(embedding) <= distance
             )
             # 如果有category,则添加额外过滤条件
             if category:
@@ -95,23 +97,23 @@ class KGNodeService:
                 KGNode.name,
                 KGNode.category,
                 KGNode.embedding.l2_distance(embedding).label('distance')
-            )            
+            )
             query = query.filter(KGNode.status == 0)
-            #category有值时,过滤掉category不等于category的节点
+            # category有值时,过滤掉category不等于category的节点
             if category:
                 query = query.filter(KGNode.category == category)
             if search_params.get('knowledge_ids'):
                 query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
-            query = query.filter(KGNode.embedding.l2_distance(embedding) < distance)
+            query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
             results = query.order_by('distance').offset(offset).limit(limit).all()
-            #将results相同distance的category=疾病的放在前面
-            #results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
 
-            finalResults = {
+            return {
                 'records': [{
                     'id': r.id,
                     'name': r.name,
                     'category': r.category,
+                    'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
+                    # 'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
                     'distance': round(r.distance, 3)
                 } for r in results],
                 'pagination': {
@@ -120,10 +122,8 @@ class KGNodeService:
                     'limit': limit,
                     'totalPages': (total_count + limit - 1) // limit
                 }
-            
+
             }
-            self._cache[cache_key] = copy.deepcopy(finalResults)
-            return finalResults
         except Exception as e:
             logger.error(f"分页查询失败: {str(e)}")
             raise e
@@ -229,4 +229,30 @@ class KGNodeService:
             except Exception as e:
                 self.db.rollback()
                 print(f"批量处理ER节点失败: {str(e)}")
-                raise ValueError("Batch process failed")
+                raise ValueError("Batch process failed")
+
+    def get_node_by_name_category(self, name: str, category: str):
+        if not name or not category:
+            raise ValueError("Name and category are required")
+
+        cache_key = f"get_node_by_name_category_{name}:{category}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+
+        node = self.db.query(KGNode).filter(
+            KGNode.name == name,
+            KGNode.category == category,
+            KGNode.status == 0
+        ).first()
+
+        if not node:
+            return None
+
+        node_data = {
+            'id': node.id,
+            'name': node.name,
+            'category': node.category,
+            'version': node.version
+        }
+        self._cache[cache_key] = node_data
+        return node_data

+ 17 - 0
src/knowledge/service/kg_prop_service.py

@@ -32,7 +32,24 @@ class KGPropService:
     #     except Exception as e:
     #         logger.error(f"根据ref_id查询属性失败: {str(e)}")
     #         raise ValueError("查询失败")
+    def get_prop_by_id(self, id: int)-> dict:
+        try:
+            query = self.db.query(KGProp).filter(KGProp.id == id)
 
+            props = query.first()
+            if not props:
+                return None
+            return {
+                'id': props.id,
+                'category': props.category,
+                'prop_name': props.prop_name,
+                'prop_value': props.prop_value,
+                'prop_title': props.prop_title,
+                'type': props.type
+            }
+        except Exception as e:
+            logger.error(f"根据id查询属性失败: {str(e)}")
+            raise ValueError("查询失败")
 
     def get_props_by_ref_id(self, ref_id: int, prop_name: str = None) -> List[dict]: