#通过分析文章,生成分析结果 import asyncio import os,sys current_path = os.getcwd() sys.path.append(current_path) import time import httpx import json from typing import List, Dict, AsyncGenerator import logging from dotenv import load_dotenv from typing import List, Dict, AsyncGenerator import re # 加载环境变量 load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # DeepSeek API配置 DEEPSEEK_API_URL = os.getenv("DEEPSEEK_API_URL") DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY") # 这里配置了每一轮抽取2个切片就暂停一下,这样可以响应任务状态的变更 MAX_REQUEST_COUNT = 2 def load_prompt(filename): '''加载提示词''' with open(filename, "r", encoding="utf-8") as f: return "".join(f.readlines()) async def chat_with_llm(prompt: str): logger.info("chat with llm start") messages = [] #messages.append({"role": "system", "content": prompt}) messages.append({"role": "user", "content": prompt}) headers = { "Authorization": f"Bearer {DEEPSEEK_API_KEY}", "Content-Type": "application/json; charset=utf-8" } data = { "model": "Pro/deepseek-ai/DeepSeek-V3", #deepseek-ai/DeepSeek-V3", "messages": messages, "temperature": 0.7, "max_tokens": 2000, # "tools":functions, # "tool_choice": "auto", "stream": True } logger.info(f"request llm") try: async with httpx.AsyncClient() as client: async with client.stream("POST", DEEPSEEK_API_URL, json=data, headers=headers, timeout=60) as response: response.raise_for_status() async for chunk in response.aiter_lines(): if chunk: if chunk.startswith("data: "): json_data = chunk[6:] if json_data != "[DONE]": try: chunk_data = json.loads(json_data) if "choices" in chunk_data and chunk_data["choices"]: delta = chunk_data["choices"][0].get("delta", {}) if "content" in delta: yield delta["content"] except json.JSONDecodeError: continue except httpx.RequestError as e: logger.error(f"Request llm with error: ",e) def generate_tasks(chunks_path: str, kb_path: str): #如果有任务文件,直接返回里面的数据 if os.path.exists(os.path.join(kb_path,"kb_extract.json")): with open(os.path.join(kb_path,"kb_extract.json"),"r",encoding="utf-8") as task_f: task_data = json.loads(task_f.read()) return task_data return with open(os.path.join(kb_path,"kb_extract.json"),"w",encoding="utf-8") as task_f: task_data = [] index = 1 for root,dirs,files in os.walk(chunks_path): for file in files: if file.endswith(".txt"): print(f"Processing {file}") buffer = [] text = "" with open(os.path.join(root,file),"r",encoding="utf-8") as f: text = f.read() chunk_started = False for line in text.split("\n"): if line.strip()=="": continue if line.startswith("```txt"): text = line[6:] buffer = [] chunk_started = True continue if line.startswith("```"): chunk_started = False chunk_text = "\n".join(buffer) buffer = [] task_data.append({"index":index, "file":file,"chunk":chunk_text,"status":"waiting"}) index = index + 1 buffer.append(line) task_f.write(json.dumps(task_data, ensure_ascii=False,indent=4)) return task_data def check_json_file_format(filename: str): """检查JSON文件格式是否正确""" try: with open(filename, 'r', encoding='utf-8') as f: content = f.read() buffer = [] json_started = False found_json = True for line in content.split("\n"): if line.strip()=="": continue if line.startswith("```json"): buffer = [] found_json = True json_started = True continue if line.startswith("```"): if json_started: json.loads("\n".join(buffer)) json_started = False buffer.append(line) if found_json: return True return False except json.JSONDecodeError as e: logger.info(f"JSON格式错误: {e}") return False if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: python standard_kb_extractor.py ") sys.exit(-1) #检查路径是否正确 job_path = sys.argv[1] if not os.path.exists(job_path): print(f"job path not exists: {job_path}") sys.exit(-1) chunks_path = os.path.join(job_path,"chunks") if not os.path.exists(chunks_path): print(f"chunks path not exists: {chunks_path}") sys.exit(-1) kb_path = os.path.join(job_path,"kb_extract") os.makedirs(kb_path ,exist_ok=True) #初始化日志 log_path = os.path.join(job_path,"logs") print(f"log path: {log_path}") handler = logging.FileHandler(f"{log_path}/kb_extractor.log", mode='a',encoding="utf-8") handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logging.getLogger().addHandler(handler) logger = logging.getLogger(__name__) #加载提示词 prompt_file = os.path.join("/".join(re.split(r"[\\/]",__file__)[:-1]),"prompt/standard_med.txt") logger.info(f"load prompt from {prompt_file}") prompt_template = load_prompt(prompt_file) #加载或者生成任务清单 task_data = generate_tasks(chunks_path,kb_path) count_down = MAX_REQUEST_COUNT for item in task_data: result_file = os.path.join(kb_path,f"{item['index']}.txt") if os.path.exists(result_file): if check_json_file_format(filename=result_file): logger.info(f"{result_file} exists and format is valid, skip") continue else: logger.info(f"{result_file} exists but format is invalid, remove it and retry") os.remove(result_file) logger.info(f"Processing {item['file']}, index: {item['index']}") full_request = prompt_template + item["chunk"] #.format(text=chunk_text) try: buffer = [] async def run_chat(): async for content in chat_with_llm(full_request): buffer.append(content) print(content,end="") asyncio.run(run_chat()) response_txt = "".join(buffer) with open(os.path.join(kb_path,f"{item['index']}.txt"),"w",encoding="utf-8") as f: f.write("```txt\n") f.write(item["chunk"]) f.write("```\n\n") f.write("```result\n") f.write(response_txt) f.write("\n```\n") f.flush() count_down = count_down - 1 if count_down == 0: logger.info("reach max request count, stop and wait for retry") sys.exit(1) # response_json = chat_with_llm(full_request) # if response_json is None: # logger.error("Error: response is None") # sys.exit(1) # for choice in response_json["choices"]: # response_txt = choice["message"]["content"] time.sleep(2) except Exception as e: logger.error(f"Error: {e}") sys.exit(1) #EXIT with RETRY CODE