WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 124 additions & 68 deletions atomic-agents/atomic_agents/context/chat_history.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,109 @@
from __future__ import annotations

import json
import uuid
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any

from instructor.multimodal import PDF, Image, Audio
from instructor.processing.multimodal import PDF, Image, Audio
from pydantic import BaseModel, Field

from atomic_agents.base.base_io_schema import BaseIOSchema

if TYPE_CHECKING:
from typing import Type


MULTIMODAL_TYPES = (Image, Audio, PDF)


@dataclass
class MultimodalContent:
"""Result of extracting multimodal content from nested structures."""

objects: list = field(default_factory=list)
json_data: Any = None

@property
def has_multimodal(self) -> bool:
return len(self.objects) > 0


INSTRUCTOR_MULTIMODAL_TYPES = (Image, Audio, PDF)
def _extract_multimodal_content(obj: Any, _seen: set[int] | None = None) -> MultimodalContent:
"""
Single-pass extraction of multimodal content from nested structures.

Returns both the multimodal objects and a JSON-serializable representation
with multimodal content removed.
"""
if _seen is None:
_seen = set()

match obj:
case Image() | Audio() | PDF():
return MultimodalContent(objects=[obj], json_data=None)

case list():
if id(obj) in _seen:
return MultimodalContent()
_seen.add(id(obj))

all_objects = []
json_items = []
for item in obj:
result = _extract_multimodal_content(item, _seen)
all_objects.extend(result.objects)
if result.json_data is not None:
json_items.append(result.json_data)

return MultimodalContent(
objects=all_objects,
json_data=json_items or None,
)

case dict():
if id(obj) in _seen:
return MultimodalContent()
_seen.add(id(obj))

all_objects = []
json_dict = {}
for key, value in obj.items():
result = _extract_multimodal_content(value, _seen)
all_objects.extend(result.objects)
if result.json_data is not None:
json_dict[key] = result.json_data

return MultimodalContent(
objects=all_objects,
json_data=json_dict or None,
)

case BaseModel():
if id(obj) in _seen:
return MultimodalContent()
_seen.add(id(obj))

all_objects = []
json_dict = {}
for field_name in type(obj).model_fields:
result = _extract_multimodal_content(getattr(obj, field_name), _seen)
all_objects.extend(result.objects)
if result.json_data is not None:
json_dict[field_name] = result.json_data

return MultimodalContent(
objects=all_objects,
json_data=json_dict or None,
)

case _ if hasattr(obj, "model_dump"):
return MultimodalContent(json_data=obj.model_dump())

case _:
return MultimodalContent(json_data=obj)


class Message(BaseModel):
Expand All @@ -25,30 +118,29 @@ class Message(BaseModel):

role: str
content: BaseIOSchema
turn_id: Optional[str] = None
turn_id: str | None = None


class ChatHistory:
"""
Manages the chat history for an AI agent.

Attributes:
history (List[Message]): A list of messages representing the chat history.
max_messages (Optional[int]): Maximum number of messages to keep in history.
current_turn_id (Optional[str]): The ID of the current turn.
history: A list of messages representing the chat history.
max_messages: Maximum number of messages to keep in history.
current_turn_id: The ID of the current turn.
"""

def __init__(self, max_messages: Optional[int] = None):
def __init__(self, max_messages: int | None = None):
"""
Initializes the ChatHistory with an empty history and optional constraints.

Args:
max_messages (Optional[int]): Maximum number of messages to keep in history.
When exceeded, oldest messages are removed first.
max_messages: Maximum number of messages to keep. Oldest removed first.
"""
self.history: List[Message] = []
self.history: list[Message] = []
self.max_messages = max_messages
self.current_turn_id: Optional[str] = None
self.current_turn_id: str | None = None

def initialize_turn(self) -> None:
"""
Expand Down Expand Up @@ -87,71 +179,35 @@ def _manage_overflow(self) -> None:
while len(self.history) > self.max_messages:
self.history.pop(0)

def get_history(self) -> List[Dict]:
def get_history(self) -> list[dict]:
"""
Retrieves the chat history, handling both regular and multimodal content.

