123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- import sys,os
- current_path = os.getcwd()
- sys.path.append(current_path)
- import streamlit as st
- import os
- import json
- import requests
- from dotenv import load_dotenv
- from langgraph.graph import StateGraph, MessagesState,START, END
- from langchain_core.messages import HumanMessage, SystemMessage, AIMessage,BaseMessage
- from langchain_openai import ChatOpenAI
- #from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
- from langchain_core.callbacks import AsyncCallbackHandler
- from typing import Any
- #streamlit run agent/app.py
- # 加载环境变量
- load_dotenv()
- from config.site import SiteConfig
- config = SiteConfig()
- GRAPH_API_URL = config.get_config("GRAPH_API_URL")
- DEEPSEEK_API_URL = config.get_config("DEEPSEEK_API_URL")
- DEEPSEEK_API_KEY = config.get_config("DEEPSEEK_API_KEY")
- PROMPT_INTENSION = '''你是一个临床医学专家,你需要对用户的问题进行意图分析,判断用户的问题是否是关于医疗健康的问题。如果是医疗健康的问题,你需要输出"是",否则输出"否"。'''
- PROMPT_DATA_EXTRACTION = '''你是一个NLP专家,你需要对用户的问题进行数据抽取,抽取的结果输出为json格式。以下是json格式的样本
- ···json
- {
- "pat_name": "XXX", "pat_sex": "XXX","pat_age": 0,"clinical_department": "XXX","chief_complaint":["A","B","C"],
- "present_illness":["A","B","C"],past_medical_history:["A","B","C"],physical_examination:["A","B","C"],lab_and_imaging:["A","B","C"]
- },
- ···
- 其中字段的描述如下:
- pat_name:患者名字,字符串,如"张三",无患者信息输出""
- pat_sex:患者性别,字符串,如"男",无患者信息输出""
- pat_age:患者年龄,数字,单位为年,如25岁,输出25,无年龄信息输出0
- clinical_department:就诊科室,字符串,如"呼吸内科",无就诊科室信息输出""
- chief_complaint:主诉,字符串列表,包括主要症状的列表,如["胸痛","发热"],无主诉输出[]
- present_illness:现病史,字符串列表,包括症状发展过程、诱因、伴随症状(如疼痛性质、放射部位、缓解方式,无现病史信息输出[]
- past_medical_history:既往病史,字符串列表,包括疾病史(如高血压、糖尿病)、手术史、药物过敏史、家族史等,无现病史信息输出[]
- physical_examination:体格检查,字符串列表,如生命体征(血压、心率)、心肺腹部体征、实验室/影像学结果(如心电图异常、肌钙蛋白升高),无信息输出[]
- lab_and_imaging:检验与检查,字符串列表,包括血常规、生化指标、心电图(ECG)、胸部X光、CT等检查项目,结果和报告等,无信息输出[]
- '''
- ######################## langgraph
- # 初始化Deepseek模型
- llm = ChatOpenAI(
- # deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-R1
- model="Pro/deepseek-ai/DeepSeek-V3",
- #model="Qwen/QwQ-32B",
- api_key=DEEPSEEK_API_KEY,
- base_url=DEEPSEEK_API_URL,
- streaming=True)
- # 定义LangGraph工作流程
- class MyStreamingOutCallbackHandler(AsyncCallbackHandler):
- def __init__(self):
- super().__init__()
- self.content = ""
- async def on_llm_new_token(self, token: str, **kwargs) -> None:
- # 流式输出token的回调函数
- self.content += token
- print(token)
- st.write(token)
- async def on_llm_end(self, response, **kwargs) -> None:
- # 流式输出结束的回调函数
- # # 在这里可以处理流式输出的结束逻辑
- pass
- async def on_llm_error(self, error, **kwargs) -> None:
- # 流式输出错误的回调函数
- pass
- async def on_chat_model_start(self,
- serialized: dict[str, Any],
- messages: list[list[BaseMessage]], **kwargs) -> None:
- # 流式输出开始的回调函数
- print("on_chat_model_start", messages)
-
- def agent_node(state):
- print("agent_node", state)
- messages = [
- SystemMessage(content=PROMPT_INTENSION),
- HumanMessage(content=state["messages"][-1].content)
- ]
- response = llm.stream(messages)
- collect_messages = []
-
- for chunk in response:
- text = chunk.content or ""
- if text == "":
- continue
- #st.write(text)
- #state["callback"](text)
- collect_messages.append(text)
- fully_response = ''.join(collect_messages)
- state["messages"].append(AIMessage(content=fully_response))
- #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]})
- return state
- def entity_extraction_node(state):
- state["messages"] = state["messages"][:-1]
-
- print("entity_extraction_node",state["messages"][-1].content)
- print(state["messages"])
- messages = [
- SystemMessage(content=PROMPT_DATA_EXTRACTION),
- HumanMessage(content=state["messages"][-1].content)
- ]
- response = llm.stream(messages)
- collect_messages = []
-
- for chunk in response:
- text = chunk.content or ""
- if text == "":
- continue
- #st.write(text)
- state["callback"](text)
- collect_messages.append(text)
- fully_response = ''.join(collect_messages)
- state["messages"].append(AIMessage(content=fully_response))
- #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]})
- return state
- def recommed_check(state):
- print("recommed_check",state["messages"][-1].content)
- text_json = state["messages"][-1].content
- text_json = text_json.strip("\n```json")
- json_data = json.loads(text_json)
- headers = {
- "Content-Type": "application/json"
- }
- data = {
- "q": " ".join(json_data["chief_complaint"]),
- "type":"Check"
- }
- response = requests.get(f"{GRAPH_API_URL}/graph/nodes/neighbor_search?keyword={data['q']}&&neighbor_type={data['type']}",
- headers=headers)
- response.raise_for_status()
- response = response.json()
- state["callback"]("\n")
- if "records" in response.keys():
- state["callback"]("## 该病案可能的诊断包括\n")
- if response["records"] and "nodes" in response["records"].keys():
- response_data = response["records"]["nodes"]
- for data in response_data:
- print(data)
- if "type" in data.keys():
- if data["type"] == "Disease":
- state["callback"]("- "+data["id"]+"("+data["type"]+","+str(round(data['count']*100,2))+"%)\n")
-
- state["callback"]("## 推荐的检查\n")
- if "neighbors" in response["records"].keys():
- response_data = response["records"]["neighbors"]
- for data in response_data:
- state["callback"]("- "+data["id"]+"("+str(round(data['count']*100,2))+"%)\n")
-
- data = {
- "q": " ".join(json_data["chief_complaint"]),
- "type":"Drug"
- }
- response = requests.get(f"{GRAPH_API_URL}/graph/nodes/neighbor_search?keyword={data['q']}&&neighbor_type={data['type']}",
- headers=headers)
- response.raise_for_status()
- response = response.json()
- state["callback"]("\n")
- if "records" in response.keys():
- state["callback"]("## 推荐的药物\n")
- if "neighbors" in response["records"].keys():
- response_data = response["records"]["neighbors"]
- for data in response_data:
- state["callback"]("- "+data["id"]+"("+str(round(data['count']*100,2))+"%)\n")
-
-
- print("recommed_check finished")
-
- #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]})
- return state
- def should_continue_2_entity_extraction(state):
- print("should_continue")
- previous_message = state["messages"][-1]
- ai_resposne = previous_message.content or ""
- ai_resposne = ai_resposne.strip()
- print("should_continue",ai_resposne)
- if ai_resposne == "是":
- # 是医疗健康问题,继续执行工具节点
- return "continue"
- return "end"
- def tool_node(state):
- # 在此添加自定义工具逻辑
- print("tool_node")
- return {"tool_output": "Tool executed"}
- class MyMessageState(MessagesState):
- callback: Any = None
-
- workflow = StateGraph(MyMessageState)
- workflow.add_node("agent", agent_node)
- workflow.add_node("tools", tool_node)
- workflow.add_node("extract", entity_extraction_node)
- workflow.add_node("recommend_check", recommed_check)
- workflow.add_edge(START, "agent")
- #workflow.add_edge("agent", "tools")
- workflow.add_edge("extract", "recommend_check")
- workflow.add_edge("recommend_check", END)
- workflow.add_edge("tools", END)
- workflow.add_conditional_edges( "agent", should_continue_2_entity_extraction, {"continue":"extract", "end":END})
- app = workflow.compile()
- def test_langgraph(user_input):
- messages = [
- HumanMessage(content=user_input)
- ]
- response = app.invoke({"messages":messages})
- print(response)
-
- ######################## networkx
-
- #Streamlit界面
- st.set_page_config(layout="wide")
- tmp_text = ""
- def st_callback(text):
- global tmp_text
- tmp_text = tmp_text + text
- st.info(tmp_text)
- def submit_question(text):
- print("submit_question", text)
- user_input = None
- submit_button = None
- st.header("Med Graph Agent")
- if "history" not in st.session_state:
- st.session_state.history = []
-
- for message in st.session_state.history:
- with st.chat_message(message["role"]):
- st.write(message["content"])
- user_input = st.chat_input("请输入您的问题:")
- if user_input:
- messages = [
- HumanMessage(content=user_input)
- ]
- state = MyMessageState(messages=messages, callback=st_callback)
- st.session_state.history.append({"role": "user", "content": user_input})
- st.chat_message("user").write(user_input)
- placeholder = st.empty()
- with placeholder:
- st.info("thinking...")
- tmp_text = ""
- response = app.invoke(state)
- print(state["messages"])
- message = {"role": "assistant", "content": tmp_text}
- placeholder.empty()
-
- #st.write("### 回答:")
- st.session_state.history.append(message)
- with st.chat_message(message["role"]):
- st.write(message["content"])
-
- # left_col, right_col = st.columns([1,2])
- # with left_col:
- # st.header("Med Graph Agent")
- # user_input = st.text_area("请输入您的问题:",height=200)
- # submit_button = st.button("提交", on_click=submit_question, args=[user_input])
-
- # with right_col:
- # with st.container(height=800):
- # if "history" not in st.session_state:
- # st.session_state.history = []
-
-
- # for message in st.session_state.history:
- # with st.chat_message(message["role"]):
- # st.write(message["content"])
- # if submit_button:
- # messages = [
- # HumanMessage(content=user_input)
- # ]
- # state = MyMessageState(messages=messages, callback=st_callback)
- # st.session_state.history.append({"role": "user", "content": user_input})
- # st.chat_message("user").write(user_input)
- # placeholder = st.empty()
- # with placeholder:
- # st.info("thinking...")
- # tmp_text = ""
- # response = app.invoke(state)
- # print(state["messages"])
- # message = {"role": "assistant", "content": tmp_text}
- # placeholder.empty()
-
- # #st.write("### 回答:")
- # st.session_state.history.append(message)
- # with st.chat_message(message["role"]):
- # st.write(message["content"])
|