app.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import sys,os
  2. current_path = os.getcwd()
  3. sys.path.append(current_path)
  4. import streamlit as st
  5. import os
  6. import json
  7. import requests
  8. from dotenv import load_dotenv
  9. from langgraph.graph import StateGraph, MessagesState,START, END
  10. from langchain_core.messages import HumanMessage, SystemMessage, AIMessage,BaseMessage
  11. from langchain_openai import ChatOpenAI
  12. #from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
  13. from langchain_core.callbacks import AsyncCallbackHandler
  14. from typing import Any
  15. #streamlit run agent/app.py
  16. # 加载环境变量
  17. load_dotenv()
  18. from config.site import SiteConfig
  19. config = SiteConfig()
  20. GRAPH_API_URL = config.get_config("GRAPH_API_URL")
  21. DEEPSEEK_API_URL = config.get_config("DEEPSEEK_API_URL")
  22. DEEPSEEK_API_KEY = config.get_config("DEEPSEEK_API_KEY")
  23. PROMPT_INTENSION = '''你是一个临床医学专家,你需要对用户的问题进行意图分析,判断用户的问题是否是关于医疗健康的问题。如果是医疗健康的问题,你需要输出"是",否则输出"否"。'''
  24. PROMPT_DATA_EXTRACTION = '''你是一个NLP专家,你需要对用户的问题进行数据抽取,抽取的结果输出为json格式。以下是json格式的样本
  25. ···json
  26. {
  27. "pat_name": "XXX", "pat_sex": "XXX","pat_age": 0,"clinical_department": "XXX","chief_complaint":["A","B","C"],
  28. "present_illness":["A","B","C"],past_medical_history:["A","B","C"],physical_examination:["A","B","C"],lab_and_imaging:["A","B","C"]
  29. },
  30. ···
  31. 其中字段的描述如下:
  32. pat_name:患者名字,字符串,如"张三",无患者信息输出""
  33. pat_sex:患者性别,字符串,如"男",无患者信息输出""
  34. pat_age:患者年龄,数字,单位为年,如25岁,输出25,无年龄信息输出0
  35. clinical_department:就诊科室,字符串,如"呼吸内科",无就诊科室信息输出""
  36. chief_complaint:主诉,字符串列表,包括主要症状的列表,如["胸痛","发热"],无主诉输出[]
  37. present_illness:现病史,字符串列表,包括症状发展过程、诱因、伴随症状(如疼痛性质、放射部位、缓解方式,无现病史信息输出[]
  38. past_medical_history:既往病史,字符串列表,包括疾病史(如高血压、糖尿病)、手术史、药物过敏史、家族史等,无现病史信息输出[]
  39. physical_examination:体格检查,字符串列表,如生命体征(血压、心率)、心肺腹部体征、实验室/影像学结果(如心电图异常、肌钙蛋白升高),无信息输出[]
  40. lab_and_imaging:检验与检查,字符串列表,包括血常规、生化指标、心电图(ECG)、胸部X光、CT等检查项目,结果和报告等,无信息输出[]
  41. '''
  42. ######################## langgraph
  43. # 初始化Deepseek模型
  44. llm = ChatOpenAI(
  45. # deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-R1
  46. model="Pro/deepseek-ai/DeepSeek-V3",
  47. #model="Qwen/QwQ-32B",
  48. api_key=DEEPSEEK_API_KEY,
  49. base_url=DEEPSEEK_API_URL,
  50. streaming=True)
  51. # 定义LangGraph工作流程
  52. class MyStreamingOutCallbackHandler(AsyncCallbackHandler):
  53. def __init__(self):
  54. super().__init__()
  55. self.content = ""
  56. async def on_llm_new_token(self, token: str, **kwargs) -> None:
  57. # 流式输出token的回调函数
  58. self.content += token
  59. print(token)
  60. st.write(token)
  61. async def on_llm_end(self, response, **kwargs) -> None:
  62. # 流式输出结束的回调函数
  63. # # 在这里可以处理流式输出的结束逻辑
  64. pass
  65. async def on_llm_error(self, error, **kwargs) -> None:
  66. # 流式输出错误的回调函数
  67. pass
  68. async def on_chat_model_start(self,
  69. serialized: dict[str, Any],
  70. messages: list[list[BaseMessage]], **kwargs) -> None:
  71. # 流式输出开始的回调函数
  72. print("on_chat_model_start", messages)
  73. def agent_node(state):
  74. print("agent_node", state)
  75. messages = [
  76. SystemMessage(content=PROMPT_INTENSION),
  77. HumanMessage(content=state["messages"][-1].content)
  78. ]
  79. response = llm.stream(messages)
  80. collect_messages = []
  81. for chunk in response:
  82. text = chunk.content or ""
  83. if text == "":
  84. continue
  85. #st.write(text)
  86. #state["callback"](text)
  87. collect_messages.append(text)
  88. fully_response = ''.join(collect_messages)
  89. state["messages"].append(AIMessage(content=fully_response))
  90. #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]})
  91. return state
  92. def entity_extraction_node(state):
  93. state["messages"] = state["messages"][:-1]
  94. print("entity_extraction_node",state["messages"][-1].content)
  95. print(state["messages"])
  96. messages = [
  97. SystemMessage(content=PROMPT_DATA_EXTRACTION),
  98. HumanMessage(content=state["messages"][-1].content)
  99. ]
  100. response = llm.stream(messages)
  101. collect_messages = []
  102. for chunk in response:
  103. text = chunk.content or ""
  104. if text == "":
  105. continue
  106. #st.write(text)
  107. state["callback"](text)
  108. collect_messages.append(text)
  109. fully_response = ''.join(collect_messages)
  110. state["messages"].append(AIMessage(content=fully_response))
  111. #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]})
  112. return state
  113. def recommed_check(state):
  114. print("recommed_check",state["messages"][-1].content)
  115. text_json = state["messages"][-1].content
  116. text_json = text_json.strip("\n```json")
  117. json_data = json.loads(text_json)
  118. headers = {
  119. "Content-Type": "application/json"
  120. }
  121. data = {
  122. "q": " ".join(json_data["chief_complaint"]),
  123. "type":"Check"
  124. }
  125. response = requests.get(f"{GRAPH_API_URL}/graph/nodes/neighbor_search?keyword={data['q']}&&neighbor_type={data['type']}",
  126. headers=headers)
  127. response.raise_for_status()
  128. response = response.json()
  129. state["callback"]("\n")
  130. if "records" in response.keys():
  131. state["callback"]("## 该病案可能的诊断包括\n")
  132. if response["records"] and "nodes" in response["records"].keys():
  133. response_data = response["records"]["nodes"]
  134. for data in response_data:
  135. print(data)
  136. if "type" in data.keys():
  137. if data["type"] == "Disease":
  138. state["callback"]("- "+data["id"]+"("+data["type"]+","+str(round(data['count']*100,2))+"%)\n")
  139. state["callback"]("## 推荐的检查\n")
  140. if "neighbors" in response["records"].keys():
  141. response_data = response["records"]["neighbors"]
  142. for data in response_data:
  143. state["callback"]("- "+data["id"]+"("+str(round(data['count']*100,2))+"%)\n")
  144. data = {
  145. "q": " ".join(json_data["chief_complaint"]),
  146. "type":"Drug"
  147. }
  148. response = requests.get(f"{GRAPH_API_URL}/graph/nodes/neighbor_search?keyword={data['q']}&&neighbor_type={data['type']}",
  149. headers=headers)
  150. response.raise_for_status()
  151. response = response.json()
  152. state["callback"]("\n")
  153. if "records" in response.keys():
  154. state["callback"]("## 推荐的药物\n")
  155. if "neighbors" in response["records"].keys():
  156. response_data = response["records"]["neighbors"]
  157. for data in response_data:
  158. state["callback"]("- "+data["id"]+"("+str(round(data['count']*100,2))+"%)\n")
  159. print("recommed_check finished")
  160. #response = llm.invoke(state["messages"], config={"callbacks":[MyStreamingOutCallbackHandler()]})
  161. return state
  162. def should_continue_2_entity_extraction(state):
  163. print("should_continue")
  164. previous_message = state["messages"][-1]
  165. ai_resposne = previous_message.content or ""
  166. ai_resposne = ai_resposne.strip()
  167. print("should_continue",ai_resposne)
  168. if ai_resposne == "是":
  169. # 是医疗健康问题,继续执行工具节点
  170. return "continue"
  171. return "end"
  172. def tool_node(state):
  173. # 在此添加自定义工具逻辑
  174. print("tool_node")
  175. return {"tool_output": "Tool executed"}
  176. class MyMessageState(MessagesState):
  177. callback: Any = None
  178. workflow = StateGraph(MyMessageState)
  179. workflow.add_node("agent", agent_node)
  180. workflow.add_node("tools", tool_node)
  181. workflow.add_node("extract", entity_extraction_node)
  182. workflow.add_node("recommend_check", recommed_check)
  183. workflow.add_edge(START, "agent")
  184. #workflow.add_edge("agent", "tools")
  185. workflow.add_edge("extract", "recommend_check")
  186. workflow.add_edge("recommend_check", END)
  187. workflow.add_edge("tools", END)
  188. workflow.add_conditional_edges( "agent", should_continue_2_entity_extraction, {"continue":"extract", "end":END})
  189. app = workflow.compile()
  190. def test_langgraph(user_input):
  191. messages = [
  192. HumanMessage(content=user_input)
  193. ]
  194. response = app.invoke({"messages":messages})
  195. print(response)
  196. ######################## networkx
  197. #Streamlit界面
  198. st.set_page_config(layout="wide")
  199. tmp_text = ""
  200. def st_callback(text):
  201. global tmp_text
  202. tmp_text = tmp_text + text
  203. st.info(tmp_text)
  204. def submit_question(text):
  205. print("submit_question", text)
  206. user_input = None
  207. submit_button = None
  208. st.header("Med Graph Agent")
  209. if "history" not in st.session_state:
  210. st.session_state.history = []
  211. for message in st.session_state.history:
  212. with st.chat_message(message["role"]):
  213. st.write(message["content"])
  214. user_input = st.chat_input("请输入您的问题:")
  215. if user_input:
  216. messages = [
  217. HumanMessage(content=user_input)
  218. ]
  219. state = MyMessageState(messages=messages, callback=st_callback)
  220. st.session_state.history.append({"role": "user", "content": user_input})
  221. st.chat_message("user").write(user_input)
  222. placeholder = st.empty()
  223. with placeholder:
  224. st.info("thinking...")
  225. tmp_text = ""
  226. response = app.invoke(state)
  227. print(state["messages"])
  228. message = {"role": "assistant", "content": tmp_text}
  229. placeholder.empty()
  230. #st.write("### 回答:")
  231. st.session_state.history.append(message)
  232. with st.chat_message(message["role"]):
  233. st.write(message["content"])
  234. # left_col, right_col = st.columns([1,2])
  235. # with left_col:
  236. # st.header("Med Graph Agent")
  237. # user_input = st.text_area("请输入您的问题:",height=200)
  238. # submit_button = st.button("提交", on_click=submit_question, args=[user_input])
  239. # with right_col:
  240. # with st.container(height=800):
  241. # if "history" not in st.session_state:
  242. # st.session_state.history = []
  243. # for message in st.session_state.history:
  244. # with st.chat_message(message["role"]):
  245. # st.write(message["content"])
  246. # if submit_button:
  247. # messages = [
  248. # HumanMessage(content=user_input)
  249. # ]
  250. # state = MyMessageState(messages=messages, callback=st_callback)
  251. # st.session_state.history.append({"role": "user", "content": user_input})
  252. # st.chat_message("user").write(user_input)
  253. # placeholder = st.empty()
  254. # with placeholder:
  255. # st.info("thinking...")
  256. # tmp_text = ""
  257. # response = app.invoke(state)
  258. # print(state["messages"])
  259. # message = {"role": "assistant", "content": tmp_text}
  260. # placeholder.empty()
  261. # #st.write("### 回答:")
  262. # st.session_state.history.append(message)
  263. # with st.chat_message(message["role"]):
  264. # st.write(message["content"])