You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
1.6 KiB
Python
41 lines
1.6 KiB
Python
# my_llm.py
|
|
import os
|
|
from typing import Optional
|
|
from openai import OpenAI
|
|
from hello_agents import HelloAgentsLLM
|
|
|
|
class MyLLM(HelloAgentsLLM):
|
|
def __init__(
|
|
self,
|
|
model: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
provider: Optional[str] = "auto",
|
|
**kwargs
|
|
):
|
|
# 检查provider是否为我们想处理的'modelscope'
|
|
if provider == "modelscope":
|
|
print("正在使用自定义的 ModelScope Provider")
|
|
self.provider = "modelscope"
|
|
|
|
# 解析 ModelScope 的凭证
|
|
self.api_key = api_key or os.getenv("MODELSCOPE_API_KEY")
|
|
self.base_url = base_url or "https://api-inference.modelscope.cn/v1/"
|
|
|
|
# 验证凭证是否存在
|
|
if not self.api_key:
|
|
raise ValueError("ModelScope API key not found. Please set MODELSCOPE_API_KEY environment variable.")
|
|
|
|
# 设置默认模型和其他参数
|
|
self.model = model or os.getenv("LLM_MODEL_ID") or "Qwen/Qwen2.5-VL-72B-Instruct"
|
|
self.temperature = kwargs.get('temperature', 0.7)
|
|
self.max_tokens = kwargs.get('max_tokens')
|
|
self.timeout = kwargs.get('timeout', 60)
|
|
|
|
# 使用获取的参数创建OpenAI客户端实例
|
|
self._client = OpenAI(api_key=self.api_key, base_url=self.base_url, timeout=self.timeout)
|
|
|
|
else:
|
|
# 如果不是 modelscope, 则完全使用父类的原始逻辑来处理
|
|
super().__init__(model=model, api_key=api_key, base_url=base_url, provider=provider, **kwargs)
|