WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions libs/partners/xai/langchain_xai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast

import openai
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.utils import secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from langchain_openai.chat_models.base import (
BaseChatOpenAI,
_convert_from_v1_to_chat_completions,
)
from langchain_openai.chat_models.base import (
_convert_message_to_dict as _openai_convert_message_to_dict,
)
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

Expand Down Expand Up @@ -37,6 +43,36 @@ def _get_default_model_profile(model_name: str) -> ModelProfile:
return default.copy()


def _convert_message_to_dict_xai(
message: BaseMessage,
api: Literal["chat/completions", "responses"] = "chat/completions",
) -> dict:
"""Convert a LangChain message to dictionary format expected by xAI.

xAI's API requires that all messages have at least an empty content field,
unlike OpenAI which allows None for messages with tool_calls.

Args:
message: The LangChain message to convert.
api: The API format to use. Defaults to "chat/completions".

Returns:
Dictionary representation of the message compatible with xAI's API.
"""
message_dict = _openai_convert_message_to_dict(message, api=api)

# xAI requires content to be at least an empty string, not None
# This is especially important for AIMessages with tool_calls but no content
if (
isinstance(message, AIMessage)
and message_dict.get("content") is None
and ("tool_calls" in message_dict or "function_call" in message_dict)
):
message_dict["content"] = ""

return message_dict


class ChatXAI(BaseChatOpenAI): # type: ignore[override]
r"""ChatXAI chat model.

Expand Down Expand Up @@ -547,6 +583,42 @@ def _default_params(self) -> dict[str, Any]:

return params

def _get_request_payload(
self,
input_: LanguageModelInput,
*,
stop: list[str] | None = None,
**kwargs: Any, # noqa: ANN401
) -> dict:
"""Prepare the request payload for xAI's API.

Overrides the base implementation to use xAI-specific message conversion
that ensures all messages have at least an empty content field.

Args:
input_: The input to convert into messages.
stop: List of stop sequences. Defaults to None.
**kwargs: Additional keyword arguments to pass to the API.

Returns:
Dictionary containing the request payload for xAI's API.
"""
messages = self._convert_input(input_).to_messages()
if stop is not None:
kwargs["stop"] = stop

payload = {**self._default_params, **kwargs}

# Use xAI-specific message converter for all message conversion
# xAI requires all messages to have at least empty content field
payload["messages"] = [
_convert_message_to_dict_xai(_convert_from_v1_to_chat_completions(m))
if isinstance(m, AIMessage)
else _convert_message_to_dict_xai(m)
for m in messages
]
return payload

def _create_chat_result(
self,
response: dict | openai.BaseModel,
Expand Down
Loading