You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

448 lines
14 KiB
Python

6 months ago
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
智能文档问答助手 - 基于HelloAgents的智能文档问答系统
这是一个完整的PDF学习助手应用支持
- 加载PDF文档并构建知识库
- 智能问答基于RAG
- 学习历程记录基于Memory
- 学习回顾和报告生成
"""
import os
import time
import json
from datetime import datetime
from typing import Dict, List, Optional, Any, Tuple
from hello_agents.tools import MemoryTool, RAGTool
import gradio as gr
class PDFLearningAssistant:
"""智能文档问答助手"""
def __init__(self, user_id: str = "default_user"):
"""初始化学习助手
Args:
user_id: 用户ID用于隔离不同用户的数据
"""
self.user_id = user_id
self.session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# 初始化工具
self.memory_tool = MemoryTool(user_id=user_id)
self.rag_tool = RAGTool(rag_namespace=f"pdf_{user_id}")
# 学习统计
self.stats = {
"session_start": datetime.now(),
"documents_loaded": 0,
"questions_asked": 0,
"concepts_learned": 0
}
# 当前加载的文档
self.current_document = None
def load_document(self, pdf_path: str) -> Dict[str, Any]:
"""加载PDF文档到知识库
Args:
pdf_path: PDF文件路径
Returns:
Dict: 包含success和message的结果
"""
if not os.path.exists(pdf_path):
return {"success": False, "message": f"文件不存在: {pdf_path}"}
start_time = time.time()
try:
# 使用RAG工具处理PDF
result = self.rag_tool.execute(
"add_document",
file_path=pdf_path,
chunk_size=1000,
chunk_overlap=200
)
process_time = time.time() - start_time
# RAG工具返回的是字符串消息
self.current_document = os.path.basename(pdf_path)
self.stats["documents_loaded"] += 1
# 记录到学习记忆
self.memory_tool.execute(
"add",
content=f"加载了文档《{self.current_document}",
memory_type="episodic",
importance=0.9,
event_type="document_loaded",
session_id=self.session_id
)
return {
"success": True,
"message": f"加载成功!(耗时: {process_time:.1f}秒)",
"document": self.current_document
}
except Exception as e:
return {
"success": False,
"message": f"加载失败: {str(e)}"
}
def ask(self, question: str, use_advanced_search: bool = True) -> str:
"""向文档提问
Args:
question: 用户问题
use_advanced_search: 是否使用高级检索MQE + HyDE
Returns:
str: 答案
"""
if not self.current_document:
return "⚠️ 请先加载文档!使用 load_document() 方法加载PDF文档。"
# 记录问题到工作记忆
self.memory_tool.execute(
"add",
content=f"提问: {question}",
memory_type="working",
importance=0.6,
session_id=self.session_id
)
# 使用RAG检索答案
answer = self.rag_tool.execute(
"ask",
question=question,
limit=5,
enable_advanced_search=use_advanced_search,
enable_mqe=use_advanced_search,
enable_hyde=use_advanced_search
)
# 记录到情景记忆
self.memory_tool.execute(
"add",
content=f"关于'{question}'的学习",
memory_type="episodic",
importance=0.7,
event_type="qa_interaction",
session_id=self.session_id
)
self.stats["questions_asked"] += 1
return answer
def add_note(self, content: str, concept: Optional[str] = None):
"""添加学习笔记
Args:
content: 笔记内容
concept: 相关概念可选
"""
self.memory_tool.execute(
"add",
content=content,
memory_type="semantic",
importance=0.8,
concept=concept or "general",
session_id=self.session_id
)
self.stats["concepts_learned"] += 1
def recall(self, query: str, limit: int = 5) -> str:
"""回顾学习历程
Args:
query: 查询关键词
limit: 返回结果数量
Returns:
str: 相关记忆
"""
result = self.memory_tool.execute(
"search",
query=query,
limit=limit
)
return result
def get_stats(self) -> Dict[str, Any]:
"""获取学习统计
Returns:
Dict: 统计信息
"""
duration = (datetime.now() - self.stats["session_start"]).total_seconds()
return {
"会话时长": f"{duration:.0f}",
"加载文档": self.stats["documents_loaded"],
"提问次数": self.stats["questions_asked"],
"学习笔记": self.stats["concepts_learned"],
"当前文档": self.current_document or "未加载"
}
def generate_report(self, save_to_file: bool = True) -> Dict[str, Any]:
"""生成学习报告
Args:
save_to_file: 是否保存到文件
Returns:
Dict: 学习报告
"""
# 获取记忆摘要
memory_summary = self.memory_tool.execute("summary", limit=10)
# 获取RAG统计
rag_stats = self.rag_tool.execute("stats")
# 生成报告
duration = (datetime.now() - self.stats["session_start"]).total_seconds()
report = {
"session_info": {
"session_id": self.session_id,
"user_id": self.user_id,
"start_time": self.stats["session_start"].isoformat(),
"duration_seconds": duration
},
"learning_metrics": {
"documents_loaded": self.stats["documents_loaded"],
"questions_asked": self.stats["questions_asked"],
"concepts_learned": self.stats["concepts_learned"]
},
"memory_summary": memory_summary,
"rag_status": rag_stats
}
# 保存到文件
if save_to_file:
report_file = f"learning_report_{self.session_id}.json"
try:
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(report, f, ensure_ascii=False, indent=2, default=str)
report["report_file"] = report_file
except Exception as e:
report["save_error"] = str(e)
return report
def create_gradio_ui():
"""创建Gradio Web UI"""
# 全局助手实例
assistant_state = {"assistant": None}
def init_assistant(user_id: str) -> str:
"""初始化助手"""
if not user_id:
user_id = "web_user"
assistant_state["assistant"] = PDFLearningAssistant(user_id=user_id)
return f"✅ 助手已初始化 (用户: {user_id})"
def load_pdf(pdf_file) -> str:
"""加载PDF文件"""
if assistant_state["assistant"] is None:
return "❌ 请先初始化助手"
if pdf_file is None:
return "❌ 请上传PDF文件"
# Gradio上传的文件是临时文件对象
pdf_path = pdf_file.name
result = assistant_state["assistant"].load_document(pdf_path)
if result["success"]:
return f"{result['message']}\n📄 文档: {result['document']}"
else:
return f"{result['message']}"
def chat(message: str, history: List) -> Tuple[str, List]:
"""聊天功能"""
if assistant_state["assistant"] is None:
return "", history + [[message, "❌ 请先初始化助手并加载文档"]]
if not message.strip():
return "", history
# 判断是技术问题还是回顾问题
if any(keyword in message for keyword in ["之前", "学过", "回顾", "历史", "记得"]):
# 回顾学习历程
response = assistant_state["assistant"].recall(message)
response = f"🧠 **学习回顾**\n\n{response}"
else:
# 技术问答
response = assistant_state["assistant"].ask(message)
response = f"💡 **回答**\n\n{response}"
history.append([message, response])
return "", history
def add_note_ui(note_content: str, concept: str) -> str:
"""添加笔记"""
if assistant_state["assistant"] is None:
return "❌ 请先初始化助手"
if not note_content.strip():
return "❌ 笔记内容不能为空"
assistant_state["assistant"].add_note(note_content, concept or None)
return f"✅ 笔记已保存: {note_content[:50]}..."
def get_stats_ui() -> str:
"""获取统计信息"""
if assistant_state["assistant"] is None:
return "❌ 请先初始化助手"
stats = assistant_state["assistant"].get_stats()
result = "📊 **学习统计**\n\n"
for key, value in stats.items():
result += f"- **{key}**: {value}\n"
return result
def generate_report_ui() -> str:
"""生成报告"""
if assistant_state["assistant"] is None:
return "❌ 请先初始化助手"
report = assistant_state["assistant"].generate_report(save_to_file=True)
result = f"✅ 学习报告已生成\n\n"
result += f"**会话信息**\n"
result += f"- 会话时长: {report['session_info']['duration_seconds']:.0f}\n"
result += f"- 加载文档: {report['learning_metrics']['documents_loaded']}\n"
result += f"- 提问次数: {report['learning_metrics']['questions_asked']}\n"
result += f"- 学习笔记: {report['learning_metrics']['concepts_learned']}\n"
if "report_file" in report:
result += f"\n💾 报告已保存至: {report['report_file']}"
return result
# 创建Gradio界面
with gr.Blocks(title="智能文档问答助手", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 📚 智能文档问答助手
基于HelloAgents的智能文档问答系统支持
- 📄 加载PDF文档并构建知识库
- 💬 智能问答基于RAG
- 📝 学习笔记记录
- 🧠 学习历程回顾
- 📊 学习报告生成
""")
with gr.Tab("🏠 开始使用"):
with gr.Row():
user_id_input = gr.Textbox(
label="用户ID",
placeholder="输入你的用户ID可选默认为web_user",
value="web_user"
)
init_btn = gr.Button("初始化助手", variant="primary")
init_output = gr.Textbox(label="初始化状态", interactive=False)
init_btn.click(init_assistant, inputs=[user_id_input], outputs=[init_output])
gr.Markdown("### 📄 加载PDF文档")
pdf_upload = gr.File(
label="上传PDF文件",
file_types=[".pdf"],
type="filepath"
)
load_btn = gr.Button("加载文档", variant="primary")
load_output = gr.Textbox(label="加载状态", interactive=False)
load_btn.click(load_pdf, inputs=[pdf_upload], outputs=[load_output])
with gr.Tab("💬 智能问答"):
gr.Markdown("### 向文档提问或回顾学习历程")
chatbot = gr.Chatbot(
label="对话历史",
height=400,
bubble_full_width=False
)
with gr.Row():
msg_input = gr.Textbox(
label="输入问题",
placeholder="例如什么是Transformer 或 我之前学过什么?",
scale=4
)
send_btn = gr.Button("发送", variant="primary", scale=1)
gr.Examples(
examples=[
"什么是大语言模型?",
"Transformer架构有哪些核心组件",
"如何训练大语言模型?",
"我之前学过什么内容?",
"回顾一下关于注意力机制的学习"
],
inputs=msg_input
)
msg_input.submit(chat, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot])
send_btn.click(chat, inputs=[msg_input, chatbot], outputs=[msg_input, chatbot])
with gr.Tab("📝 学习笔记"):
gr.Markdown("### 记录学习心得和重要概念")
note_content = gr.Textbox(
label="笔记内容",
placeholder="输入你的学习笔记...",
lines=3
)
concept_input = gr.Textbox(
label="相关概念(可选)",
placeholder="例如transformer, attention"
)
note_btn = gr.Button("保存笔记", variant="primary")
note_output = gr.Textbox(label="保存状态", interactive=False)
note_btn.click(add_note_ui, inputs=[note_content, concept_input], outputs=[note_output])
with gr.Tab("📊 学习统计"):
gr.Markdown("### 查看学习进度和统计信息")
stats_btn = gr.Button("刷新统计", variant="primary")
stats_output = gr.Markdown()
stats_btn.click(get_stats_ui, outputs=[stats_output])
gr.Markdown("### 生成学习报告")
report_btn = gr.Button("生成报告", variant="primary")
report_output = gr.Textbox(label="报告状态", interactive=False)
report_btn.click(generate_report_ui, outputs=[report_output])
return demo
def main():
"""主函数 - 启动Gradio Web UI"""
print("\n" + "="*60)
print("<EFBFBD> 智能文档问答助手")
print("="*60)
print("正在启动Web界面...\n")
demo = create_gradio_ui()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)
if __name__ == "__main__":
main()