123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- import sys,os
- from community.graph_helper2 import GraphHelper
- from model.response import StandardResponse
- current_path = os.getcwd()
- sys.path.append(current_path)
- import time
- from fastapi import APIRouter, Depends, Query
- from typing import Optional, List
- import sys
- sys.path.append('..')
- from utils.agent import call_chat_api,get_conversation_id
- import json
- router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
- graph_helper = GraphHelper()
- @router.get("/nodes/recommend", response_model=StandardResponse)
- async def recommend(
- chief: str
- ):
- app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
- conversation_id = get_conversation_id(app_id)
- result = call_chat_api(app_id, conversation_id, chief)
- json_data = json.loads(result)
- keyword = " ".join(json_data["chief_complaint"])
- return await neighbor_search(keyword=keyword, neighbor_type='Check',limit=10)
- @router.get("/nodes/neighbor_search", response_model=StandardResponse)
- async def neighbor_search(
- keyword: str = Query(..., min_length=2),
- limit: int = Query(10, ge=1, le=100),
- node_type: Optional[str] = Query(None),
- neighbor_type: Optional[str] = Query(None),
- min_degree: Optional[int] = Query(None)
- ):
- """
- 根据关键词和属性过滤条件搜索图谱节点
- """
- try:
- start_time = time.time()
- print(f"开始执行neighbor_search,参数:keyword={keyword}, limit={limit}, node_type={node_type}, neighbor_type={neighbor_type}, min_degree={min_degree}")
- scores_factor = 1.7
- results = []
- diseases = {}
- has_good_result = False
- if not has_good_result:
- keywords = keyword.split(" ")
- new_results = []
- for item in keywords:
- if len(item) > 1:
- results = graph_helper.node_search2(
- item,
- limit=limit,
- node_type=node_type,
- min_degree=min_degree
- )
- for result_item in results:
- if result_item["score"] > scores_factor:
- new_results.append(result_item)
- if result_item["type"] == "Disease":
- if result_item["id"] not in diseases:
- diseases[result_item["id"]] = {
- "id":result_item["id"],
- "type":1,
- "count":1
- }
- else:
- diseases[result_item["id"]]["count"] = diseases[result_item["id"]]["count"] + 1
- has_good_result = True
- results = new_results
- print("扩展搜索的结果数量:",len(results))
- neighbors_data = {}
- for item in results:
- entities, relations = graph_helper.neighbor_search(item["id"], 1)
- max = 20 #因为类似发热这种疾病会有很多关联的疾病,所以需要防止检索范围过大,设置了上限
- for neighbor in entities:
- #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查
- if "type" not in neighbor.keys():
- continue
- if neighbor["type"] == neighbor_type:
- #如果这里正好找到了要求检索的节点类型
- if neighbor["id"] not in neighbors_data:
- neighbors_data[neighbor["id"]] = {
- "id":neighbor["id"],
- "type":neighbor["type"],
- "count":1
- }
- else:
- neighbors_data[neighbor["id"]]["count"] = neighbors_data[neighbor["id"]]["count"] + 1
- else:
- #如果这里找到的节点是个疾病,那么就再检索一层,看看是否有符合要求的节点类型
- if neighbor["type"] == "Disease":
- if neighbor["id"] not in diseases:
- diseases[neighbor["id"]] = {
- "id":neighbor["id"],
- "type":"Disease",
- "count":1
- }
- else:
- diseases[neighbor["id"]]["count"] = diseases[neighbor["id"]]["count"] + 1
- disease_entities, relations = graph_helper.neighbor_search(neighbor["id"], 1)
- for disease_neighbor in disease_entities:
- #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查
- if "type" in disease_neighbor.keys():
- if disease_neighbor["type"] == neighbor_type:
- if disease_neighbor["id"] not in neighbors_data:
- neighbors_data[disease_neighbor["id"]] = {
- "id":disease_neighbor["id"],
- "type":disease_neighbor["type"],
- "count":1
- }
- else:
- neighbors_data[disease_neighbor["id"]]["count"] = neighbors_data[disease_neighbor["id"]]["count"] + 1
- #最多搜索的范围是max个疾病
- max = max - 1
- if max == 0:
- break
- disease_data = [diseases[k] for k in diseases]
- disease_data = sorted(disease_data, key=lambda x:x["count"],reverse=True)
- data = [neighbors_data[k] for k in neighbors_data if neighbors_data[k]["type"] == "Check"]
- data = sorted(data, key=lambda x:x["count"],reverse=True)
- if len(data) > 10:
- data = data[:10]
- factor = 1.0
- total = 0.0
- for item in data:
- total = item["count"] * factor + total
- for item in data:
- item["count"] = item["count"] / total
- factor = factor * 0.9
- if len(disease_data) > 10:
- disease_data = disease_data[:10]
- factor = 1.0
- total = 0.0
- for item in disease_data:
- total = item["count"] * factor + total
- for item in disease_data:
- item["count"] = item["count"] / total
- factor = factor * 0.9
- for item in data:
- item["type"] = 3
- item["name"] = item["id"]
- item["rate"] = round(item["count"] * 100, 2)
- for item in disease_data:
- item["type"] = 1
- item["name"] = item["id"]
- item["rate"] = round(item["count"] * 100, 2)
- end_time = time.time()
- print(f"neighbor_search执行完成,耗时:{end_time - start_time:.2f}秒")
- return StandardResponse(
- success=True,
- records={"可能诊断":disease_data,"推荐检验":data},
- error_code = 0,
- error_msg=f"Found {len(results)} nodes"
- )
- except Exception as e:
- print(e)
- raise e
- return StandardResponse(
- success=False,
- error_code=500,
- error_msg=str(e)
- )
- graph_router = router
|