Added test dapt files
This commit is contained in:
parent
287764cc1d
commit
948ad66343
|
|
@ -0,0 +1,361 @@
|
|||
absl-py==2.3.1
|
||||
accelerate==1.11.0
|
||||
aiofiles==24.1.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.1
|
||||
aiosignal==1.4.0
|
||||
akracer==0.0.14
|
||||
akshare==1.17.83
|
||||
annotated-doc==0.0.4
|
||||
annotated-types==0.7.0
|
||||
anthropic==0.72.1
|
||||
anyio==4.11.0
|
||||
argon2-cffi==25.1.0
|
||||
argon2-cffi-bindings==25.1.0
|
||||
arrow==1.4.0
|
||||
asttokens==3.0.0
|
||||
async-lru==2.0.5
|
||||
async-timeout==4.0.3
|
||||
asyncer==0.0.10
|
||||
attrs==25.4.0
|
||||
babel==2.17.0
|
||||
backoff==2.2.1
|
||||
backtrader==1.9.78.123
|
||||
bcrypt==5.0.0
|
||||
beautifulsoup4==4.14.2
|
||||
bidict==0.23.1
|
||||
bitsandbytes==0.48.1
|
||||
bleach==6.3.0
|
||||
bs4==0.0.2
|
||||
build==1.3.0
|
||||
cachetools==6.2.1
|
||||
certifi==2025.10.5
|
||||
cffi==2.0.0
|
||||
chainlit==2.9.0
|
||||
chardet==4.0.0
|
||||
charset-normalizer==3.4.4
|
||||
chevron==0.14.0
|
||||
chromadb==1.3.4
|
||||
click==8.3.0
|
||||
colorama==0.4.6
|
||||
coloredlogs==15.0.1
|
||||
comm==0.2.3
|
||||
contourpy==1.3.2
|
||||
cryptography==46.0.3
|
||||
cssselect==1.3.0
|
||||
cuid==0.4
|
||||
curl_cffi==0.13.0
|
||||
cycler==0.12.1
|
||||
dataclasses-json==0.6.7
|
||||
datasets==4.3.0
|
||||
debugpy==1.8.17
|
||||
decorator==5.2.1
|
||||
defusedxml==0.7.1
|
||||
Deprecated==1.3.1
|
||||
diffusers==0.35.2
|
||||
dill==0.4.0
|
||||
distro==1.9.0
|
||||
docstring_parser==0.17.0
|
||||
duckdb==1.4.2
|
||||
durationpy==0.10
|
||||
eodhd==1.0.32
|
||||
et_xmlfile==2.0.0
|
||||
exceptiongroup==1.3.0
|
||||
executing==2.2.1
|
||||
fastapi==0.121.1
|
||||
fastjsonschema==2.21.2
|
||||
feedparser==6.0.12
|
||||
filelock==3.0.12
|
||||
filetype==1.2.0
|
||||
finnhub-python==2.4.25
|
||||
flatbuffers==25.9.23
|
||||
fonttools==4.60.1
|
||||
fqdn==1.5.1
|
||||
frozendict==2.4.7
|
||||
frozenlist==1.8.0
|
||||
fsspec==2025.9.0
|
||||
google-ai-generativelanguage==0.9.0
|
||||
google-api-core==2.28.1
|
||||
google-auth==2.43.0
|
||||
googleapis-common-protos==1.72.0
|
||||
greenlet==3.2.4
|
||||
grpcio==1.76.0
|
||||
grpcio-status==1.76.0
|
||||
h11==0.16.0
|
||||
hf-xet==1.1.10
|
||||
html5lib==1.1
|
||||
httpcore==1.0.9
|
||||
httptools==0.7.1
|
||||
httpx==0.28.1
|
||||
httpx-sse==0.4.3
|
||||
huggingface-hub==0.35.3
|
||||
humanfriendly==10.0
|
||||
idna==2.10
|
||||
importlib_metadata==8.7.0
|
||||
importlib_resources==6.5.2
|
||||
inflection==0.5.1
|
||||
ipykernel==7.1.0
|
||||
ipython==8.37.0
|
||||
ipywidgets==8.1.7
|
||||
isoduration==20.11.0
|
||||
jedi==0.19.2
|
||||
Jinja2==3.1.6
|
||||
jiter==0.12.0
|
||||
jmespath==1.0.1
|
||||
joblib==1.5.2
|
||||
json5==0.12.1
|
||||
jsonpatch==1.33
|
||||
jsonpath==0.82.2
|
||||
jsonpointer==3.0.0
|
||||
jsonschema==4.25.1
|
||||
jsonschema-specifications==2025.9.1
|
||||
jupyter==1.1.1
|
||||
jupyter-console==6.6.3
|
||||
jupyter-events==0.12.0
|
||||
jupyter-lsp==2.3.0
|
||||
jupyter_client==8.6.3
|
||||
jupyter_core==5.9.1
|
||||
jupyter_server==2.17.0
|
||||
jupyter_server_terminals==0.5.3
|
||||
jupyterlab==4.4.10
|
||||
jupyterlab_pygments==0.3.0
|
||||
jupyterlab_server==2.28.0
|
||||
jupyterlab_widgets==3.0.15
|
||||
kiwisolver==1.4.9
|
||||
kubernetes==34.1.0
|
||||
langchain-anthropic==1.0.2
|
||||
langchain-classic==1.0.0
|
||||
langchain-community==0.4.1
|
||||
langchain-core==1.0.4
|
||||
langchain-experimental==0.4.0
|
||||
langchain-google-genai==3.0.2
|
||||
langchain-openai==1.0.2
|
||||
langchain-text-splitters==1.0.0
|
||||
langgraph==1.0.3
|
||||
langgraph-checkpoint==3.0.1
|
||||
langgraph-prebuilt==1.0.2
|
||||
langgraph-sdk==0.2.9
|
||||
langsmith==0.4.42
|
||||
lark==1.3.1
|
||||
Lazify==0.4.0
|
||||
literalai==0.1.201
|
||||
lxml==6.0.2
|
||||
Markdown==3.10
|
||||
markdown-it-py==4.0.0
|
||||
MarkupSafe==2.1.5
|
||||
marshmallow==3.26.1
|
||||
matplotlib==3.10.7
|
||||
matplotlib-inline==0.2.1
|
||||
mcp==1.21.0
|
||||
mdurl==0.1.2
|
||||
mistune==3.1.4
|
||||
mmh3==5.2.0
|
||||
monotonic==1.6
|
||||
mpmath==1.3.0
|
||||
multidict==6.7.0
|
||||
multiprocess==0.70.16
|
||||
multitasking==0.0.12
|
||||
mypy_extensions==1.1.0
|
||||
nbclient==0.10.2
|
||||
nbconvert==7.16.6
|
||||
nbformat==5.10.4
|
||||
nest-asyncio==1.6.0
|
||||
networkx==3.3
|
||||
nltk==3.9.2
|
||||
notebook==7.4.7
|
||||
notebook_shim==0.2.4
|
||||
numpy==2.2.6
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
nvidia-cuda-cupti-cu12==12.4.127
|
||||
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
nvidia-cuda-runtime-cu12==12.4.127
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu12==11.2.1.3
|
||||
nvidia-curand-cu12==10.3.5.147
|
||||
nvidia-cusolver-cu12==11.6.1.9
|
||||
nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-cusparselt-cu12==0.6.2
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
oauthlib==3.3.1
|
||||
onnxruntime==1.23.2
|
||||
openai==2.7.2
|
||||
openpyxl==3.1.5
|
||||
opentelemetry-api==1.38.0
|
||||
opentelemetry-exporter-otlp==1.38.0
|
||||
opentelemetry-exporter-otlp-proto-common==1.38.0
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.38.0
|
||||
opentelemetry-exporter-otlp-proto-http==1.38.0
|
||||
opentelemetry-instrumentation==0.59b0
|
||||
opentelemetry-instrumentation-alephalpha==0.48.0
|
||||
opentelemetry-instrumentation-anthropic==0.48.0
|
||||
opentelemetry-instrumentation-bedrock==0.48.0
|
||||
opentelemetry-instrumentation-chromadb==0.48.0
|
||||
opentelemetry-instrumentation-cohere==0.48.0
|
||||
opentelemetry-instrumentation-crewai==0.48.0
|
||||
opentelemetry-instrumentation-google-generativeai==0.48.0
|
||||
opentelemetry-instrumentation-groq==0.48.0
|
||||
opentelemetry-instrumentation-haystack==0.48.0
|
||||
opentelemetry-instrumentation-lancedb==0.48.0
|
||||
opentelemetry-instrumentation-langchain==0.48.0
|
||||
opentelemetry-instrumentation-llamaindex==0.48.0
|
||||
opentelemetry-instrumentation-logging==0.59b0
|
||||
opentelemetry-instrumentation-marqo==0.48.0
|
||||
opentelemetry-instrumentation-mcp==0.48.0
|
||||
opentelemetry-instrumentation-milvus==0.48.0
|
||||
opentelemetry-instrumentation-mistralai==0.48.0
|
||||
opentelemetry-instrumentation-ollama==0.48.0
|
||||
opentelemetry-instrumentation-openai==0.48.0
|
||||
opentelemetry-instrumentation-openai-agents==0.48.0
|
||||
opentelemetry-instrumentation-pinecone==0.48.0
|
||||
opentelemetry-instrumentation-qdrant==0.48.0
|
||||
opentelemetry-instrumentation-redis==0.59b0
|
||||
opentelemetry-instrumentation-replicate==0.48.0
|
||||
opentelemetry-instrumentation-requests==0.59b0
|
||||
opentelemetry-instrumentation-sagemaker==0.48.0
|
||||
opentelemetry-instrumentation-sqlalchemy==0.59b0
|
||||
opentelemetry-instrumentation-threading==0.59b0
|
||||
opentelemetry-instrumentation-together==0.48.0
|
||||
opentelemetry-instrumentation-transformers==0.48.0
|
||||
opentelemetry-instrumentation-urllib3==0.59b0
|
||||
opentelemetry-instrumentation-vertexai==0.48.0
|
||||
opentelemetry-instrumentation-watsonx==0.48.0
|
||||
opentelemetry-instrumentation-weaviate==0.48.0
|
||||
opentelemetry-instrumentation-writer==0.48.0
|
||||
opentelemetry-proto==1.38.0
|
||||
opentelemetry-sdk==1.38.0
|
||||
opentelemetry-semantic-conventions==0.59b0
|
||||
opentelemetry-semantic-conventions-ai==0.4.13
|
||||
opentelemetry-util-http==0.59b0
|
||||
orjson==3.11.4
|
||||
ormsgpack==1.12.0
|
||||
overrides==7.7.0
|
||||
packaging==25.0
|
||||
pandas==2.3.3
|
||||
pandocfilters==1.5.1
|
||||
parsel==1.10.0
|
||||
parso==0.8.5
|
||||
peewee==3.18.3
|
||||
peft==0.17.1
|
||||
pexpect==4.9.0
|
||||
pillow==11.3.0
|
||||
platformdirs==4.5.0
|
||||
posthog==3.25.0
|
||||
praw==7.8.1
|
||||
prawcore==2.4.0
|
||||
prometheus_client==0.23.1
|
||||
prompt_toolkit==3.0.52
|
||||
propcache==0.4.1
|
||||
proto-plus==1.26.1
|
||||
protobuf==6.33.1
|
||||
psutil==7.1.2
|
||||
ptyprocess==0.7.0
|
||||
pure_eval==0.2.3
|
||||
py-mini-racer==0.6.0
|
||||
pyarrow==22.0.0
|
||||
pyasn1==0.6.1
|
||||
pyasn1_modules==0.4.2
|
||||
pybase64==1.4.2
|
||||
pycocotools==2.0.10
|
||||
pycparser==2.23
|
||||
pydantic==2.12.4
|
||||
pydantic-settings==2.12.0
|
||||
pydantic_core==2.41.5
|
||||
pyfiglet==1.0.4
|
||||
Pygments==2.19.2
|
||||
PyJWT==2.10.1
|
||||
pyparsing==3.2.5
|
||||
PyPika==0.48.9
|
||||
pyproject_hooks==1.2.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.2.1
|
||||
python-engineio==4.12.3
|
||||
python-json-logger==4.0.0
|
||||
python-multipart==0.0.20
|
||||
python-socketio==5.14.3
|
||||
pytz==2025.2
|
||||
PyYAML==6.0.3
|
||||
pyzmq==27.1.0
|
||||
questionary==2.1.1
|
||||
redis==7.0.1
|
||||
referencing==0.37.0
|
||||
regex==2025.9.18
|
||||
requests==2.32.5
|
||||
requests-oauthlib==2.0.0
|
||||
requests-toolbelt==1.0.0
|
||||
rfc3339-validator==0.1.4
|
||||
rfc3986-validator==0.1.1
|
||||
rfc3987-syntax==1.1.0
|
||||
rich==14.2.0
|
||||
rpds-py==0.28.0
|
||||
rsa==4.9.1
|
||||
safetensors==0.6.2
|
||||
scikit-learn==1.7.2
|
||||
scipy==1.15.3
|
||||
seaborn==0.13.2
|
||||
Send2Trash==1.8.3
|
||||
sentence-transformers==5.1.2
|
||||
sentencepiece==0.2.0
|
||||
sgmllib3k==1.0.0
|
||||
shellingham==1.5.4
|
||||
simple-websocket==1.1.0
|
||||
simplejson==3.20.2
|
||||
six==1.17.0
|
||||
sklearn==0.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.8
|
||||
SQLAlchemy==2.0.44
|
||||
sse-starlette==3.0.3
|
||||
stack-data==0.6.3
|
||||
starlette==0.49.3
|
||||
stockstats==0.6.5
|
||||
sympy==1.13.1
|
||||
syncer==2.0.3
|
||||
tabulate==0.9.0
|
||||
tenacity==9.1.2
|
||||
tensorboard==2.20.0
|
||||
tensorboard-data-server==0.7.2
|
||||
terminado==0.18.1
|
||||
threadpoolctl==3.6.0
|
||||
tiktoken==0.12.0
|
||||
tinycss2==1.4.0
|
||||
tokenizers==0.22.1
|
||||
tomli==2.3.0
|
||||
torch==2.6.0+cu124
|
||||
torchaudio==2.6.0+cu124
|
||||
torchvision==0.21.0+cu124
|
||||
tornado==6.5.2
|
||||
tqdm==4.67.1
|
||||
traceloop-sdk==0.48.0
|
||||
traitlets==5.14.3
|
||||
transformers==4.57.1
|
||||
triton==3.2.0
|
||||
tushare==1.4.24
|
||||
typer==0.20.0
|
||||
typing-inspect==0.9.0
|
||||
typing-inspection==0.4.2
|
||||
typing_extensions==4.15.0
|
||||
tzdata==2025.2
|
||||
update-checker==0.18.0
|
||||
uri-template==1.3.0
|
||||
urllib3==1.26.20
|
||||
uvicorn==0.38.0
|
||||
uvloop==0.22.1
|
||||
w3lib==2.3.1
|
||||
watchfiles==0.24.0
|
||||
wcwidth==0.2.14
|
||||
webcolors==24.11.1
|
||||
webencodings==0.5.1
|
||||
websocket-client==1.9.0
|
||||
websockets==15.0.1
|
||||
Werkzeug==3.1.3
|
||||
widgetsnbextension==4.0.14
|
||||
wrapt==1.17.3
|
||||
wsproto==1.3.0
|
||||
xlrd==2.0.2
|
||||
xxhash==3.6.0
|
||||
yarl==1.22.0
|
||||
yfinance==0.2.66
|
||||
zipp==3.23.0
|
||||
zstandard==0.25.0
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["use_dapt_sentiment"] = True
|
||||
config["dapt_adapter_path"] = "/u/v/d/vdhanuka/llama3_8b_dapt_transcripts_lora" # <- set your absolute path
|
||||
config["llm_provider"] = "openai" # provider for the other agents; DAPT is used for News
|
||||
config["backend_url"] = "https://api.openai.com/v1" # unused if DAPT loads fine
|
||||
|
||||
graph = TradingAgentsGraph(selected_analysts=["news","market","social"], config=config, debug=True)
|
||||
_, decision = graph.propagate(company_name="AAPL", trade_date="2024-01-07")
|
||||
print(decision)
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
import time
|
||||
import json
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from datetime import datetime, timedelta
|
||||
from tradingagents.agents.utils.agent_utils import get_news, get_global_news
|
||||
from tradingagents.dataflows.config import get_config
|
||||
|
||||
|
|
@ -10,45 +9,51 @@ def create_news_analyst(llm):
|
|||
current_date = state["trade_date"]
|
||||
ticker = state["company_of_interest"]
|
||||
|
||||
tools = [
|
||||
get_news,
|
||||
get_global_news,
|
||||
# Compute 7-day lookback window
|
||||
try:
|
||||
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
||||
start_dt = end_dt - timedelta(days=7)
|
||||
start_date = start_dt.strftime("%Y-%m-%d")
|
||||
except Exception:
|
||||
# Fallback: use the same day if parsing fails
|
||||
start_date = current_date
|
||||
|
||||
# Fetch company-specific news and global macro news via tools
|
||||
company_news = ""
|
||||
global_news = ""
|
||||
try:
|
||||
company_news = get_news.invoke({"ticker": ticker, "start_date": start_date, "end_date": current_date}) or ""
|
||||
except Exception:
|
||||
company_news = ""
|
||||
try:
|
||||
global_news = get_global_news.invoke({"curr_date": current_date, "look_back_days": 7, "limit": 8}) or ""
|
||||
except Exception:
|
||||
global_news = ""
|
||||
|
||||
# Build a data-grounded instruction and feed fetched data to the LLM
|
||||
system_instruction = (
|
||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. "
|
||||
"Write a comprehensive, data-grounded report relevant for trading and macroeconomics. "
|
||||
"Use the provided fetched news data as primary evidence. "
|
||||
"Do not simply state that trends are mixed. Provide detailed and nuanced insights with implications. "
|
||||
"Append a concise Markdown table at the end summarizing key points.\n\n"
|
||||
f"Context:\n"
|
||||
f"- Current date: {current_date}\n"
|
||||
f"- Company: {ticker}\n\n"
|
||||
f"Fetched company news ({ticker}, {start_date} to {current_date}):\n{company_news}\n\n"
|
||||
f"Fetched global/macro news (last 7 days):\n{global_news}\n"
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=system_instruction),
|
||||
HumanMessage(content=f"Produce the final report for {ticker} using the fetched data above."),
|
||||
]
|
||||
|
||||
system_message = (
|
||||
"You are a news researcher tasked with analyzing recent news and trends over the past week. Please write a comprehensive report of the current state of the world that is relevant for trading and macroeconomics. Use the available tools: get_news(query, start_date, end_date) for company-specific or targeted news searches, and get_global_news(curr_date, look_back_days, limit) for broader macroeconomic news. Do not simply state the trends are mixed, provide detailed and finegrained analysis and insights that may help traders make decisions."
|
||||
+ """ Make sure to append a Markdown table at the end of the report to organize key points in the report, organized and easy to read."""
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"You are a helpful AI assistant, collaborating with other assistants."
|
||||
" Use the provided tools to progress towards answering the question."
|
||||
" If you are unable to fully answer, that's OK; another assistant with different tools"
|
||||
" will help where you left off. Execute what you can to make progress."
|
||||
" If you or any other assistant has the FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** or deliverable,"
|
||||
" prefix your response with FINAL TRANSACTION PROPOSAL: **BUY/HOLD/SELL** so the team knows to stop."
|
||||
" You have access to the following tools: {tool_names}.\n{system_message}"
|
||||
"For your reference, the current date is {current_date}. We are looking at the company {ticker}",
|
||||
),
|
||||
MessagesPlaceholder(variable_name="messages"),
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompt.partial(system_message=system_message)
|
||||
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
|
||||
prompt = prompt.partial(current_date=current_date)
|
||||
prompt = prompt.partial(ticker=ticker)
|
||||
|
||||
chain = prompt | llm.bind_tools(tools)
|
||||
result = chain.invoke(state["messages"])
|
||||
result = llm.invoke(messages)
|
||||
|
||||
report = ""
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
# Use the generated content as the report
|
||||
report = getattr(result, "content", "") or ""
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
|
|
|
|||
|
|
@ -121,21 +121,27 @@ class DAPTLlamaChatModel(BaseChatModel):
|
|||
)
|
||||
return prompt
|
||||
else:
|
||||
# Fallback to manual formatting
|
||||
formatted_parts = []
|
||||
# Fallback to Llama 3 style manual formatting using header tokens
|
||||
# <|begin_of_text|>
|
||||
# <|start_header_id|>system<|end_header_id|>
|
||||
# {content}<|eot_id|> ...
|
||||
bos = "<|begin_of_text|>"
|
||||
start_header = "<|start_header_id|>"
|
||||
end_header = "<|end_header_id|>"
|
||||
eot = "<|eot_id|>"
|
||||
parts: List[str] = [bos]
|
||||
for msg in messages:
|
||||
if isinstance(msg, SystemMessage):
|
||||
formatted_parts.append(f"<|system|>\n{msg.content}<|end|>\n")
|
||||
parts.append(f"{start_header}system{end_header}\n{msg.content}{eot}\n")
|
||||
elif isinstance(msg, HumanMessage):
|
||||
formatted_parts.append(f"<|user|>\n{msg.content}<|end|>\n")
|
||||
parts.append(f"{start_header}user{end_header}\n{msg.content}{eot}\n")
|
||||
elif isinstance(msg, AIMessage):
|
||||
formatted_parts.append(f"<|assistant|>\n{msg.content}<|end|>\n")
|
||||
parts.append(f"{start_header}assistant{end_header}\n{msg.content}{eot}\n")
|
||||
else:
|
||||
formatted_parts.append(f"{msg.content}\n")
|
||||
|
||||
# Add assistant prompt
|
||||
formatted_parts.append("<|assistant|>\n")
|
||||
return "".join(formatted_parts)
|
||||
parts.append(f"{start_header}user{end_header}\n{str(msg.content)}{eot}\n")
|
||||
# Add assistant header to cue generation
|
||||
parts.append(f"{start_header}assistant{end_header}\n")
|
||||
return "".join(parts)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
|
@ -219,5 +225,4 @@ class DAPTLlamaWithTools(Runnable):
|
|||
else:
|
||||
messages = [input] if isinstance(input, BaseMessage) else [HumanMessage(content=str(input))]
|
||||
|
||||
return self.llm._invoke(messages, **kwargs)
|
||||
|
||||
return self.llm._invoke(messages, **kwargs)
|
||||
|
|
@ -35,6 +35,8 @@ def get_stock_news_openai(query, start_date, end_date):
|
|||
top_p=1,
|
||||
store=True,
|
||||
)
|
||||
# print("2:OpenAI call completed")
|
||||
# print(response)
|
||||
|
||||
return response.output[1].content[0].text
|
||||
|
||||
|
|
@ -92,7 +94,8 @@ def get_global_news_openai(curr_date, look_back_days=7, limit=5):
|
|||
max_output_tokens=4096,
|
||||
store=False,
|
||||
)
|
||||
|
||||
# print("1:OpenAI call completed")
|
||||
# print(resp)
|
||||
return _extract_text(resp)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,9 +25,10 @@ DEFAULT_CONFIG = {
|
|||
# Sentiment analysis model (DAPTed Llama 3.1 8B)
|
||||
"use_dapt_sentiment": True, # Use DAPTed model for sentiment analysis (set False to use OpenAI backup)
|
||||
# Path to DAPT PEFT adapter (dynamically uses current username)
|
||||
"dapt_adapter_path": f"/u/v/d/{os.getenv('USER', 'negi3')}/llama3_8b_dapt_transcripts_lora",
|
||||
"dapt_adapter_path": f"/u/v/d/{os.getenv('USER', 'vdhanuka')}/llama3_8b_dapt_transcripts_lora",
|
||||
|
||||
# Fallback: OpenAI model if DAPT is unavailable
|
||||
"sentiment_fallback_llm": "o1-mini", # OpenAI model for fallback
|
||||
"sentiment_fallback_llm": "o4-mini", # OpenAI model for fallback
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
|
|
@ -42,8 +43,8 @@ DEFAULT_CONFIG = {
|
|||
},
|
||||
# Tool-level configuration (takes precedence over category-level)
|
||||
"tool_vendors": {
|
||||
# "get_stock_data": "alpha_vantage", # Override category default
|
||||
"get_news": "alpha_vantage",
|
||||
"get_stock_data": "openai", # Override category default
|
||||
"get_news": "openai",
|
||||
"get_global_news" :"openai" # Override category default
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class TradingAgentsGraph:
|
|||
# Initialize sentiment analysis LLM (DAPTed Llama 3.1 8B)
|
||||
if self.config.get("use_dapt_sentiment", True):
|
||||
# Use DAPTed model directly
|
||||
dapt_path = self.config.get("dapt_adapter_path", f"/u/v/d/{os.getenv('USER', 'negi3')}/llama3_8b_dapt_transcripts_lora")
|
||||
dapt_path = self.config.get("dapt_adapter_path", f"/u/v/d/{os.getenv('USER', 'vdhanuka')}/llama3_8b_dapt_transcripts_lora")
|
||||
# Convert relative path to absolute if needed
|
||||
if not os.path.isabs(dapt_path):
|
||||
dapt_path = os.path.join(self.config["project_dir"], dapt_path)
|
||||
|
|
|
|||
Loading…
Reference in New Issue