This method supports multimodal content (Image, Audio, PDF) including when
nested within other schemas. Multimodal objects are kept separate from
the JSON serialization to allow Instructor to handle them appropriately.

Returns:
List[Dict]: The list of messages in the chat history as dictionaries.
Each dictionary has 'role' and 'content' keys, where 'content' contains
either a single JSON string or a mixed array of JSON and multimodal objects.

Note:
This method supports multimodal content by keeping multimodal objects
separate while generating cohesive JSON for text-based fields.
This method supports nested multimodal content by recursively detecting
and extracting multimodal objects from any level of the schema hierarchy.
"""
history = []
for message in self.history:
input_content = message.content

# Check if content has any multimodal fields
multimodal_objects = []
has_multimodal = False

# Extract multimodal content first
for field_name, field in input_content.__class__.model_fields.items():
field_value = getattr(input_content, field_name)

if isinstance(field_value, list):
for item in field_value:
if isinstance(item, INSTRUCTOR_MULTIMODAL_TYPES):
multimodal_objects.append(item)
has_multimodal = True
elif isinstance(field_value, INSTRUCTOR_MULTIMODAL_TYPES):
multimodal_objects.append(field_value)
has_multimodal = True

if has_multimodal:
# For multimodal content: create mixed array with JSON + multimodal objects
processed_content = []

# Add single cohesive JSON for all non-multimodal fields
non_multimodal_data = {}
for field_name, field in input_content.__class__.model_fields.items():
field_value = getattr(input_content, field_name)

if isinstance(field_value, list):
# Only include non-multimodal items from lists
non_multimodal_items = [
item for item in field_value if not isinstance(item, INSTRUCTOR_MULTIMODAL_TYPES)
]
if non_multimodal_items:
non_multimodal_data[field_name] = non_multimodal_items
elif not isinstance(field_value, INSTRUCTOR_MULTIMODAL_TYPES):
non_multimodal_data[field_name] = field_value

# Add single JSON string if there are non-multimodal fields
if non_multimodal_data:
processed_content.append(json.dumps(non_multimodal_data, ensure_ascii=False))

# Add all multimodal objects
processed_content.extend(multimodal_objects)

history.append({"role": message.role, "content": processed_content})
extracted = _extract_multimodal_content(message.content)

if extracted.has_multimodal:
content = []
if extracted.json_data:
content.append(json.dumps(extracted.json_data, ensure_ascii=False))
content.extend(extracted.objects)
history.append({"role": message.role, "content": content})
else:
# No multimodal content: generate single cohesive JSON string
content_json = input_content.model_dump_json()
history.append({"role": message.role, "content": content_json})
history.append({"role": message.role, "content": message.content.model_dump_json()})

return history

Expand All @@ -167,7 +223,7 @@ def copy(self) -> "ChatHistory":
new_history.current_turn_id = self.current_turn_id
return new_history

def get_current_turn_id(self) -> Optional[str]:
def get_current_turn_id(self) -> str | None:
"""
Returns the current turn ID.

Expand Down Expand Up @@ -352,7 +408,7 @@ class MultimodalSchema(BaseIOSchema):
"""Schema for testing multimodal content"""

instruction_text: str = Field(..., description="The instruction text")
images: List[instructor.Image] = Field(..., description="The images to analyze")
images: list[instructor.Image] = Field(..., description="The images to analyze")

# Create and populate the original history with complex data
original_history = ChatHistory(max_messages=10)
Expand Down Expand Up @@ -409,8 +465,8 @@ class MultimodalSchema(BaseIOSchema):
print(f"Turn ID: {message.turn_id}")
print(f"Content type: {type(message.content).__name__}")
print("Content:")
for field, value in message.content.model_dump().items():
print(f" {field}: {value}")
for field_name, value in message.content.model_dump().items():
print(f" {field_name}: {value}")

# Final verification
print("\nFinal verification:")
Expand Down
Loading