async_stream_service.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import asyncio
  2. import json
  3. import uuid
  4. import traceback
  5. from typing import AsyncIterable, Optional
  6. from langchain_core.messages import HumanMessage
  7. from core.callback import AsyncVentCallback
  8. from utils.pdf_builder import markdown_to_pdf, get_pdf_download_url
  9. from config.settings import (
  10. MAX_AGENT_ITERATIONS,
  11. REPORT_OUTPUT_DIR,
  12. STREAM_BATCH_SIZE,
  13. STREAM_BATCH_DELAY,
  14. )
  15. from core.agent_builder import review_agent, calculation_agent
  16. from service.session_manager import session_manager
  17. from service.history_manager import get_or_create_title, append_line
  18. async def _stream_text_chunks(queue: asyncio.Queue, text: str, event_type: str = "model_thinking") -> None:
  19. """Split text into chunks and push them to the async queue as SSE data frames."""
  20. for start in range(0, len(text), STREAM_BATCH_SIZE):
  21. chunk = text[start:start + STREAM_BATCH_SIZE]
  22. data = json.dumps({"type": event_type, "content": chunk}, ensure_ascii=False)
  23. await queue.put(f"data: {data}\n\n")
  24. await asyncio.sleep(STREAM_BATCH_DELAY)
  25. async def _yield_with_heartbeat(queue: asyncio.Queue, timeout_seconds: int = 120) -> AsyncIterable[str]:
  26. """Yield items from the queue, sending a heartbeat on timeout."""
  27. while True:
  28. try:
  29. yield await asyncio.wait_for(queue.get(), timeout=timeout_seconds)
  30. except asyncio.TimeoutError:
  31. timeout_data = json.dumps({"type": "timeout", "content": "处理中..."}, ensure_ascii=False)
  32. yield f"data: {timeout_data}\n\n"
  33. async def stream_review_async(message: str, file_name: str, parse_data: dict, session_id: Optional[str]) -> AsyncIterable[str]:
  34. queue = asyncio.Queue(1000)
  35. session_id = session_id or str(uuid.uuid4())
  36. user_content = f"{message} {file_name} \n {parse_data}"
  37. # 不存在文件→大模型生成标题写入首行,存在直接读取标题
  38. get_or_create_title(session_id, user_content)
  39. callback = AsyncVentCallback(queue, session_id)
  40. all_responses = []
  41. full_report = ""
  42. # 用户消息落地日志
  43. append_line(session_id, "user_message", f"{message} {file_name}")
  44. async def run():
  45. nonlocal full_report
  46. try:
  47. state = session_manager.get_or_create(session_id)
  48. state.messages.append({"role": "user", "content": user_content})
  49. config = {
  50. "callbacks": [callback],
  51. "max_iterations": MAX_AGENT_ITERATIONS,
  52. "recursion_limit": 20
  53. }
  54. async for chunk in review_agent.astream(
  55. {"messages": [HumanMessage(content=user_content)]},
  56. stream_mode="updates",
  57. config=config
  58. ):
  59. for node, update in chunk.items():
  60. for msg in update.get("messages", []):
  61. if msg.type == "ai" and msg.content:
  62. all_responses.append(msg.content)
  63. await _stream_text_chunks(queue, msg.content)
  64. append_line(session_id, "model_thinking", msg.content)
  65. if all_responses:
  66. full_report = all_responses[-1].strip()
  67. state.messages.append({"role": "assistant", "content": full_report})
  68. session_manager.update(session_id, state)
  69. if full_report:
  70. output_filename = f"{uuid.uuid4()}.pdf"
  71. output_path = f"{REPORT_OUTPUT_DIR}/{output_filename}"
  72. markdown_to_pdf(full_report, output_path)
  73. download_url = get_pdf_download_url(output_filename)
  74. pdf_data = json.dumps({"type": "pdf_download", "content": download_url}, ensure_ascii=False)
  75. append_line(session_id, "pdf_download", download_url)
  76. await queue.put(f"data: {pdf_data}\n\n")
  77. done_data = json.dumps({"type": "done", "content": "审查完成", "session_id": session_id}, ensure_ascii=False)
  78. await queue.put(f"data: {done_data}\n\n")
  79. except Exception as e:
  80. err_msg = str(e) + "\n" + traceback.format_exc()[:800]
  81. append_line(session_id, "error", err_msg)
  82. err_data = json.dumps({"type": "error", "content": err_msg}, ensure_ascii=False)
  83. print(err_msg)
  84. await queue.put(f"data: {err_data}\n\n")
  85. asyncio.create_task(run())
  86. async for item in _yield_with_heartbeat(queue):
  87. yield item
  88. async def stream_calculation_async(user_msg: str, session_id: Optional[str]) -> AsyncIterable[str]:
  89. queue = asyncio.Queue(1000)
  90. session_id = session_id or str(uuid.uuid4())
  91. get_or_create_title(session_id, user_msg)
  92. callback = AsyncVentCallback(queue, session_id)
  93. last_ai_content = ""
  94. append_line(session_id, "user_message", user_msg)
  95. async def run():
  96. nonlocal last_ai_content
  97. try:
  98. state = session_manager.get_or_create(session_id)
  99. state.messages.append({"role": "user", "content": user_msg})
  100. config = {
  101. "callbacks": [callback],
  102. "max_iterations": MAX_AGENT_ITERATIONS,
  103. "recursion_limit": 20
  104. }
  105. async for chunk in calculation_agent.astream(
  106. {"messages": [HumanMessage(content=user_msg)]},
  107. stream_mode="updates",
  108. config=config
  109. ):
  110. for node, update in chunk.items():
  111. for msg in update.get("messages", []):
  112. if msg.type == "ai" and msg.content:
  113. last_ai_content = msg.content
  114. await _stream_text_chunks(queue, msg.content)
  115. append_line(session_id, "model_thinking", msg.content)
  116. state.messages.append({"role": "assistant", "content": last_ai_content})
  117. session_manager.update(session_id, state)
  118. done_data = json.dumps({"type": "done", "content": "对话结束", "session_id": session_id}, ensure_ascii=False)
  119. await queue.put(f"data: {done_data}\n\n")
  120. except Exception as e:
  121. err_msg = str(e) + "\n" + traceback.format_exc()[:800]
  122. append_line(session_id, "error", err_msg)
  123. err_data = json.dumps({"type": "error", "content": err_msg}, ensure_ascii=False)
  124. print(err_msg)
  125. await queue.put(f"data: {err_data}\n\n")
  126. asyncio.create_task(run())
  127. async for item in _yield_with_heartbeat(queue):
  128. yield item