123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- #通过分析文章,生成分析结果
- import asyncio
- import os,sys
- current_path = os.getcwd()
- sys.path.append(current_path)
- import time
- import httpx
- import json
- from typing import List, Dict, AsyncGenerator
- import logging
- from dotenv import load_dotenv
- from typing import List, Dict, AsyncGenerator
- import re
- # 加载环境变量
- load_dotenv()
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- # DeepSeek API配置
- DEEPSEEK_API_URL = os.getenv("DEEPSEEK_API_URL")
- DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
- # 这里配置了每一轮抽取2个切片就暂停一下,这样可以响应任务状态的变更
- MAX_REQUEST_COUNT = 2
- def load_prompt(filename):
- '''加载提示词'''
- with open(filename, "r", encoding="utf-8") as f:
- return "".join(f.readlines())
-
- async def chat_with_llm(prompt: str):
- logger.info("chat with llm start")
- messages = []
- #messages.append({"role": "system", "content": prompt})
- messages.append({"role": "user", "content": prompt})
- headers = {
- "Authorization": f"Bearer {DEEPSEEK_API_KEY}",
- "Content-Type": "application/json; charset=utf-8"
- }
-
- data = {
- "model": "Pro/deepseek-ai/DeepSeek-V3", #deepseek-ai/DeepSeek-V3",
- "messages": messages,
- "temperature": 0.7,
- "max_tokens": 2000,
- # "tools":functions,
- # "tool_choice": "auto",
- "stream": True
- }
- logger.info(f"request llm")
- try:
- async with httpx.AsyncClient() as client:
- async with client.stream("POST", DEEPSEEK_API_URL, json=data, headers=headers, timeout=60) as response:
- response.raise_for_status()
- async for chunk in response.aiter_lines():
- if chunk:
- if chunk.startswith("data: "):
- json_data = chunk[6:]
- if json_data != "[DONE]":
- try:
- chunk_data = json.loads(json_data)
- if "choices" in chunk_data and chunk_data["choices"]:
- delta = chunk_data["choices"][0].get("delta", {})
- if "content" in delta:
- yield delta["content"]
- except json.JSONDecodeError:
- continue
- except httpx.RequestError as e:
- logger.error(f"Request llm with error: ",e)
- def generate_tasks(chunks_path: str, kb_path: str):
- #如果有任务文件,直接返回里面的数据
- if os.path.exists(os.path.join(kb_path,"kb_extract.json")):
- with open(os.path.join(kb_path,"kb_extract.json"),"r",encoding="utf-8") as task_f:
- task_data = json.loads(task_f.read())
- return task_data
- return
- with open(os.path.join(kb_path,"kb_extract.json"),"w",encoding="utf-8") as task_f:
- task_data = []
- index = 1
- for root,dirs,files in os.walk(chunks_path):
- for file in files:
- if file.endswith(".txt"):
- print(f"Processing {file}")
- buffer = []
- text = ""
- with open(os.path.join(root,file),"r",encoding="utf-8") as f:
- text = f.read()
- chunk_started = False
- for line in text.split("\n"):
- if line.strip()=="":
- continue
- if line.startswith("```txt"):
- text = line[6:]
- buffer = []
- chunk_started = True
- continue
- if line.startswith("```"):
- chunk_started = False
- chunk_text = "\n".join(buffer)
- buffer = []
- task_data.append({"index":index, "file":file,"chunk":chunk_text,"status":"waiting"})
- index = index + 1
- buffer.append(line)
- task_f.write(json.dumps(task_data, ensure_ascii=False,indent=4))
- return task_data
- def check_json_file_format(filename: str):
- """检查JSON文件格式是否正确"""
- try:
- with open(filename, 'r', encoding='utf-8') as f:
- content = f.read()
- buffer = []
- json_started = False
- found_json = True
- for line in content.split("\n"):
- if line.strip()=="":
- continue
- if line.startswith("```json"):
- buffer = []
- found_json = True
- json_started = True
- continue
- if line.startswith("```"):
- if json_started:
- json.loads("\n".join(buffer))
- json_started = False
- buffer.append(line)
- if found_json:
- return True
- return False
- except json.JSONDecodeError as e:
- logger.info(f"JSON格式错误: {e}")
- return False
-
- if __name__ == "__main__":
- if len(sys.argv) != 2:
- print("Usage: python standard_kb_extractor.py <path_of_job>")
- sys.exit(-1)
- #检查路径是否正确
- job_path = sys.argv[1]
- if not os.path.exists(job_path):
- print(f"job path not exists: {job_path}")
- sys.exit(-1)
- chunks_path = os.path.join(job_path,"chunks")
- if not os.path.exists(chunks_path):
- print(f"chunks path not exists: {chunks_path}")
- sys.exit(-1)
- kb_path = os.path.join(job_path,"kb_extract")
- os.makedirs(kb_path ,exist_ok=True)
- #初始化日志
- log_path = os.path.join(job_path,"logs")
- print(f"log path: {log_path}")
- handler = logging.FileHandler(f"{log_path}/kb_extractor.log", mode='a',encoding="utf-8")
- handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logging.getLogger().addHandler(handler)
- logger = logging.getLogger(__name__)
- #加载提示词
- prompt_file = os.path.join("/".join(re.split(r"[\\/]",__file__)[:-1]),"prompt/standard_med.txt")
- logger.info(f"load prompt from {prompt_file}")
- prompt_template = load_prompt(prompt_file)
- #加载或者生成任务清单
- task_data = generate_tasks(chunks_path,kb_path)
- count_down = MAX_REQUEST_COUNT
- for item in task_data:
- result_file = os.path.join(kb_path,f"{item['index']}.txt")
- if os.path.exists(result_file):
- if check_json_file_format(filename=result_file):
- logger.info(f"{result_file} exists and format is valid, skip")
- continue
- else:
- logger.info(f"{result_file} exists but format is invalid, remove it and retry")
- os.remove(result_file)
- logger.info(f"Processing {item['file']}, index: {item['index']}")
- full_request = prompt_template + item["chunk"] #.format(text=chunk_text)
- try:
- buffer = []
- async def run_chat():
- async for content in chat_with_llm(full_request):
- buffer.append(content)
- print(content,end="")
- asyncio.run(run_chat())
- response_txt = "".join(buffer)
- with open(os.path.join(kb_path,f"{item['index']}.txt"),"w",encoding="utf-8") as f:
- f.write("```txt\n")
- f.write(item["chunk"])
- f.write("```\n\n")
- f.write("```result\n")
- f.write(response_txt)
- f.write("\n```\n")
- f.flush()
- count_down = count_down - 1
- if count_down == 0:
- logger.info("reach max request count, stop and wait for retry")
- sys.exit(1)
- # response_json = chat_with_llm(full_request)
- # if response_json is None:
- # logger.error("Error: response is None")
- # sys.exit(1)
- # for choice in response_json["choices"]:
- # response_txt = choice["message"]["content"]
- time.sleep(2)
- except Exception as e:
- logger.error(f"Error: {e}")
- sys.exit(1) #EXIT with RETRY CODE
|