You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

259 lines
8.6 KiB
Python

6 months ago
"""
智能搜索助手 - 基于 LangGraph + Tavily API 的真实搜索系统
1. 理解用户需求
2. 使用Tavily API真实搜索信息
3. 生成基于搜索结果的回答
"""
import asyncio
from typing import TypedDict, Annotated
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import InMemorySaver
import os
from dotenv import load_dotenv
from tavily import TavilyClient
# 加载环境变量
load_dotenv()
# 定义状态结构
class SearchState(TypedDict):
messages: Annotated[list, add_messages]
user_query: str # 用户查询
search_query: str # 优化后的搜索查询
search_results: str # Tavily搜索结果
final_answer: str # 最终答案
step: str # 当前步骤
# 初始化模型和Tavily客户端
llm = ChatOpenAI(
model=os.getenv("LLM_MODEL_ID", "gpt-4o-mini"),
api_key=os.getenv("LLM_API_KEY"),
base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"),
temperature=0.7
)
# 初始化Tavily客户端
tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
def understand_query_node(state: SearchState) -> SearchState:
"""步骤1理解用户查询并生成搜索关键词"""
# 获取最新的用户消息
user_message = ""
for msg in reversed(state["messages"]):
if isinstance(msg, HumanMessage):
user_message = msg.content
break
understand_prompt = f"""分析用户的查询:"{user_message}"
请完成两个任务
1. 简洁总结用户想要了解什么
2. 生成最适合搜索的关键词中英文均可要精准
格式
理解[用户需求总结]
搜索词[最佳搜索关键词]"""
response = llm.invoke([SystemMessage(content=understand_prompt)])
# 提取搜索关键词
response_text = response.content
search_query = user_message # 默认使用原始查询
if "搜索词:" in response_text:
search_query = response_text.split("搜索词:")[1].strip()
elif "搜索关键词:" in response_text:
search_query = response_text.split("搜索关键词:")[1].strip()
return {
"user_query": response.content,
"search_query": search_query,
"step": "understood",
"messages": [AIMessage(content=f"我理解您的需求:{response.content}")]
}
def tavily_search_node(state: SearchState) -> SearchState:
"""步骤2使用Tavily API进行真实搜索"""
search_query = state["search_query"]
try:
print(f"🔍 正在搜索: {search_query}")
# 调用Tavily搜索API
response = tavily_client.search(
query=search_query,
search_depth="basic",
include_answer=True,
include_raw_content=False,
max_results=5
)
# 处理搜索结果
search_results = ""
# 优先使用Tavily的综合答案
if response.get("answer"):
search_results = f"综合答案:\n{response['answer']}\n\n"
# 添加具体的搜索结果
if response.get("results"):
search_results += "相关信息:\n"
for i, result in enumerate(response["results"][:3], 1):
title = result.get("title", "")
content = result.get("content", "")
url = result.get("url", "")
search_results += f"{i}. {title}\n{content}\n来源:{url}\n\n"
if not search_results:
search_results = "抱歉,没有找到相关信息。"
return {
"search_results": search_results,
"step": "searched",
"messages": [AIMessage(content=f"✅ 搜索完成!找到了相关信息,正在为您整理答案...")]
}
except Exception as e:
error_msg = f"搜索时发生错误: {str(e)}"
print(f"{error_msg}")
return {
"search_results": f"搜索失败:{error_msg}",
"step": "search_failed",
"messages": [AIMessage(content="❌ 搜索遇到问题,我将基于已有知识为您回答")]
}
def generate_answer_node(state: SearchState) -> SearchState:
"""步骤3基于搜索结果生成最终答案"""
# 检查是否有搜索结果
if state["step"] == "search_failed":
# 如果搜索失败基于LLM知识回答
fallback_prompt = f"""搜索API暂时不可用请基于您的知识回答用户的问题
用户问题{state['user_query']}
请提供一个有用的回答并说明这是基于已有知识的回答"""
response = llm.invoke([SystemMessage(content=fallback_prompt)])
return {
"final_answer": response.content,
"step": "completed",
"messages": [AIMessage(content=response.content)]
}
# 基于搜索结果生成答案
answer_prompt = f"""基于以下搜索结果为用户提供完整、准确的答案:
用户问题{state['user_query']}
搜索结果
{state['search_results']}
请要求
1. 综合搜索结果提供准确有用的回答
2. 如果是技术问题提供具体的解决方案或代码
3. 引用重要信息的来源
4. 回答要结构清晰易于理解
5. 如果搜索结果不够完整请说明并提供补充建议"""
response = llm.invoke([SystemMessage(content=answer_prompt)])
return {
"final_answer": response.content,
"step": "completed",
"messages": [AIMessage(content=response.content)]
}
# 构建搜索工作流
def create_search_assistant():
workflow = StateGraph(SearchState)
# 添加三个节点
workflow.add_node("understand", understand_query_node)
workflow.add_node("search", tavily_search_node)
workflow.add_node("answer", generate_answer_node)
# 设置线性流程
workflow.add_edge(START, "understand")
workflow.add_edge("understand", "search")
workflow.add_edge("search", "answer")
workflow.add_edge("answer", END)
# 编译图
memory = InMemorySaver()
app = workflow.compile(checkpointer=memory)
return app
async def main():
"""主函数:运行智能搜索助手"""
# 检查API密钥
if not os.getenv("TAVILY_API_KEY"):
print("❌ 错误:请在.env文件中配置TAVILY_API_KEY")
return
app = create_search_assistant()
print("🔍 智能搜索助手启动!")
print("我会使用Tavily API为您搜索最新、最准确的信息")
print("支持各种问题:新闻、技术、知识问答等")
print("(输入 'quit' 退出)\n")
session_count = 0
while True:
user_input = input("🤔 您想了解什么: ").strip()
if user_input.lower() in ['quit', 'q', '退出', 'exit']:
print("感谢使用!再见!👋")
break
if not user_input:
continue
session_count += 1
config = {"configurable": {"thread_id": f"search-session-{session_count}"}}
# 初始状态
initial_state = {
"messages": [HumanMessage(content=user_input)],
"user_query": "",
"search_query": "",
"search_results": "",
"final_answer": "",
"step": "start"
}
try:
print("\n" + "="*60)
# 执行工作流
async for output in app.astream(initial_state, config=config):
for node_name, node_output in output.items():
if "messages" in node_output and node_output["messages"]:
latest_message = node_output["messages"][-1]
if isinstance(latest_message, AIMessage):
if node_name == "understand":
print(f"🧠 理解阶段: {latest_message.content}")
elif node_name == "search":
print(f"🔍 搜索阶段: {latest_message.content}")
elif node_name == "answer":
print(f"\n💡 最终回答:\n{latest_message.content}")
print("\n" + "="*60 + "\n")
except Exception as e:
print(f"❌ 发生错误: {e}")
print("请重新输入您的问题。\n")
if __name__ == "__main__":
asyncio.run(main())