deepseek_chat.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # coding=utf-8
  2. import os
  3. import gc
  4. from dotenv import load_dotenv
  5. import httpx
  6. from typing import List, Dict, AsyncGenerator
  7. from libs.text_processor import TextProcessor
  8. from utils.es import ElasticsearchOperations
  9. import json
  10. from fastapi import FastAPI, HTTPException, Request
  11. from fastapi.middleware.cors import CORSMiddleware
  12. from fastapi.responses import StreamingResponse
  13. from pydantic import BaseModel
  14. from functions.call import generate_response_with_function_call, parse_function_call
  15. from functions.basic_function import basic_functions
  16. from openai import OpenAI
  17. import codecs
  18. import psutil
  19. #import chardet
  20. # 加载环境变量
  21. load_dotenv()
  22. # DeepSeek API配置
  23. DEEPSEEK_API_URL = os.getenv("DEEPSEEK_API_URL")
  24. DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
  25. async def chat_with_openai(prompt: List) -> AsyncGenerator[List, None]:
  26. client = OpenAI(api_key="sk-3894637f410a4653bbdc27fc86ddebc8", base_url="https://api.deepseek.com")
  27. response = client.chat.completions.create(
  28. model="deepseek-chat",
  29. messages=prompt,
  30. stream=True
  31. )
  32. print(response)
  33. for item in response:
  34. yield "data: " + json.dumps(item)+"\n\n"
  35. print(item.choices[0].delta.content)
  36. async def chat_with_deepseek(prompt: List) -> AsyncGenerator[List, None]:
  37. print(">>> start chat_with_deepseek ")
  38. print(prompt)
  39. new_prompt = prompt
  40. """与DeepSeek模型进行流式对话"""
  41. headers = {
  42. "Authorization": f"Bearer {DEEPSEEK_API_KEY}",
  43. "Content-Type": "application/json; charset=utf-8"
  44. }
  45. data = {
  46. "model":"Pro/deepseek-ai/DeepSeek-V3", #deepseek-ai/DeepSeek-V3",
  47. #"model": "Pro/deepseek-ai/DeepSeek-V3",
  48. "messages": new_prompt,
  49. "temperature": 0.7,
  50. "max_tokens": 4000,
  51. "stream": True
  52. }
  53. try:
  54. async with httpx.AsyncClient() as client:
  55. async with client.stream("POST", DEEPSEEK_API_URL, json=data, headers=headers) as response:
  56. print(response)
  57. response.raise_for_status()
  58. async for chunk in response.aiter_lines():
  59. if chunk:
  60. yield chunk+"\n\n"
  61. # if chunk.startswith("data: "):
  62. # json_data = chunk[6:]
  63. # if json_data != "[DONE]":
  64. # try:
  65. # chunk_data = json.loads(json_data)
  66. # if "choices" in chunk_data and chunk_data["choices"]:
  67. # delta = chunk_data["choices"][0].get("delta", {})
  68. # if "content" in delta:
  69. # yield delta["content"]
  70. # except json.JSONDecodeError:
  71. # continue
  72. except httpx.RequestError as e:
  73. print(f"Error: ",e)
  74. del data
  75. del headers
  76. app = FastAPI()
  77. # 允许所有来源的跨域请求
  78. app.add_middleware(
  79. CORSMiddleware,
  80. allow_origins=["*"],
  81. allow_credentials=True,
  82. allow_methods=["*"],
  83. allow_headers=["*"],
  84. )
  85. class ChatMessage(BaseModel):
  86. role: str
  87. content: str
  88. class ChatRequest(BaseModel):
  89. messages: List[ChatMessage]
  90. @app.get("/")
  91. def hello():
  92. return "hello"
  93. @app.post("/chat")
  94. async def chat_endpoint(request: ChatRequest):
  95. gc.collect()
  96. process = psutil.Process()
  97. print(">>> start chat_endpoint ")
  98. user_input = ""
  99. user_messages = []
  100. for msg in request.messages:
  101. user_messages.append({'role':msg.role, 'content':msg.content})
  102. last_message = user_messages[-1]
  103. user_input = last_message['content']
  104. if not user_input or user_input.strip() == "":
  105. raise HTTPException(status_code=400, detail="Message cannot be empty")
  106. prompt_text = []
  107. prompt_text.append(user_input)
  108. print(">>> user_input ", prompt_text)
  109. # if len(user_input) > 4:
  110. # results = es_ops.search_similar_texts(user_input)
  111. # for result in results:
  112. # if result['score'] > 1.8:
  113. # prompt_text.append(result['text'])
  114. # if len(prompt_text) > 0:
  115. # prompt_text = "\n\n".join(prompt_text)
  116. # prompt_text ="'''doc\n"+ prompt_text + "'''\n"
  117. # prompt_text = f"请基于以下的文档内容回复问题\n\n{prompt_text}\n\n{user_input}"
  118. # else:
  119. # prompt_text = user_input
  120. print(process.memory_info().rss)
  121. first_response = generate_response_with_function_call(functions=basic_functions, user_input=prompt_text)
  122. print(process.memory_info().rss)
  123. if 'choices' in first_response.keys():
  124. if 'tool_calls' in first_response['choices'][0]['message'].keys():
  125. choice = first_response['choices'][0]
  126. print(">>> function call response : ",choice['message'])
  127. #user_messages = user_messages + [choice['message']]
  128. call_result = parse_function_call(first_response, user_messages)
  129. if call_result['result'] != "":
  130. result_text = codecs.encode(call_result['result'], "utf-8")
  131. result_text = codecs.decode(result_text, "utf-8")
  132. user_messages = [{
  133. "role": "user",
  134. "content": f"以下是你的参考内容,请基于这些内容进行问题回答问题:【{user_input}】\n\n```doc\n{result_text}\n```",
  135. #"tool_call_id":choice['message']['tool_calls'][0]['id']
  136. }]
  137. # user_messages.append({
  138. # "role": "user",
  139. # "content": f"以下是你的参考内容,请基于这些内容进行问题回答:\n```doc\n{result_text}\n```",
  140. # #"tool_call_id":choice['message']['tool_calls'][0]['id']
  141. # })
  142. print(process.memory_info().rss)
  143. async def generate_response():
  144. async for chunk in chat_with_deepseek(user_messages):
  145. print(">>> ", chunk)
  146. yield chunk
  147. #yield json.dumps({"content": chunk}) + "\n"
  148. return StreamingResponse(generate_response(), media_type="application/json")
  149. if __name__ == "__main__":
  150. import uvicorn
  151. uvicorn.run(app, host="0.0.0.0", port=8000)