import sys,os current_path = os.getcwd() sys.path.append(current_path) import streamlit as st import os import json import requests from dotenv import load_dotenv from langgraph.graph import StateGraph, MessagesState,START, END from langchain_core.messages import HumanMessage, SystemMessage, AIMessage,BaseMessage from langchain_openai import ChatOpenAI #from langchain_community.callbacks.streamlit import StreamlitCallbackHandler from langchain_core.callbacks import AsyncCallbackHandler from typing import Any #streamlit run agent/app.py # 加载环境变量 load_dotenv() from config.site import SiteConfig config = SiteConfig() GRAPH_API_URL = config.get_config("GRAPH_API_URL") DEEPSEEK_API_URL = config.get_config("DEEPSEEK_API_URL") DEEPSEEK_API_KEY = config.get_config("DEEPSEEK_API_KEY") PROMPT_INTENSION = '''你是一个临床医学专家,你需要对用户的问题进行意图分析,判断用户的问题是否是关于医疗健康的问题。如果是医疗健康的问题,你需要输出"是",否则输出"否"。''' PROMPT_DATA_EXTRACTION = '''你是一个NLP专家,你需要对用户的问题进行数据抽取,抽取的结果输出为json格式。以下是json格式的样本 ···json { "pat_name": "XXX", "pat_sex": "XXX","pat_age": 0,"clinical_department": "XXX","chief_complaint":["A","B","C"], "present_illness":["A","B","C"],past_medical_history:["A","B","C"],physical_examination:["A","B","C"],lab_and_imaging:["A","B","C"] }, ··· 其中字段的描述如下: pat_name:患者名字,字符串,如"张三",无患者信息输出"" pat_sex:患者性别,字符串,如"男",无患者信息输出"" pat_age:患者年龄,数字,单位为年,如25岁,输出25,无年龄信息输出0 clinical_department:就诊科室,字符串,如"呼吸内科",无就诊科室信息输出"" chief_complaint:主诉,字符串列表,包括主要症状的列表,如["胸痛","发热"],无主诉输出[] present_illness:现病史,字符串列表,包括症状发展过程、诱因、伴随症状(如疼痛性质、放射部位、缓解方式,无现病史信息输出[] past_medical_history:既往病史,字符串列表,包括疾病史(如高血压、糖尿病)、手术史、药物过敏史、家族史等,无现病史信息输出[] physical_examination:体格检查,字符串列表,如生命体征(血压、心率)、心肺腹部体征、实验室/影像学结果(如心电图异常、肌钙蛋白升高),无信息输出[] lab_and_imaging:检验与检查,字符串列表,包括血常规、生化指标、心电图(ECG)、胸部X光、CT等检查项目,结果和报告等,无信息输出[] ''' ######################## langgraph # 初始化Deepseek模型 llm = ChatOpenAI( # deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-R1 model="Pro/deepseek-ai/DeepSeek-V3", #model="Qwen/QwQ-32B", api_key=DEEPSEEK_API_KEY, base_url=DEEPSEEK_API_URL, streaming=True) # 定义LangGraph工作流程 class MyStreamingOutCallbackHandler(AsyncCallbackHandler): def __init__(self): super().__init__() self.content = "" async def on_llm_new_token(self, token: str, **kwargs) -> None: # 流式输出token的回调函数 self.content += token print(token) st.write(token) async def on_llm_end(self, response, **kwargs) -> None: # 流式输出结束的回调函数 # # 在这里可以处理流式输出的结束逻辑 pass async def on_llm_error(self, error, **kwargs) -> None: # 流式输出错误的回调函数 pass async def on_chat_model_start(self, serialized: dict[str, Any], messages: list[list[BaseMessage]], **kwargs) -> None: # 流式输出开始的回调函数 print("on_chat_model_start", messages) def agent_node(state): print("agent_node", state) messages = [ SystemMessage(content=PROMPT_INTENSION), HumanMessage(content=state["messages"][-1].content) ] response = llm.stream(messages) collect_messages = [] for chunk in response: text = chunk.content or "" if text == "": continue #st.write(text) #state["callback"](text) collect_messages.append(text) fully_response = ''.join(collect_messages) state["messages"].append(AIMessage(content=fully_response)) #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]}) return state def entity_extraction_node(state): state["messages"] = state["messages"][:-1] print("entity_extraction_node",state["messages"][-1].content) print(state["messages"]) messages = [ SystemMessage(content=PROMPT_DATA_EXTRACTION), HumanMessage(content=state["messages"][-1].content) ] response = llm.stream(messages) collect_messages = [] for chunk in response: text = chunk.content or "" if text == "": continue #st.write(text) state["callback"](text) collect_messages.append(text) fully_response = ''.join(collect_messages) state["messages"].append(AIMessage(content=fully_response)) #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]}) return state def recommed_check(state): print("recommed_check",state["messages"][-1].content) text_json = state["messages"][-1].content text_json = text_json.strip("\n```json") json_data = json.loads(text_json) headers = { "Content-Type": "application/json" } data = { "q": " ".join(json_data["chief_complaint"]), "type":"Check" } response = requests.get(f"{GRAPH_API_URL}/graph/nodes/neighbor_search?keyword={data['q']}&&neighbor_type={data['type']}", headers=headers) response.raise_for_status() response = response.json() state["callback"]("\n") if "records" in response.keys(): state["callback"]("## 该病案可能的诊断包括\n") if response["records"] and "nodes" in response["records"].keys(): response_data = response["records"]["nodes"] for data in response_data: print(data) if "type" in data.keys(): if data["type"] == "Disease": state["callback"]("- "+data["id"]+"("+data["type"]+","+str(round(data['count']*100,2))+"%)\n") state["callback"]("## 推荐的检查\n") if "neighbors" in response["records"].keys(): response_data = response["records"]["neighbors"] for data in response_data: state["callback"]("- "+data["id"]+"("+str(round(data['count']*100,2))+"%)\n") data = { "q": " ".join(json_data["chief_complaint"]), "type":"Drug" } response = requests.get(f"{GRAPH_API_URL}/graph/nodes/neighbor_search?keyword={data['q']}&&neighbor_type={data['type']}", headers=headers) response.raise_for_status() response = response.json() state["callback"]("\n") if "records" in response.keys(): state["callback"]("## 推荐的药物\n") if "neighbors" in response["records"].keys(): response_data = response["records"]["neighbors"] for data in response_data: state["callback"]("- "+data["id"]+"("+str(round(data['count']*100,2))+"%)\n") print("recommed_check finished") #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]}) return state def should_continue_2_entity_extraction(state): print("should_continue") previous_message = state["messages"][-1] ai_resposne = previous_message.content or "" ai_resposne = ai_resposne.strip() print("should_continue",ai_resposne) if ai_resposne == "是": # 是医疗健康问题,继续执行工具节点 return "continue" return "end" def tool_node(state): # 在此添加自定义工具逻辑 print("tool_node") return {"tool_output": "Tool executed"} class MyMessageState(MessagesState): callback: Any = None workflow = StateGraph(MyMessageState) workflow.add_node("agent", agent_node) workflow.add_node("tools", tool_node) workflow.add_node("extract", entity_extraction_node) workflow.add_node("recommend_check", recommed_check) workflow.add_edge(START, "agent") #workflow.add_edge("agent", "tools") workflow.add_edge("extract", "recommend_check") workflow.add_edge("recommend_check", END) workflow.add_edge("tools", END) workflow.add_conditional_edges( "agent", should_continue_2_entity_extraction, {"continue":"extract", "end":END}) app = workflow.compile() def test_langgraph(user_input): messages = [ HumanMessage(content=user_input) ] response = app.invoke({"messages":messages}) print(response) ######################## networkx #Streamlit界面 st.set_page_config(layout="wide") tmp_text = "" def st_callback(text): global tmp_text tmp_text = tmp_text + text st.info(tmp_text) def submit_question(text): print("submit_question", text) user_input = None submit_button = None st.header("Med Graph Agent") if "history" not in st.session_state: st.session_state.history = [] for message in st.session_state.history: with st.chat_message(message["role"]): st.write(message["content"]) user_input = st.chat_input("请输入您的问题:") if user_input: messages = [ HumanMessage(content=user_input) ] state = MyMessageState(messages=messages, callback=st_callback) st.session_state.history.append({"role": "user", "content": user_input}) st.chat_message("user").write(user_input) placeholder = st.empty() with placeholder: st.info("thinking...") tmp_text = "" response = app.invoke(state) print(state["messages"]) message = {"role": "assistant", "content": tmp_text} placeholder.empty() #st.write("### 回答:") st.session_state.history.append(message) with st.chat_message(message["role"]): st.write(message["content"]) # left_col, right_col = st.columns([1,2]) # with left_col: # st.header("Med Graph Agent") # user_input = st.text_area("请输入您的问题:",height=200) # submit_button = st.button("提交", on_click=submit_question, args=[user_input]) # with right_col: # with st.container(height=800): # if "history" not in st.session_state: # st.session_state.history = [] # for message in st.session_state.history: # with st.chat_message(message["role"]): # st.write(message["content"]) # if submit_button: # messages = [ # HumanMessage(content=user_input) # ] # state = MyMessageState(messages=messages, callback=st_callback) # st.session_state.history.append({"role": "user", "content": user_input}) # st.chat_message("user").write(user_input) # placeholder = st.empty() # with placeholder: # st.info("thinking...") # tmp_text = "" # response = app.invoke(state) # print(state["messages"]) # message = {"role": "assistant", "content": tmp_text} # placeholder.empty() # #st.write("### 回答:") # st.session_state.history.append(message) # with st.chat_message(message["role"]): # st.write(message["content"])