import sys,os current_path = os.getcwd() sys.path.append(current_path) from fastapi import APIRouter, Depends, Query from typing import Optional, List from libs.graph_helper import GraphHelper from saas.models.response import StandardResponse from saas.models.request import GraphFilterRequest router = APIRouter(prefix="/graph", tags=["Knowledge Graph"]) graph_helper = GraphHelper() @router.get("/nodes/search", response_model=StandardResponse) async def search_nodes( keyword: str = Query(..., min_length=2), limit: int = Query(10, ge=1, le=100), node_type: Optional[str] = Query(None), min_degree: Optional[int] = Query(None) ): """ 根据关键词和属性过滤条件搜索图谱节点 """ try: results = graph_helper.node_search( keyword, limit=limit, node_type=node_type, min_degree=min_degree ) community_report_results = graph_helper.community_report_search(keyword) return StandardResponse( success=True, records={"nodes":results,"community_report":community_report_results}, error_code = 0, error_msg=f"Found {len(results)} nodes" ) except Exception as e: raise e return StandardResponse( success=False, message=str(e) ) @router.get("/nodes/neighbor_search", response_model=StandardResponse) async def neighbor_search( keyword: str = Query(..., min_length=2), limit: int = Query(10, ge=1, le=100), node_type: Optional[str] = Query(None), neighbor_type: Optional[str] = Query(None), min_degree: Optional[int] = Query(None) ): """ 根据关键词和属性过滤条件搜索图谱节点 """ try: scores_factor = 1.7 results = [] diseases = {} # results = graph_helper.node_search( # keyword, # limit=limit, # node_type=node_type, # min_degree=min_degree # ) print("搜索的结果数量:",len(results)) # 检查是否有结果返回,没有则进行关键词拆分搜索(仅针对症状),并检查拆分后的结果是否有结果返回,没有则返回空结果。如果有结果返回,则返回结果。(仅针对症状),并检查拆分后的结果是否有结果返回,没有则返回空结果。如果有结果返回,则返回结果。 has_good_result = False # new_results = [] # for item in results: # if item["score"] > scores_factor: # has_good_result = True # new_results.append(item) #results = new_results print("通过相似度过滤之后剩余的数量:",len(results)) if not has_good_result: keywords = keyword.split(" ") new_results = [] for item in keywords: if len(item) > 1: results = graph_helper.node_search( item, limit=limit, node_type=node_type, min_degree=min_degree ) for result_item in results: if result_item["score"] > scores_factor: new_results.append(result_item) if result_item["type"] == "Disease": if result_item["id"] not in diseases: diseases[result_item["id"]] = { "id":result_item["id"], "type":result_item["type"], "count":1 } else: diseases[result_item["id"]]["count"] = diseases[result_item["id"]]["count"] + 1 has_good_result = True results = new_results print("扩展搜索的结果数量:",len(results)) neighbors_data = {} for item in results: entities, relations = graph_helper.neighbor_search(item["id"], 1) max = 20 #因为类似发热这种疾病会有很多关联的疾病,所以需要防止检索范围过大,设置了上限 for neighbor in entities: if neighbor["type"] == neighbor_type: #如果这里正好找到了要求检索的节点类型 if neighbor["id"] not in neighbors_data: neighbors_data[neighbor["id"]] = { "id":neighbor["id"], "type":neighbor["type"], "count":1 } else: neighbors_data[neighbor["id"]]["count"] = neighbors_data[neighbor["id"]]["count"] + 1 else: #如果这里找到的节点是个疾病,那么就再检索一层,看看是否有符合要求的节点类型 if neighbor["type"] == "Disease": if neighbor["id"] not in diseases: diseases[neighbor["id"]] = { "id":neighbor["id"], "type":neighbor["type"], "count":1 } else: diseases[neighbor["id"]]["count"] = diseases[neighbor["id"]]["count"] + 1 disease_entities, relations = graph_helper.neighbor_search(neighbor["id"], 1) for disease_neighbor in disease_entities: #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查 if "type" in disease_neighbor.keys(): if disease_neighbor["type"] == neighbor_type: if disease_neighbor["id"] not in neighbors_data: neighbors_data[disease_neighbor["id"]] = { "id":disease_neighbor["id"], "type":disease_neighbor["type"], "count":1 } else: neighbors_data[disease_neighbor["id"]]["count"] = neighbors_data[disease_neighbor["id"]]["count"] + 1 #最多搜索的范围是max个疾病 max = max - 1 if max == 0: break disease_data = [diseases[k] for k in diseases] disease_data = sorted(disease_data, key=lambda x:x["count"],reverse=True) data = [neighbors_data[k] for k in neighbors_data] data = sorted(data, key=lambda x:x["count"],reverse=True) if len(data) > 10: data = data[:10] factor = 1.0 total = 0.0 for item in data: total = item["count"] * factor + total for item in data: item["count"] = item["count"] / total factor = factor * 0.9 if len(disease_data) > 10: disease_data = disease_data[:10] factor = 1.0 total = 0.0 for item in disease_data: total = item["count"] * factor + total for item in disease_data: item["count"] = item["count"] / total factor = factor * 0.9 return StandardResponse( success=True, records={"nodes":disease_data,"neighbors":data}, error_code = 0, error_msg=f"Found {len(results)} nodes" ) except Exception as e: return StandardResponse( success=False, error_code=500, error_msg=str(e) ) @router.post("/nodes/filter", response_model=StandardResponse) async def filter_nodes(request: GraphFilterRequest): """ 根据复杂条件过滤节点 """ try: results = graph_helper.filter_nodes( node_types=request.node_types, min_degree=request.min_degree, min_community_size=request.min_community_size, attributes=request.attributes ) return StandardResponse( success=True, data={"nodes": results}, message=f"Filtered {len(results)} nodes" ) except Exception as e: return StandardResponse( success=False, message=str(e) ) @router.get("/statistics", response_model=StandardResponse) async def get_graph_statistics(): """ 获取图谱统计信息 """ try: stats = graph_helper.get_graph_statistics() return StandardResponse( success=True, data=stats, message="Graph statistics retrieved" ) except Exception as e: return StandardResponse( success=False, message=str(e) ) @router.get("/community/{community_id}", response_model=StandardResponse) async def get_community_details( community_id: int, min_size: int = Query(3, ge=2) ): """ 获取指定社区的详细信息 """ try: community_info = graph_helper.get_community_details(community_id, min_size) return StandardResponse( success=True, data=community_info, message="Community details retrieved" ) except Exception as e: return StandardResponse( success=False, message=str(e) ) # # 新增图谱路径分析接口 # @router.post("/path-analysis", response_model=StandardResponse) # async def analyze_paths(request: GraphSearchRequest): # """ # 分析节点间的潜在关系路径 # """ # try: # paths = graph_helper.find_paths( # source_id=request.source_id, # target_id=request.target_id, # max_depth=request.max_depth # ) # return StandardResponse( # success=True, # data={"paths": paths}, # message=f"Found {len(paths)} possible paths" # ) # except Exception as e: # return StandardResponse( # success=False, # message=str(e) # ) @router.get("/node/{node_id}", response_model=StandardResponse) async def get_node_details( node_id: str, with_relations: bool = False, relation_types: List[str] = Query(None), relation_limit: int = Query(10, ge=1, le=100) ): """ 获取节点详细信息 - relation_types: 过滤指定类型的关系 - relation_limit: 返回关系数量限制 """ try: node_info = graph_helper.get_node_details( node_id, include_relations=with_relations, relation_types=relation_types, relation_limit=relation_limit ) return StandardResponse( success=True, data=node_info, message="Node details retrieved" ) except ValueError as e: return StandardResponse( success=False, message=str(e) ) graph_router = router