add deepseek and kimi provider support
This commit is contained in:
parent
10c136f49c
commit
9745421555
|
|
@ -4,3 +4,6 @@ GOOGLE_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
XAI_API_KEY=
|
XAI_API_KEY=
|
||||||
OPENROUTER_API_KEY=
|
OPENROUTER_API_KEY=
|
||||||
|
DEEPSEEK_API_KEY=
|
||||||
|
KIMI_API_KEY=
|
||||||
|
MOONSHOT_API_KEY=
|
||||||
|
|
|
||||||
|
|
@ -141,6 +141,8 @@ export GOOGLE_API_KEY=... # Google (Gemini)
|
||||||
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
export ANTHROPIC_API_KEY=... # Anthropic (Claude)
|
||||||
export XAI_API_KEY=... # xAI (Grok)
|
export XAI_API_KEY=... # xAI (Grok)
|
||||||
export OPENROUTER_API_KEY=... # OpenRouter
|
export OPENROUTER_API_KEY=... # OpenRouter
|
||||||
|
export DEEPSEEK_API_KEY=... # DeepSeek
|
||||||
|
export KIMI_API_KEY=... # Kimi (Moonshot)
|
||||||
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
export ALPHA_VANTAGE_API_KEY=... # Alpha Vantage
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -178,7 +180,7 @@ An interface will appear showing results as they load, letting you track the age
|
||||||
|
|
||||||
### Implementation Details
|
### Implementation Details
|
||||||
|
|
||||||
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, OpenRouter, and Ollama.
|
We built TradingAgents with LangGraph to ensure flexibility and modularity. The framework supports multiple LLM providers: OpenAI, Google, Anthropic, xAI, DeepSeek, Kimi, OpenRouter, and Ollama.
|
||||||
|
|
||||||
### Python Usage
|
### Python Usage
|
||||||
|
|
||||||
|
|
@ -202,7 +204,7 @@ from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||||
from tradingagents.default_config import DEFAULT_CONFIG
|
from tradingagents.default_config import DEFAULT_CONFIG
|
||||||
|
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["llm_provider"] = "openai" # openai, google, anthropic, xai, openrouter, ollama
|
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, kimi, openrouter, ollama
|
||||||
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
|
config["deep_think_llm"] = "gpt-5.4" # Model for complex reasoning
|
||||||
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
config["quick_think_llm"] = "gpt-5.4-mini" # Model for quick tasks
|
||||||
config["max_debate_rounds"] = 2
|
config["max_debate_rounds"] = 2
|
||||||
|
|
|
||||||
|
|
@ -240,6 +240,8 @@ def select_llm_provider() -> tuple[str, str | None]:
|
||||||
("Google", None), # google-genai SDK manages its own endpoint
|
("Google", None), # google-genai SDK manages its own endpoint
|
||||||
("Anthropic", "https://api.anthropic.com/"),
|
("Anthropic", "https://api.anthropic.com/"),
|
||||||
("xAI", "https://api.x.ai/v1"),
|
("xAI", "https://api.x.ai/v1"),
|
||||||
|
("DeepSeek", "https://api.deepseek.com/v1"),
|
||||||
|
("Kimi", "https://api.moonshot.cn/v1"),
|
||||||
("Openrouter", "https://openrouter.ai/api/v1"),
|
("Openrouter", "https://openrouter.ai/api/v1"),
|
||||||
("Ollama", "http://localhost:11434/v1"),
|
("Ollama", "http://localhost:11434/v1"),
|
||||||
]
|
]
|
||||||
|
|
@ -261,7 +263,7 @@ def select_llm_provider() -> tuple[str, str | None]:
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
if choice is None:
|
if choice is None:
|
||||||
console.print("\n[red]no OpenAI backend selected. Exiting...[/red]")
|
console.print("\n[red]No LLM provider selected. Exiting...[/red]")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
display_name, url = choice
|
display_name, url = choice
|
||||||
|
|
|
||||||
1
main.py
1
main.py
|
|
@ -10,6 +10,7 @@ load_dotenv()
|
||||||
config = DEFAULT_CONFIG.copy()
|
config = DEFAULT_CONFIG.copy()
|
||||||
config["deep_think_llm"] = "gpt-5.4-mini" # Use a different model
|
config["deep_think_llm"] = "gpt-5.4-mini" # Use a different model
|
||||||
config["quick_think_llm"] = "gpt-5.4-mini" # Use a different model
|
config["quick_think_llm"] = "gpt-5.4-mini" # Use a different model
|
||||||
|
config["llm_provider"] = "openai" # openai, google, anthropic, xai, deepseek, kimi, openrouter, ollama
|
||||||
config["max_debate_rounds"] = 1 # Increase debate rounds
|
config["max_debate_rounds"] = 1 # Increase debate rounds
|
||||||
|
|
||||||
# Configure data vendors (default uses yfinance, no extra API keys needed)
|
# Configure data vendors (default uses yfinance, no extra API keys needed)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from tradingagents.llm_clients.factory import create_llm_client
|
||||||
|
from tradingagents.llm_clients.openai_client import OpenAIClient
|
||||||
|
from tradingagents.llm_clients.validators import validate_model
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProviderSupportTests(unittest.TestCase):
|
||||||
|
def test_factory_supports_deepseek_and_kimi(self):
|
||||||
|
deepseek_client = create_llm_client("deepseek", "deepseek-chat")
|
||||||
|
kimi_client = create_llm_client("kimi", "kimi-latest")
|
||||||
|
|
||||||
|
self.assertIsInstance(deepseek_client, OpenAIClient)
|
||||||
|
self.assertIsInstance(kimi_client, OpenAIClient)
|
||||||
|
|
||||||
|
@patch("tradingagents.llm_clients.openai_client.NormalizedChatOpenAI")
|
||||||
|
def test_deepseek_uses_expected_base_url_and_key(self, mock_chat_openai):
|
||||||
|
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "deepseek-test-key"}, clear=False):
|
||||||
|
client = OpenAIClient("deepseek-chat", provider="deepseek")
|
||||||
|
client.get_llm()
|
||||||
|
|
||||||
|
kwargs = mock_chat_openai.call_args.kwargs
|
||||||
|
self.assertEqual(kwargs["base_url"], "https://api.deepseek.com/v1")
|
||||||
|
self.assertEqual(kwargs["api_key"], "deepseek-test-key")
|
||||||
|
self.assertEqual(kwargs["model"], "deepseek-chat")
|
||||||
|
|
||||||
|
@patch("tradingagents.llm_clients.openai_client.NormalizedChatOpenAI")
|
||||||
|
def test_kimi_prefers_kimi_api_key(self, mock_chat_openai):
|
||||||
|
with patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"KIMI_API_KEY": "kimi-test-key",
|
||||||
|
"MOONSHOT_API_KEY": "moonshot-test-key",
|
||||||
|
},
|
||||||
|
clear=False,
|
||||||
|
):
|
||||||
|
client = OpenAIClient("kimi-latest", provider="kimi")
|
||||||
|
client.get_llm()
|
||||||
|
|
||||||
|
kwargs = mock_chat_openai.call_args.kwargs
|
||||||
|
self.assertEqual(kwargs["base_url"], "https://api.moonshot.cn/v1")
|
||||||
|
self.assertEqual(kwargs["api_key"], "kimi-test-key")
|
||||||
|
self.assertEqual(kwargs["model"], "kimi-latest")
|
||||||
|
|
||||||
|
@patch("tradingagents.llm_clients.openai_client.NormalizedChatOpenAI")
|
||||||
|
def test_kimi_falls_back_to_moonshot_key(self, mock_chat_openai):
|
||||||
|
with patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{"KIMI_API_KEY": "", "MOONSHOT_API_KEY": "moonshot-test-key"},
|
||||||
|
clear=False,
|
||||||
|
):
|
||||||
|
client = OpenAIClient("kimi-thinking-preview", provider="kimi")
|
||||||
|
client.get_llm()
|
||||||
|
|
||||||
|
kwargs = mock_chat_openai.call_args.kwargs
|
||||||
|
self.assertEqual(kwargs["base_url"], "https://api.moonshot.cn/v1")
|
||||||
|
self.assertEqual(kwargs["api_key"], "moonshot-test-key")
|
||||||
|
self.assertEqual(kwargs["model"], "kimi-thinking-preview")
|
||||||
|
|
||||||
|
def test_validator_allows_fast_moving_compatibility_providers(self):
|
||||||
|
self.assertTrue(validate_model("deepseek", "any-model-name"))
|
||||||
|
self.assertTrue(validate_model("kimi", "any-model-name"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
@ -40,8 +40,8 @@ class ModelValidationTests(unittest.TestCase):
|
||||||
self.assertIn("not-a-real-openai-model", str(caught[0].message))
|
self.assertIn("not-a-real-openai-model", str(caught[0].message))
|
||||||
self.assertIn("openai", str(caught[0].message))
|
self.assertIn("openai", str(caught[0].message))
|
||||||
|
|
||||||
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
|
def test_compatibility_providers_accept_custom_models_without_warning(self):
|
||||||
for provider in ("openrouter", "ollama"):
|
for provider in ("openrouter", "ollama", "deepseek", "kimi"):
|
||||||
client = DummyLLMClient(provider, "custom-model-name")
|
client = DummyLLMClient(provider, "custom-model-name")
|
||||||
|
|
||||||
with self.subTest(provider=provider):
|
with self.subTest(provider=provider):
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ DEFAULT_CONFIG = {
|
||||||
"dataflows/data_cache",
|
"dataflows/data_cache",
|
||||||
),
|
),
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"llm_provider": "openai",
|
"llm_provider": "openai", # openai, google, anthropic, xai, deepseek, kimi, openrouter, ollama
|
||||||
"deep_think_llm": "gpt-5.4",
|
"deep_think_llm": "gpt-5.4",
|
||||||
"quick_think_llm": "gpt-5.4-mini",
|
"quick_think_llm": "gpt-5.4-mini",
|
||||||
"backend_url": "https://api.openai.com/v1",
|
"backend_url": "https://api.openai.com/v1",
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def create_llm_client(
|
||||||
"""Create an LLM client for the specified provider.
|
"""Create an LLM client for the specified provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
provider: LLM provider (openai, anthropic, google, xai, openrouter, deepseek, kimi, ollama)
|
||||||
model: Model name/identifier
|
model: Model name/identifier
|
||||||
base_url: Optional base URL for API endpoint
|
base_url: Optional base URL for API endpoint
|
||||||
**kwargs: Additional provider-specific arguments
|
**kwargs: Additional provider-specific arguments
|
||||||
|
|
@ -34,7 +34,7 @@ def create_llm_client(
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("openai", "ollama", "openrouter"):
|
if provider_lower in ("openai", "ollama", "openrouter", "deepseek", "kimi"):
|
||||||
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
return OpenAIClient(model, base_url, provider=provider_lower, **kwargs)
|
||||||
|
|
||||||
if provider_lower == "xai":
|
if provider_lower == "xai":
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,28 @@ MODEL_OPTIONS: ProviderModeOptions = {
|
||||||
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"quick": [
|
||||||
|
("DeepSeek Chat - Fast non-thinking mode", "deepseek-chat"),
|
||||||
|
("DeepSeek Reasoner - Strong reasoning mode", "deepseek-reasoner"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("DeepSeek Reasoner - Strong reasoning mode", "deepseek-reasoner"),
|
||||||
|
("DeepSeek Chat - Fast non-thinking mode", "deepseek-chat"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"kimi": {
|
||||||
|
"quick": [
|
||||||
|
("Kimi K2 Turbo (Preview) - High-throughput responses", "kimi-k2-turbo-preview"),
|
||||||
|
("Kimi Latest - Rolling latest default", "kimi-latest"),
|
||||||
|
("Moonshot V1 32K - Stable long-context", "moonshot-v1-32k"),
|
||||||
|
],
|
||||||
|
"deep": [
|
||||||
|
("Kimi Thinking (Preview) - Dedicated reasoning model", "kimi-thinking-preview"),
|
||||||
|
("Kimi K2 0905 (Preview) - Strong coding and agent tasks", "kimi-k2-0905-preview"),
|
||||||
|
("Moonshot V1 128K - Stable long-context", "moonshot-v1-128k"),
|
||||||
|
],
|
||||||
|
},
|
||||||
# OpenRouter models are fetched dynamically at CLI runtime.
|
# OpenRouter models are fetched dynamically at CLI runtime.
|
||||||
# No static entries needed; any model ID is accepted by the validator.
|
# No static entries needed; any model ID is accepted by the validator.
|
||||||
"ollama": {
|
"ollama": {
|
||||||
|
|
|
||||||
|
|
@ -26,19 +26,21 @@ _PASSTHROUGH_KWARGS = (
|
||||||
|
|
||||||
# Provider base URLs and API key env vars
|
# Provider base URLs and API key env vars
|
||||||
_PROVIDER_CONFIG = {
|
_PROVIDER_CONFIG = {
|
||||||
"xai": ("https://api.x.ai/v1", "XAI_API_KEY"),
|
"xai": ("https://api.x.ai/v1", ("XAI_API_KEY",)),
|
||||||
"openrouter": ("https://openrouter.ai/api/v1", "OPENROUTER_API_KEY"),
|
"openrouter": ("https://openrouter.ai/api/v1", ("OPENROUTER_API_KEY",)),
|
||||||
"ollama": ("http://localhost:11434/v1", None),
|
"deepseek": ("https://api.deepseek.com/v1", ("DEEPSEEK_API_KEY",)),
|
||||||
|
"kimi": ("https://api.moonshot.cn/v1", ("KIMI_API_KEY", "MOONSHOT_API_KEY")),
|
||||||
|
"ollama": ("http://localhost:11434/v1", ()),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class OpenAIClient(BaseLLMClient):
|
class OpenAIClient(BaseLLMClient):
|
||||||
"""Client for OpenAI, Ollama, OpenRouter, and xAI providers.
|
"""Client for OpenAI-compatible providers.
|
||||||
|
|
||||||
For native OpenAI models, uses the Responses API (/v1/responses) which
|
For native OpenAI models, uses the Responses API (/v1/responses) which
|
||||||
supports reasoning_effort with function tools across all model families
|
supports reasoning_effort with function tools across all model families
|
||||||
(GPT-4.1, GPT-5). Third-party compatible providers (xAI, OpenRouter,
|
(GPT-4.1, GPT-5). Third-party compatible providers (xAI, OpenRouter,
|
||||||
Ollama) use standard Chat Completions.
|
DeepSeek, Kimi, Ollama) use standard Chat Completions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -58,12 +60,14 @@ class OpenAIClient(BaseLLMClient):
|
||||||
|
|
||||||
# Provider-specific base URL and auth
|
# Provider-specific base URL and auth
|
||||||
if self.provider in _PROVIDER_CONFIG:
|
if self.provider in _PROVIDER_CONFIG:
|
||||||
base_url, api_key_env = _PROVIDER_CONFIG[self.provider]
|
base_url, api_key_envs = _PROVIDER_CONFIG[self.provider]
|
||||||
llm_kwargs["base_url"] = base_url
|
llm_kwargs["base_url"] = base_url
|
||||||
if api_key_env:
|
if api_key_envs:
|
||||||
api_key = os.environ.get(api_key_env)
|
for api_key_env in api_key_envs:
|
||||||
if api_key:
|
api_key = os.environ.get(api_key_env)
|
||||||
llm_kwargs["api_key"] = api_key
|
if api_key:
|
||||||
|
llm_kwargs["api_key"] = api_key
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
llm_kwargs["api_key"] = "ollama"
|
llm_kwargs["api_key"] = "ollama"
|
||||||
elif self.base_url:
|
elif self.base_url:
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,11 @@ VALID_MODELS = {
|
||||||
def validate_model(provider: str, model: str) -> bool:
|
def validate_model(provider: str, model: str) -> bool:
|
||||||
"""Check if model name is valid for the given provider.
|
"""Check if model name is valid for the given provider.
|
||||||
|
|
||||||
For ollama, openrouter - any model is accepted.
|
For compatibility providers with fast-moving model catalogs, any model is accepted.
|
||||||
"""
|
"""
|
||||||
provider_lower = provider.lower()
|
provider_lower = provider.lower()
|
||||||
|
|
||||||
if provider_lower in ("ollama", "openrouter"):
|
if provider_lower in ("ollama", "openrouter", "deepseek", "kimi"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if provider_lower not in VALID_MODELS:
|
if provider_lower not in VALID_MODELS:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue