feat: add DashScope (Alibaba Cloud) LLM provider support
- Add DashScope to CLI LLM provider selection (first option) - Add DashScope support in TradingAgentsGraph with proper error handling - Import ChatDashScope adapter with fallback for missing dependencies - Update default_config.py with DashScope configuration comments - Create comprehensive DashScope configuration example Features added: - DashScope provider selection in CLI (cli/utils.py) - LLM initialization for DashScope models (trading_graph.py) - Configuration validation and testing (examples/dashscope_config_example.py) - Support for qwen-turbo, qwen-plus, qwen-max models - Proper error handling for missing dashscope package Usage: 1. Install: pip install dashscope 2. Configure: DASHSCOPE_API_KEY in .env file 3. Select 'DashScope (Alibaba Cloud)' in CLI 4. Use models: qwen-turbo (fast), qwen-plus (balanced), qwen-max (best) This enables Chinese users to use Alibaba Cloud's Qwen models for financial analysis with optimized Chinese language support.
This commit is contained in:
parent
114f7f0618
commit
cdba5bf780
|
|
@ -240,14 +240,16 @@ def select_deep_thinking_agent(provider) -> str:
|
|||
return choice
|
||||
|
||||
def select_llm_provider() -> tuple[str, str]:
|
||||
"""Select the OpenAI api url using interactive selection."""
|
||||
# Define OpenAI api options with their corresponding endpoints
|
||||
"""Select the LLM provider using interactive selection."""
|
||||
# Define LLM provider options with their corresponding endpoints
|
||||
# DashScope (Alibaba Cloud) is recommended for Chinese users
|
||||
BASE_URLS = [
|
||||
("DashScope (Alibaba Cloud)", "https://dashscope.aliyuncs.com/api/v1"),
|
||||
("OpenAI", "https://api.openai.com/v1"),
|
||||
("Anthropic", "https://api.anthropic.com/"),
|
||||
("Google", "https://generativelanguage.googleapis.com/v1"),
|
||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
("Ollama", "http://localhost:11434/v1"),
|
||||
]
|
||||
|
||||
choice = questionary.select(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,202 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
DashScope (Alibaba Cloud) Configuration Example
|
||||
阿里云百炼模型配置示例
|
||||
|
||||
This example shows how to configure TradingAgents to use DashScope models.
|
||||
这个示例展示如何配置TradingAgents使用阿里云百炼模型。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
def create_dashscope_config():
|
||||
"""
|
||||
Create configuration for DashScope models
|
||||
创建百炼模型配置
|
||||
"""
|
||||
|
||||
# Copy default config
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
|
||||
# Configure for DashScope
|
||||
config.update({
|
||||
# LLM Provider Settings
|
||||
"llm_provider": "dashscope",
|
||||
"backend_url": "https://dashscope.aliyuncs.com/api/v1",
|
||||
|
||||
# Model Selection
|
||||
# 模型选择 - 根据需要调整
|
||||
"deep_think_llm": "qwen-plus", # For complex analysis 复杂分析
|
||||
"quick_think_llm": "qwen-turbo", # For quick tasks 快速任务
|
||||
|
||||
# Optional: Reduce rounds for faster execution
|
||||
# 可选:减少轮次以加快执行速度
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
|
||||
# Enable online tools
|
||||
"online_tools": True,
|
||||
})
|
||||
|
||||
return config
|
||||
|
||||
def check_dashscope_setup():
|
||||
"""
|
||||
Check if DashScope is properly configured
|
||||
检查百炼配置是否正确
|
||||
"""
|
||||
|
||||
print("🔍 Checking DashScope Configuration")
|
||||
print("🔍 检查百炼配置")
|
||||
print("=" * 50)
|
||||
|
||||
# Check API key
|
||||
api_key = os.getenv('DASHSCOPE_API_KEY')
|
||||
if api_key:
|
||||
print(f"✅ DASHSCOPE_API_KEY: {api_key[:10]}...")
|
||||
else:
|
||||
print("❌ DASHSCOPE_API_KEY not found in environment variables")
|
||||
print("❌ 环境变量中未找到 DASHSCOPE_API_KEY")
|
||||
print("\n💡 To fix this:")
|
||||
print("💡 解决方法:")
|
||||
print("1. Get API key from: https://dashscope.aliyun.com/")
|
||||
print("1. 从以下网址获取API密钥: https://dashscope.aliyun.com/")
|
||||
print("2. Add to .env file: DASHSCOPE_API_KEY=your_key_here")
|
||||
print("2. 添加到.env文件: DASHSCOPE_API_KEY=your_key_here")
|
||||
return False
|
||||
|
||||
# Check DashScope package
|
||||
try:
|
||||
import dashscope
|
||||
print("✅ dashscope package installed")
|
||||
print("✅ dashscope包已安装")
|
||||
except ImportError:
|
||||
print("❌ dashscope package not installed")
|
||||
print("❌ dashscope包未安装")
|
||||
print("\n💡 To install:")
|
||||
print("💡 安装方法:")
|
||||
print("pip install dashscope")
|
||||
return False
|
||||
|
||||
# Check adapter
|
||||
try:
|
||||
from tradingagents.llm_adapters.dashscope_adapter import ChatDashScope
|
||||
print("✅ DashScope adapter available")
|
||||
print("✅ 百炼适配器可用")
|
||||
except ImportError:
|
||||
print("❌ DashScope adapter not available")
|
||||
print("❌ 百炼适配器不可用")
|
||||
return False
|
||||
|
||||
print("\n🎉 DashScope configuration is ready!")
|
||||
print("🎉 百炼配置已就绪!")
|
||||
return True
|
||||
|
||||
def test_dashscope_connection():
|
||||
"""
|
||||
Test connection to DashScope
|
||||
测试百炼连接
|
||||
"""
|
||||
|
||||
print("\n🧪 Testing DashScope Connection")
|
||||
print("🧪 测试百炼连接")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
from tradingagents.llm_adapters.dashscope_adapter import ChatDashScope
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
# Create model instance
|
||||
llm = ChatDashScope(
|
||||
model="qwen-turbo",
|
||||
temperature=0.1,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
# Test simple query
|
||||
test_message = HumanMessage(content="Hello, please respond with 'DashScope connection successful!'")
|
||||
response = llm.invoke([test_message])
|
||||
|
||||
print(f"✅ Connection successful!")
|
||||
print(f"✅ 连接成功!")
|
||||
print(f"📝 Response: {response.content}")
|
||||
print(f"📝 响应: {response.content}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Connection failed: {str(e)}")
|
||||
print(f"❌ 连接失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to demonstrate DashScope configuration
|
||||
主函数演示百炼配置
|
||||
"""
|
||||
|
||||
print("🚀 DashScope Configuration Example")
|
||||
print("🚀 百炼配置示例")
|
||||
print("=" * 50)
|
||||
|
||||
# Check setup
|
||||
if not check_dashscope_setup():
|
||||
print("\n❌ Please fix the configuration issues above")
|
||||
print("❌ 请修复上述配置问题")
|
||||
return
|
||||
|
||||
# Test connection
|
||||
if not test_dashscope_connection():
|
||||
print("\n❌ Connection test failed")
|
||||
print("❌ 连接测试失败")
|
||||
return
|
||||
|
||||
# Show configuration
|
||||
config = create_dashscope_config()
|
||||
|
||||
print(f"\n📋 DashScope Configuration:")
|
||||
print(f"📋 百炼配置:")
|
||||
print(f" Provider: {config['llm_provider']}")
|
||||
print(f" Deep Think Model: {config['deep_think_llm']}")
|
||||
print(f" Quick Think Model: {config['quick_think_llm']}")
|
||||
print(f" Backend URL: {config['backend_url']}")
|
||||
|
||||
print(f"\n💡 Usage Example:")
|
||||
print(f"💡 使用示例:")
|
||||
print(f"""
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
# Create config
|
||||
config = create_dashscope_config()
|
||||
|
||||
# Initialize trading graph
|
||||
ta = TradingAgentsGraph(config)
|
||||
|
||||
# Run analysis
|
||||
result, decision = ta.propagate("AAPL", "2024-01-15")
|
||||
print(result)
|
||||
""")
|
||||
|
||||
print(f"\n🎯 Available DashScope Models:")
|
||||
print(f"🎯 可用的百炼模型:")
|
||||
|
||||
models = {
|
||||
"qwen-turbo": "Fast response, suitable for daily conversations",
|
||||
"qwen-plus": "Balanced performance and cost",
|
||||
"qwen-max": "Best performance",
|
||||
"qwen-max-longcontext": "Supports ultra-long context"
|
||||
}
|
||||
|
||||
for model, description in models.items():
|
||||
print(f" • {model}: {description}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,321 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
合并后功能测试验证脚本
|
||||
测试所有新增功能和原有功能的兼容性
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import traceback
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
class MergedFeaturesTest:
|
||||
"""合并功能测试类"""
|
||||
|
||||
def __init__(self):
|
||||
self.test_results = {
|
||||
"passed": [],
|
||||
"failed": [],
|
||||
"warnings": []
|
||||
}
|
||||
self.temp_dir = None
|
||||
|
||||
def setup(self):
|
||||
"""测试环境设置"""
|
||||
print("🔧 设置测试环境...")
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
print(f" 临时目录: {self.temp_dir}")
|
||||
|
||||
def cleanup(self):
|
||||
"""清理测试环境"""
|
||||
if self.temp_dir and os.path.exists(self.temp_dir):
|
||||
shutil.rmtree(self.temp_dir)
|
||||
print(f"🧹 清理临时目录: {self.temp_dir}")
|
||||
|
||||
def test_basic_imports(self):
|
||||
"""测试基础模块导入"""
|
||||
print("\n📦 测试基础模块导入...")
|
||||
|
||||
basic_modules = [
|
||||
"tradingagents.default_config",
|
||||
"tradingagents.dataflows.interface",
|
||||
"tradingagents.dataflows.config",
|
||||
]
|
||||
|
||||
for module in basic_modules:
|
||||
try:
|
||||
__import__(module)
|
||||
self.test_results["passed"].append(f"基础导入: {module}")
|
||||
print(f" ✅ {module}")
|
||||
except Exception as e:
|
||||
self.test_results["failed"].append(f"基础导入: {module} - {e}")
|
||||
print(f" ❌ {module}: {e}")
|
||||
|
||||
def test_cache_system(self):
|
||||
"""测试缓存系统"""
|
||||
print("\n💾 测试缓存系统...")
|
||||
|
||||
try:
|
||||
# 测试原有缓存管理器
|
||||
from tradingagents.dataflows.cache_manager import StockDataCache, get_cache
|
||||
|
||||
# 创建缓存实例
|
||||
cache = StockDataCache(cache_dir=self.temp_dir)
|
||||
|
||||
# 测试基本功能
|
||||
test_data = "Test stock data for AAPL"
|
||||
cache_key = cache.save_stock_data("AAPL", test_data, "2024-01-01", "2024-01-31", "test")
|
||||
|
||||
if cache_key:
|
||||
loaded_data = cache.load_stock_data(cache_key)
|
||||
if loaded_data == test_data:
|
||||
self.test_results["passed"].append("缓存系统: 基本功能正常")
|
||||
print(" ✅ 基本缓存功能正常")
|
||||
else:
|
||||
self.test_results["failed"].append("缓存系统: 数据不匹配")
|
||||
print(" ❌ 缓存数据不匹配")
|
||||
else:
|
||||
self.test_results["failed"].append("缓存系统: 保存失败")
|
||||
print(" ❌ 缓存保存失败")
|
||||
|
||||
# 测试市场类型检测
|
||||
us_market = cache._determine_market_type("AAPL")
|
||||
china_market = cache._determine_market_type("000001")
|
||||
|
||||
if us_market == "us" and china_market == "china":
|
||||
self.test_results["passed"].append("缓存系统: 市场类型检测正常")
|
||||
print(" ✅ 市场类型检测正常")
|
||||
else:
|
||||
self.test_results["failed"].append("缓存系统: 市场类型检测异常")
|
||||
print(" ❌ 市场类型检测异常")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["failed"].append(f"缓存系统: {e}")
|
||||
print(f" ❌ 缓存系统测试失败: {e}")
|
||||
|
||||
def test_new_features_import(self):
|
||||
"""测试新功能模块导入"""
|
||||
print("\n🆕 测试新功能模块导入...")
|
||||
|
||||
new_modules = [
|
||||
# 中国市场数据
|
||||
("tradingagents.dataflows.chinese_finance_utils", "中国财经数据工具"),
|
||||
("tradingagents.dataflows.tdx_utils", "通达信API工具"),
|
||||
("tradingagents.dataflows.optimized_china_data", "优化A股数据提供器"),
|
||||
|
||||
# 高级缓存
|
||||
("tradingagents.dataflows.adaptive_cache", "自适应缓存"),
|
||||
("tradingagents.dataflows.integrated_cache", "集成缓存"),
|
||||
("tradingagents.dataflows.db_cache_manager", "数据库缓存管理"),
|
||||
|
||||
# 配置管理
|
||||
("tradingagents.config.database_config", "数据库配置"),
|
||||
("tradingagents.config.database_manager", "数据库管理器"),
|
||||
("tradingagents.config.mongodb_storage", "MongoDB存储"),
|
||||
|
||||
# LLM适配器
|
||||
("tradingagents.llm_adapters.dashscope_adapter", "DashScope适配器"),
|
||||
|
||||
# API服务
|
||||
("tradingagents.api.stock_api", "股票API"),
|
||||
("tradingagents.dataflows.stock_data_service", "股票数据服务"),
|
||||
("tradingagents.dataflows.realtime_news_utils", "实时新闻工具"),
|
||||
]
|
||||
|
||||
for module_name, description in new_modules:
|
||||
try:
|
||||
__import__(module_name)
|
||||
self.test_results["passed"].append(f"新功能导入: {description}")
|
||||
print(f" ✅ {description}")
|
||||
except ImportError as e:
|
||||
if "No module named" in str(e):
|
||||
self.test_results["warnings"].append(f"新功能导入: {description} - 可能缺少依赖")
|
||||
print(f" ⚠️ {description}: 可能缺少依赖 ({e})")
|
||||
else:
|
||||
self.test_results["failed"].append(f"新功能导入: {description} - {e}")
|
||||
print(f" ❌ {description}: {e}")
|
||||
except Exception as e:
|
||||
self.test_results["failed"].append(f"新功能导入: {description} - {e}")
|
||||
print(f" ❌ {description}: {e}")
|
||||
|
||||
def test_optimized_data_providers(self):
|
||||
"""测试优化的数据提供器"""
|
||||
print("\n📊 测试优化数据提供器...")
|
||||
|
||||
try:
|
||||
# 测试美股数据提供器
|
||||
from tradingagents.dataflows.optimized_us_data import OptimizedUSDataProvider
|
||||
|
||||
provider = OptimizedUSDataProvider()
|
||||
self.test_results["passed"].append("数据提供器: 美股提供器初始化成功")
|
||||
print(" ✅ 美股数据提供器初始化成功")
|
||||
|
||||
# 测试基本方法存在
|
||||
required_methods = ['get_stock_data', '_wait_for_rate_limit', '_format_stock_data']
|
||||
for method in required_methods:
|
||||
if hasattr(provider, method):
|
||||
self.test_results["passed"].append(f"数据提供器: {method} 方法存在")
|
||||
print(f" ✅ {method} 方法存在")
|
||||
else:
|
||||
self.test_results["failed"].append(f"数据提供器: {method} 方法缺失")
|
||||
print(f" ❌ {method} 方法缺失")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["failed"].append(f"数据提供器: {e}")
|
||||
print(f" ❌ 数据提供器测试失败: {e}")
|
||||
|
||||
def test_config_system(self):
|
||||
"""测试配置系统"""
|
||||
print("\n⚙️ 测试配置系统...")
|
||||
|
||||
try:
|
||||
# 测试默认配置
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
# 检查基本配置项
|
||||
required_configs = [
|
||||
"project_dir", "results_dir", "data_dir",
|
||||
"llm_provider", "deep_think_llm", "quick_think_llm"
|
||||
]
|
||||
|
||||
for config_key in required_configs:
|
||||
if config_key in DEFAULT_CONFIG:
|
||||
self.test_results["passed"].append(f"配置系统: {config_key} 存在")
|
||||
print(f" ✅ {config_key} 配置存在")
|
||||
else:
|
||||
self.test_results["failed"].append(f"配置系统: {config_key} 缺失")
|
||||
print(f" ❌ {config_key} 配置缺失")
|
||||
|
||||
# 测试动态配置
|
||||
from tradingagents.dataflows.config import get_config, set_config
|
||||
|
||||
current_config = get_config()
|
||||
if current_config:
|
||||
self.test_results["passed"].append("配置系统: 动态配置获取正常")
|
||||
print(" ✅ 动态配置获取正常")
|
||||
else:
|
||||
self.test_results["failed"].append("配置系统: 动态配置获取失败")
|
||||
print(" ❌ 动态配置获取失败")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["failed"].append(f"配置系统: {e}")
|
||||
print(f" ❌ 配置系统测试失败: {e}")
|
||||
|
||||
def test_main_functionality(self):
|
||||
"""测试主要功能"""
|
||||
print("\n🚀 测试主要功能...")
|
||||
|
||||
try:
|
||||
# 测试主程序导入
|
||||
import main
|
||||
self.test_results["passed"].append("主功能: main.py 导入成功")
|
||||
print(" ✅ main.py 导入成功")
|
||||
|
||||
# 测试交易图形导入
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
self.test_results["passed"].append("主功能: TradingAgentsGraph 导入成功")
|
||||
print(" ✅ TradingAgentsGraph 导入成功")
|
||||
|
||||
except Exception as e:
|
||||
self.test_results["failed"].append(f"主功能: {e}")
|
||||
print(f" ❌ 主功能测试失败: {e}")
|
||||
|
||||
def test_documentation(self):
|
||||
"""测试文档完整性"""
|
||||
print("\n📚 测试文档完整性...")
|
||||
|
||||
doc_files = [
|
||||
"docs/README.md",
|
||||
"docs/en-US/configuration_guide.md",
|
||||
"docs/en-US/quick_reference.md",
|
||||
"docs/en-US/prompt_templates.md",
|
||||
"MERGE_SUMMARY.md"
|
||||
]
|
||||
|
||||
for doc_file in doc_files:
|
||||
if os.path.exists(doc_file):
|
||||
self.test_results["passed"].append(f"文档: {doc_file} 存在")
|
||||
print(f" ✅ {doc_file}")
|
||||
else:
|
||||
self.test_results["failed"].append(f"文档: {doc_file} 缺失")
|
||||
print(f" ❌ {doc_file}")
|
||||
|
||||
def run_all_tests(self):
|
||||
"""运行所有测试"""
|
||||
print("🧪 开始合并后功能测试验证")
|
||||
print("=" * 50)
|
||||
|
||||
self.setup()
|
||||
|
||||
try:
|
||||
self.test_basic_imports()
|
||||
self.test_cache_system()
|
||||
self.test_new_features_import()
|
||||
self.test_optimized_data_providers()
|
||||
self.test_config_system()
|
||||
self.test_main_functionality()
|
||||
self.test_documentation()
|
||||
|
||||
finally:
|
||||
self.cleanup()
|
||||
|
||||
self.print_summary()
|
||||
|
||||
def print_summary(self):
|
||||
"""打印测试摘要"""
|
||||
print("\n" + "=" * 50)
|
||||
print("📋 测试结果摘要")
|
||||
print("=" * 50)
|
||||
|
||||
total_passed = len(self.test_results["passed"])
|
||||
total_failed = len(self.test_results["failed"])
|
||||
total_warnings = len(self.test_results["warnings"])
|
||||
total_tests = total_passed + total_failed + total_warnings
|
||||
|
||||
print(f"\n📊 统计:")
|
||||
print(f" 总测试项: {total_tests}")
|
||||
print(f" ✅ 通过: {total_passed}")
|
||||
print(f" ❌ 失败: {total_failed}")
|
||||
print(f" ⚠️ 警告: {total_warnings}")
|
||||
|
||||
if total_failed == 0:
|
||||
print(f"\n🎉 所有核心功能测试通过!")
|
||||
if total_warnings > 0:
|
||||
print(f"⚠️ 有 {total_warnings} 个警告,主要是可选依赖缺失")
|
||||
else:
|
||||
print(f"\n❌ 有 {total_failed} 个测试失败,需要修复")
|
||||
|
||||
# 详细结果
|
||||
if self.test_results["failed"]:
|
||||
print(f"\n❌ 失败的测试:")
|
||||
for failure in self.test_results["failed"]:
|
||||
print(f" - {failure}")
|
||||
|
||||
if self.test_results["warnings"]:
|
||||
print(f"\n⚠️ 警告:")
|
||||
for warning in self.test_results["warnings"]:
|
||||
print(f" - {warning}")
|
||||
|
||||
# 建议
|
||||
print(f"\n💡 建议:")
|
||||
if total_failed == 0:
|
||||
print(" 1. 核心功能正常,可以进行更深入的集成测试")
|
||||
print(" 2. 考虑安装可选依赖以启用完整功能")
|
||||
print(" 3. 运行实际的股票数据获取测试")
|
||||
else:
|
||||
print(" 1. 修复失败的测试项")
|
||||
print(" 2. 检查依赖项安装")
|
||||
print(" 3. 验证文件路径和导入")
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
tester = MergedFeaturesTest()
|
||||
tester.run_all_tests()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -9,6 +9,8 @@ DEFAULT_CONFIG = {
|
|||
"dataflows/data_cache",
|
||||
),
|
||||
# LLM settings
|
||||
# Supported providers: "openai", "anthropic", "google", "dashscope", "ollama", "openrouter"
|
||||
# For DashScope: set llm_provider="dashscope", deep_think_llm="qwen-plus", quick_think_llm="qwen-turbo"
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "o4-mini",
|
||||
"quick_think_llm": "gpt-4o-mini",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,14 @@ from langchain_openai import ChatOpenAI
|
|||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
# Import DashScope adapter if available
|
||||
try:
|
||||
from ..llm_adapters.dashscope_adapter import ChatDashScope
|
||||
DASHSCOPE_AVAILABLE = True
|
||||
except ImportError:
|
||||
DASHSCOPE_AVAILABLE = False
|
||||
ChatDashScope = None
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.agents import *
|
||||
|
|
@ -65,8 +73,35 @@ class TradingAgentsGraph:
|
|||
self.deep_thinking_llm = ChatAnthropic(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
|
||||
self.quick_thinking_llm = ChatAnthropic(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
|
||||
elif self.config["llm_provider"].lower() == "google":
|
||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(model=self.config["deep_think_llm"])
|
||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(model=self.config["quick_think_llm"])
|
||||
google_api_key = os.getenv('GOOGLE_API_KEY')
|
||||
self.deep_thinking_llm = ChatGoogleGenerativeAI(
|
||||
model=self.config["deep_think_llm"],
|
||||
google_api_key=google_api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=2000
|
||||
)
|
||||
self.quick_thinking_llm = ChatGoogleGenerativeAI(
|
||||
model=self.config["quick_think_llm"],
|
||||
google_api_key=google_api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=2000
|
||||
)
|
||||
elif (self.config["llm_provider"].lower() == "dashscope" or
|
||||
"dashscope" in self.config["llm_provider"].lower() or
|
||||
"alibaba" in self.config["llm_provider"].lower()):
|
||||
if not DASHSCOPE_AVAILABLE:
|
||||
raise ValueError("DashScope adapter not available. Please install dashscope package: pip install dashscope")
|
||||
|
||||
self.deep_thinking_llm = ChatDashScope(
|
||||
model=self.config["deep_think_llm"],
|
||||
temperature=0.1,
|
||||
max_tokens=2000
|
||||
)
|
||||
self.quick_thinking_llm = ChatDashScope(
|
||||
model=self.config["quick_think_llm"],
|
||||
temperature=0.1,
|
||||
max_tokens=2000
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue