WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,15 @@ class InvocationContext(BaseModel):
canonical_tools_cache: Optional[list[BaseTool]] = None
"""The cache of canonical tools for this invocation."""

metadata: Optional[dict[str, Any]] = None
"""Per-request metadata passed from Runner.run_async().
This field allows passing arbitrary metadata that can be accessed during
the invocation lifecycle, particularly in callbacks like before_model_callback.
Common use cases include passing user_id, trace_id, memory context keys, or
other request-specific context that needs to be available during processing.
"""

_invocation_cost_manager: _InvocationCostManager = PrivateAttr(
default_factory=_InvocationCostManager
)
Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def run_live(
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Runs the flow using live api."""
llm_request = LlmRequest()
llm_request = LlmRequest(metadata=invocation_context.metadata)
event_id = Event.new_id()

# Preprocess before calling the LLM.
Expand Down Expand Up @@ -380,7 +380,7 @@ async def _run_one_step_async(
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""One step means one LLM call."""
llm_request = LlmRequest()
llm_request = LlmRequest(metadata=invocation_context.metadata)

# Preprocess before calling the LLM.
async with Aclosing(
Expand Down
10 changes: 10 additions & 0 deletions src/google/adk/models/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import logging
from typing import Any
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -99,6 +100,15 @@ class LlmRequest(BaseModel):
the full history.
"""

metadata: Optional[dict[str, Any]] = None
"""Per-request metadata for callbacks and custom processing.

This field allows passing arbitrary metadata from the Runner.run_async()
call to callbacks like before_model_callback. This is useful for passing
request-specific context such as user_id, trace_id, or memory context keys
that need to be available during model invocation.
"""

def append_instructions(
self, instructions: Union[list[str], types.Content]
) -> list[types.Content]:
Expand Down
22 changes: 21 additions & 1 deletion src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ async def run_async(
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
metadata: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Event, None]:
"""Main entry method to run the agent in this runner.

Expand All @@ -417,6 +418,9 @@ async def run_async(
new_message: A new message to append to the session.
state_delta: Optional state changes to apply to the session.
run_config: The run config for the agent.
metadata: Optional per-request metadata that will be passed to callbacks.
This allows passing request-specific context such as user_id, trace_id,
or memory context keys to before_model_callback and other callbacks.
Comment on lines +434 to +436
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent potential subtle bugs, it's a good practice to clarify the copy behavior of the metadata dictionary in the docstring. Since a shallow copy is performed, modifications to nested mutable objects within a callback will affect the original object passed by the caller. Please add a note about this to help users of the API understand this behavior and avoid unexpected side effects. For example, you could add: Note: A shallow copy is made of this dictionary, so changes to nested mutable objects will affect the original object.


Yields:
The events generated by the agent.
Expand All @@ -426,13 +430,16 @@ async def run_async(
new_message are None.
"""
run_config = run_config or RunConfig()
# Create a shallow copy to isolate from caller's modifications
metadata = metadata.copy() if metadata else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check if metadata is truthy-based, which means an empty dictionary {} passed by the user will be incorrectly converted to None. This can lead to unexpected behavior if the user intends to pass an empty, mutable metadata object that might be populated by callbacks. You should use if metadata is not None to correctly handle an empty dictionary while still creating a shallow copy.

Suggested change
metadata = metadata.copy() if metadata else None
metadata = metadata.copy() if metadata is not None else None


if new_message and not new_message.role:
new_message.role = 'user'

async def _run_with_trace(
new_message: Optional[types.Content] = None,
invocation_id: Optional[str] = None,
metadata: Optional[dict[str, Any]] = None,
) -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session(
Expand Down Expand Up @@ -463,6 +470,7 @@ async def _run_with_trace(
invocation_id=invocation_id,
run_config=run_config,
state_delta=state_delta,
metadata=metadata,
)
if invocation_context.end_of_agents.get(
invocation_context.agent.name
Expand All @@ -476,6 +484,7 @@ async def _run_with_trace(
new_message=new_message, # new_message is not None.
run_config=run_config,
state_delta=state_delta,
metadata=metadata,
)

async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
Expand All @@ -502,7 +511,9 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
self.app, session, self.session_service
)

async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen:
async with Aclosing(
_run_with_trace(new_message, invocation_id, metadata)
) as agen:
async for event in agen:
yield event

Expand Down Expand Up @@ -1186,6 +1197,7 @@ async def _setup_context_for_new_invocation(
new_message: types.Content,
run_config: RunConfig,
state_delta: Optional[dict[str, Any]],
metadata: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Sets up the context for a new invocation.

Expand All @@ -1194,6 +1206,7 @@ async def _setup_context_for_new_invocation(
new_message: The new message to process and append to the session.
run_config: The run config of the agent.
state_delta: Optional state changes to apply to the session.
metadata: Optional per-request metadata to pass to callbacks.

Returns:
The invocation context for the new invocation.
Expand All @@ -1203,6 +1216,7 @@ async def _setup_context_for_new_invocation(
session,
new_message=new_message,
run_config=run_config,
metadata=metadata,
)
# Step 2: Handle new message, by running callbacks and appending to
# session.
Expand All @@ -1225,6 +1239,7 @@ async def _setup_context_for_resumed_invocation(
invocation_id: Optional[str],
run_config: RunConfig,
state_delta: Optional[dict[str, Any]],
metadata: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Sets up the context for a resumed invocation.

Expand All @@ -1234,6 +1249,7 @@ async def _setup_context_for_resumed_invocation(
invocation_id: The invocation id to resume.
run_config: The run config of the agent.
state_delta: Optional state changes to apply to the session.
metadata: Optional per-request metadata to pass to callbacks.

Returns:
The invocation context for the resumed invocation.
Expand All @@ -1259,6 +1275,7 @@ async def _setup_context_for_resumed_invocation(
new_message=user_message,
run_config=run_config,
invocation_id=invocation_id,
metadata=metadata,
)
# Step 3: Maybe handle new message.
if new_message:
Expand Down Expand Up @@ -1303,6 +1320,7 @@ def _new_invocation_context(
new_message: Optional[types.Content] = None,
live_request_queue: Optional[LiveRequestQueue] = None,
run_config: Optional[RunConfig] = None,
metadata: Optional[dict[str, Any]] = None,
) -> InvocationContext:
"""Creates a new invocation context.

Expand All @@ -1312,6 +1330,7 @@ def _new_invocation_context(
new_message: The new message for the context.
live_request_queue: The live request queue for the context.
run_config: The run config for the context.
metadata: Optional per-request metadata for the context.

Returns:
The new invocation context.
Expand Down Expand Up @@ -1343,6 +1362,7 @@ def _new_invocation_context(
live_request_queue=live_request_queue,
run_config=run_config,
resumability_config=self.resumability_config,
metadata=metadata,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent accidental modification of the original metadata dictionary by the caller of run_async, it's a good practice to work with a copy of the metadata. Since dictionaries are mutable, any changes made to metadata within the runner's logic would also affect the caller's original dictionary. Creating a shallow copy here isolates the runner's execution context from the caller. This is especially important as run_async is an async generator, and the caller might modify the metadata dictionary while iterating over the yielded events.

Suggested change
metadata=metadata,
metadata=metadata.copy() if metadata is not None else None,

)

def _new_invocation_context_for_live(
Expand Down
139 changes: 139 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.cli.utils.agent_loader import AgentLoader
from google.adk.events.event import Event
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
Expand Down Expand Up @@ -1038,5 +1040,142 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent(
assert "actual_name" in runner._app_name_alignment_hint


class TestRunnerMetadata:
"""Tests for Runner metadata parameter functionality."""
Comment on lines +1103 to +1104
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test suite for metadata is comprehensive for data propagation. It would be beneficial to also add a test case that explicitly verifies the behavior of the shallow copy of the metadata dictionary.
Specifically, a test could:

  1. Pass a metadata dictionary with a nested mutable object (e.g., a dictionary or a list).
  2. Modify both a top-level value and a value within the nested object inside a callback.
  3. Assert that the top-level change does not affect the original dictionary passed by the caller.
  4. Assert that the change to the nested object does affect the original object.

This would ensure that the isolation behavior is well-understood and prevent future regressions.


def setup_method(self):
"""Set up test fixtures."""
self.session_service = InMemorySessionService()
self.artifact_service = InMemoryArtifactService()
self.root_agent = MockLlmAgent("root_agent")
self.runner = Runner(
app_name="test_app",
agent=self.root_agent,
session_service=self.session_service,
artifact_service=self.artifact_service,
)

def test_new_invocation_context_with_metadata(self):
"""Test that _new_invocation_context correctly passes metadata."""
mock_session = Session(
id=TEST_SESSION_ID,
app_name=TEST_APP_ID,
user_id=TEST_USER_ID,
events=[],
)

test_metadata = {"user_id": "test123", "trace_id": "trace456"}
invocation_context = self.runner._new_invocation_context(
mock_session, metadata=test_metadata
)

assert invocation_context.metadata == test_metadata
assert invocation_context.metadata["user_id"] == "test123"
assert invocation_context.metadata["trace_id"] == "trace456"

def test_new_invocation_context_without_metadata(self):
"""Test that _new_invocation_context works without metadata."""
mock_session = Session(
id=TEST_SESSION_ID,
app_name=TEST_APP_ID,
user_id=TEST_USER_ID,
events=[],
)

invocation_context = self.runner._new_invocation_context(mock_session)

assert invocation_context.metadata is None

@pytest.mark.asyncio
async def test_run_async_passes_metadata_to_invocation_context(self):
"""Test that run_async correctly passes metadata to before_model_callback."""
# Capture metadata received in callback
captured_metadata = None

def before_model_callback(callback_context, llm_request):
nonlocal captured_metadata
captured_metadata = llm_request.metadata
# Return a response to skip actual LLM call
return LlmResponse(
content=types.Content(
role="model", parts=[types.Part(text="Test response")]
)
)

# Create agent with before_model_callback
agent_with_callback = LlmAgent(
name="callback_agent",
model="gemini-2.0-flash",
before_model_callback=before_model_callback,
)

runner_with_callback = Runner(
app_name="test_app",
agent=agent_with_callback,
session_service=self.session_service,
artifact_service=self.artifact_service,
)

session = await self.session_service.create_session(
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
)

test_metadata = {"experiment_id": "exp-001", "variant": "B"}

async for event in runner_with_callback.run_async(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
new_message=types.Content(
role="user", parts=[types.Part(text="Hello")]
),
metadata=test_metadata,
):
pass

# Verify metadata was passed to before_model_callback
assert captured_metadata is not None
assert captured_metadata == test_metadata
assert captured_metadata["experiment_id"] == "exp-001"
assert captured_metadata["variant"] == "B"

def test_metadata_field_in_invocation_context(self):
"""Test that InvocationContext model accepts metadata field."""
mock_session = Session(
id=TEST_SESSION_ID,
app_name=TEST_APP_ID,
user_id=TEST_USER_ID,
events=[],
)

test_metadata = {"key1": "value1", "key2": 123}

# This should not raise a validation error
invocation_context = InvocationContext(
session_service=self.session_service,
invocation_id="test_inv_id",
agent=self.root_agent,
session=mock_session,
metadata=test_metadata,
)

assert invocation_context.metadata == test_metadata

def test_metadata_field_in_llm_request(self):
"""Test that LlmRequest model accepts metadata field."""
test_metadata = {"context_key": "ctx123", "user_info": {"name": "test"}}

llm_request = LlmRequest(metadata=test_metadata)

assert llm_request.metadata == test_metadata
assert llm_request.metadata["context_key"] == "ctx123"
assert llm_request.metadata["user_info"]["name"] == "test"

def test_llm_request_without_metadata(self):
"""Test that LlmRequest works without metadata."""
llm_request = LlmRequest()

assert llm_request.metadata is None


if __name__ == "__main__":
pytest.main([__file__])