standard_kb_extractor.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #通过分析文章,生成分析结果
  2. import asyncio
  3. import os,sys
  4. current_path = os.getcwd()
  5. sys.path.append(current_path)
  6. import time
  7. import httpx
  8. import json
  9. from typing import List, Dict, AsyncGenerator
  10. import logging
  11. from dotenv import load_dotenv
  12. from typing import List, Dict, AsyncGenerator
  13. import re
  14. # 加载环境变量
  15. load_dotenv()
  16. logging.basicConfig(level=logging.INFO)
  17. logger = logging.getLogger(__name__)
  18. # DeepSeek API配置
  19. DEEPSEEK_API_URL = os.getenv("DEEPSEEK_API_URL")
  20. DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
  21. # 这里配置了每一轮抽取2个切片就暂停一下,这样可以响应任务状态的变更
  22. MAX_REQUEST_COUNT = 2
  23. def load_prompt(filename):
  24. '''加载提示词'''
  25. with open(filename, "r", encoding="utf-8") as f:
  26. return "".join(f.readlines())
  27. async def chat_with_llm(prompt: str):
  28. logger.info("chat with llm start")
  29. messages = []
  30. #messages.append({"role": "system", "content": prompt})
  31. messages.append({"role": "user", "content": prompt})
  32. headers = {
  33. "Authorization": f"Bearer {DEEPSEEK_API_KEY}",
  34. "Content-Type": "application/json; charset=utf-8"
  35. }
  36. data = {
  37. "model": "Pro/deepseek-ai/DeepSeek-V3", #deepseek-ai/DeepSeek-V3",
  38. "messages": messages,
  39. "temperature": 0.7,
  40. "max_tokens": 2000,
  41. # "tools":functions,
  42. # "tool_choice": "auto",
  43. "stream": True
  44. }
  45. logger.info(f"request llm")
  46. try:
  47. async with httpx.AsyncClient() as client:
  48. async with client.stream("POST", DEEPSEEK_API_URL, json=data, headers=headers, timeout=60) as response:
  49. response.raise_for_status()
  50. async for chunk in response.aiter_lines():
  51. if chunk:
  52. if chunk.startswith("data: "):
  53. json_data = chunk[6:]
  54. if json_data != "[DONE]":
  55. try:
  56. chunk_data = json.loads(json_data)
  57. if "choices" in chunk_data and chunk_data["choices"]:
  58. delta = chunk_data["choices"][0].get("delta", {})
  59. if "content" in delta:
  60. yield delta["content"]
  61. except json.JSONDecodeError:
  62. continue
  63. except httpx.RequestError as e:
  64. logger.error(f"Request llm with error: ",e)
  65. def generate_tasks(chunks_path: str, kb_path: str):
  66. #如果有任务文件,直接返回里面的数据
  67. if os.path.exists(os.path.join(kb_path,"kb_extract.json")):
  68. with open(os.path.join(kb_path,"kb_extract.json"),"r",encoding="utf-8") as task_f:
  69. task_data = json.loads(task_f.read())
  70. return task_data
  71. return
  72. with open(os.path.join(kb_path,"kb_extract.json"),"w",encoding="utf-8") as task_f:
  73. task_data = []
  74. index = 1
  75. for root,dirs,files in os.walk(chunks_path):
  76. for file in files:
  77. if file.endswith(".txt"):
  78. print(f"Processing {file}")
  79. buffer = []
  80. text = ""
  81. with open(os.path.join(root,file),"r",encoding="utf-8") as f:
  82. text = f.read()
  83. chunk_started = False
  84. for line in text.split("\n"):
  85. if line.strip()=="":
  86. continue
  87. if line.startswith("```txt"):
  88. text = line[6:]
  89. buffer = []
  90. chunk_started = True
  91. continue
  92. if line.startswith("```"):
  93. chunk_started = False
  94. chunk_text = "\n".join(buffer)
  95. buffer = []
  96. task_data.append({"index":index, "file":file,"chunk":chunk_text,"status":"waiting"})
  97. index = index + 1
  98. buffer.append(line)
  99. task_f.write(json.dumps(task_data, ensure_ascii=False,indent=4))
  100. return task_data
  101. def check_json_file_format(filename: str):
  102. """检查JSON文件格式是否正确"""
  103. try:
  104. with open(filename, 'r', encoding='utf-8') as f:
  105. content = f.read()
  106. buffer = []
  107. json_started = False
  108. found_json = True
  109. for line in content.split("\n"):
  110. if line.strip()=="":
  111. continue
  112. if line.startswith("```json"):
  113. buffer = []
  114. found_json = True
  115. json_started = True
  116. continue
  117. if line.startswith("```"):
  118. if json_started:
  119. json.loads("\n".join(buffer))
  120. json_started = False
  121. buffer.append(line)
  122. if found_json:
  123. return True
  124. return False
  125. except json.JSONDecodeError as e:
  126. logger.info(f"JSON格式错误: {e}")
  127. return False
  128. if __name__ == "__main__":
  129. if len(sys.argv) != 2:
  130. print("Usage: python standard_kb_extractor.py <path_of_job>")
  131. sys.exit(-1)
  132. #检查路径是否正确
  133. job_path = sys.argv[1]
  134. if not os.path.exists(job_path):
  135. print(f"job path not exists: {job_path}")
  136. sys.exit(-1)
  137. chunks_path = os.path.join(job_path,"chunks")
  138. if not os.path.exists(chunks_path):
  139. print(f"chunks path not exists: {chunks_path}")
  140. sys.exit(-1)
  141. kb_path = os.path.join(job_path,"kb_extract")
  142. os.makedirs(kb_path ,exist_ok=True)
  143. #初始化日志
  144. log_path = os.path.join(job_path,"logs")
  145. print(f"log path: {log_path}")
  146. handler = logging.FileHandler(f"{log_path}/kb_extractor.log", mode='a',encoding="utf-8")
  147. handler.setLevel(logging.INFO)
  148. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  149. handler.setFormatter(formatter)
  150. logging.getLogger().addHandler(handler)
  151. logger = logging.getLogger(__name__)
  152. #加载提示词
  153. prompt_file = os.path.join("/".join(re.split(r"[\\/]",__file__)[:-1]),"prompt/standard_med.txt")
  154. logger.info(f"load prompt from {prompt_file}")
  155. prompt_template = load_prompt(prompt_file)
  156. #加载或者生成任务清单
  157. task_data = generate_tasks(chunks_path,kb_path)
  158. count_down = MAX_REQUEST_COUNT
  159. for item in task_data:
  160. result_file = os.path.join(kb_path,f"{item['index']}.txt")
  161. if os.path.exists(result_file):
  162. if check_json_file_format(filename=result_file):
  163. logger.info(f"{result_file} exists and format is valid, skip")
  164. continue
  165. else:
  166. logger.info(f"{result_file} exists but format is invalid, remove it and retry")
  167. os.remove(result_file)
  168. logger.info(f"Processing {item['file']}, index: {item['index']}")
  169. full_request = prompt_template + item["chunk"] #.format(text=chunk_text)
  170. try:
  171. buffer = []
  172. async def run_chat():
  173. async for content in chat_with_llm(full_request):
  174. buffer.append(content)
  175. print(content,end="")
  176. asyncio.run(run_chat())
  177. response_txt = "".join(buffer)
  178. with open(os.path.join(kb_path,f"{item['index']}.txt"),"w",encoding="utf-8") as f:
  179. f.write("```txt\n")
  180. f.write(item["chunk"])
  181. f.write("```\n\n")
  182. f.write("```result\n")
  183. f.write(response_txt)
  184. f.write("\n```\n")
  185. f.flush()
  186. count_down = count_down - 1
  187. if count_down == 0:
  188. logger.info("reach max request count, stop and wait for retry")
  189. sys.exit(1)
  190. # response_json = chat_with_llm(full_request)
  191. # if response_json is None:
  192. # logger.error("Error: response is None")
  193. # sys.exit(1)
  194. # for choice in response_json["choices"]:
  195. # response_txt = choice["message"]["content"]
  196. time.sleep(2)
  197. except Exception as e:
  198. logger.error(f"Error: {e}")
  199. sys.exit(1) #EXIT with RETRY CODE