patched the Copilot client to sanitize malformed usage metadata before langchain_openai processes it
This commit is contained in:
parent
1cb6d4e882
commit
57a2ae5438
|
|
@ -8,6 +8,7 @@ No env var or separate auth module needed — run ``gh auth login`` once.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
@ -117,10 +118,51 @@ def check_copilot_auth() -> bool:
|
||||||
class NormalizedChatOpenAI(ChatOpenAI):
|
class NormalizedChatOpenAI(ChatOpenAI):
|
||||||
"""ChatOpenAI with normalized content output."""
|
"""ChatOpenAI with normalized content output."""
|
||||||
|
|
||||||
|
def _create_chat_result(self, response, generation_info=None):
|
||||||
|
return super()._create_chat_result(
|
||||||
|
_sanitize_copilot_response(response), generation_info
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, input, config=None, **kwargs):
|
def invoke(self, input, config=None, **kwargs):
|
||||||
return normalize_content(super().invoke(input, config, **kwargs))
|
return normalize_content(super().invoke(input, config, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_copilot_response(response: Any) -> Any:
|
||||||
|
"""Normalize Copilot token usage fields for langchain_openai.
|
||||||
|
|
||||||
|
Copilot can return ``service_tier`` along with ``None`` values in
|
||||||
|
``cached_tokens`` or ``reasoning_tokens``. ``langchain_openai`` subtracts
|
||||||
|
those fields from the prompt/completion totals, which raises ``TypeError``
|
||||||
|
when the detail value is ``None``.
|
||||||
|
"""
|
||||||
|
if isinstance(response, dict):
|
||||||
|
response_dict = deepcopy(response)
|
||||||
|
elif hasattr(response, "model_dump"):
|
||||||
|
response_dict = response.model_dump()
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
|
usage = response_dict.get("usage")
|
||||||
|
if not isinstance(usage, dict):
|
||||||
|
return response_dict
|
||||||
|
|
||||||
|
if response_dict.get("service_tier") not in {"priority", "flex"}:
|
||||||
|
return response_dict
|
||||||
|
|
||||||
|
prompt_details = usage.get("prompt_tokens_details")
|
||||||
|
if isinstance(prompt_details, dict) and prompt_details.get("cached_tokens") is None:
|
||||||
|
prompt_details["cached_tokens"] = 0
|
||||||
|
|
||||||
|
completion_details = usage.get("completion_tokens_details")
|
||||||
|
if (
|
||||||
|
isinstance(completion_details, dict)
|
||||||
|
and completion_details.get("reasoning_tokens") is None
|
||||||
|
):
|
||||||
|
completion_details["reasoning_tokens"] = 0
|
||||||
|
|
||||||
|
return response_dict
|
||||||
|
|
||||||
|
|
||||||
class CopilotClient(BaseLLMClient):
|
class CopilotClient(BaseLLMClient):
|
||||||
"""Client for GitHub Copilot inference API.
|
"""Client for GitHub Copilot inference API.
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue