From 7b16af0fba78ec161a72beccf65db8774b227a62 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 2 Dec 2025 11:41:20 -0600 Subject: [PATCH 01/10] feat(go): Add cancelAction and early trace ID to headers in go reflection --- go/ai/evaluator.go | 2 +- go/ai/generate.go | 2 +- go/core/action.go | 12 +- go/core/api/action.go | 7 +- go/core/flow.go | 6 +- go/core/schemas.config | 3 + go/core/tracing/tracing.go | 8 + go/genkit/reflection.go | 224 ++++++++++++-- go/genkit/reflection_test.go | 327 +++++++++++++++++++- go/internal/cmd/jsonschemagen/jsonschema.go | 2 +- go/samples/flow-sample1/main.go | 117 +++++++ 11 files changed, 673 insertions(+), 37 deletions(-) diff --git a/go/ai/evaluator.go b/go/ai/evaluator.go index aa536fac9b..bf941f20fc 100644 --- a/go/ai/evaluator.go +++ b/go/ai/evaluator.go @@ -201,7 +201,7 @@ func NewEvaluator(name string, opts *EvaluatorOptions, fn EvaluatorFunc) Evaluat Type: "evaluator", Subtype: "evaluator", } - _, err := tracing.RunInNewSpan(ctx, spanMetadata, datapoint, + _, err := tracing.RunInNewSpan(ctx, spanMetadata, datapoint, nil, func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) { traceId := trace.SpanContextFromContext(ctx).TraceID().String() spanId := trace.SpanContextFromContext(ctx).SpanID().String() diff --git a/go/ai/generate.go b/go/ai/generate.go index d0743de7ac..61bfe14be0 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -330,7 +330,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Subtype: "util", } - return tracing.RunInNewSpan(ctx, spanMetadata, req, func(ctx context.Context, req *ModelRequest) (*ModelResponse, error) { + return tracing.RunInNewSpan(ctx, spanMetadata, req, nil, func(ctx context.Context, req *ModelRequest) (*ModelResponse, error) { var wrappedCb ModelStreamCallback currentRole := RoleModel currentIndex := messageIndex diff --git a/go/core/action.go b/go/core/action.go index 45d71e3177..cf598b19f3 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -176,7 +176,7 @@ func (a *ActionDef[In, Out, Stream]) Name() string { return a.desc.Name } // Run executes the Action's function in a new trace span. func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { - r, err := a.runWithTelemetry(ctx, input, cb) + r, err := a.runWithTelemetry(ctx, input, cb, nil) if err != nil { return base.Zero[Out](), err } @@ -184,7 +184,7 @@ func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb Strea } // Run executes the Action's function in a new trace span. -func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { +func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream], telemetryCb func(traceID, spanID string)) (output api.ActionRunResult[Out], err error) { inputBytes, _ := json.Marshal(input) logger.FromContext(ctx).Debug("Action.Run", "name", a.Name(), @@ -215,7 +215,7 @@ func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input var traceID string var spanID string - o, err := tracing.RunInNewSpan(ctx, spanMetadata, input, + o, err := tracing.RunInNewSpan(ctx, spanMetadata, input, telemetryCb, func(ctx context.Context, input In) (Out, error) { traceInfo := tracing.SpanTraceInfo(ctx) traceID = traceInfo.TraceID @@ -253,7 +253,7 @@ func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input // RunJSON runs the action with a JSON input, and returns a JSON result. func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { - r, err := a.RunJSONWithTelemetry(ctx, input, cb) + r, err := a.RunJSONWithTelemetry(ctx, input, cb, nil) if err != nil { return nil, err } @@ -261,7 +261,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.Raw } // RunJSON runs the action with a JSON input, and returns a JSON result along with telemetry info. -func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { +func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], telemetryCb api.TelemetryCallback) (*api.ActionRunResult[json.RawMessage], error) { i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema) if err != nil { return nil, NewError(INVALID_ARGUMENT, err.Error()) @@ -278,7 +278,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, i } } - r, err := a.runWithTelemetry(ctx, i, scb) + r, err := a.runWithTelemetry(ctx, i, scb, telemetryCb) if err != nil { return &api.ActionRunResult[json.RawMessage]{ TraceId: r.TraceId, diff --git a/go/core/api/action.go b/go/core/api/action.go index 3cfd2689f1..51b2d804e9 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -27,6 +27,10 @@ type ActionRunResult[T any] struct { SpanId string } +// TelemetryCallback is called when telemetry information becomes available. +// It receives the trace ID and span ID as soon as the span is created. +type TelemetryCallback func(traceID, spanID string) + // Action is the interface that all Genkit primitives (e.g. flows, models, tools) have in common. type Action interface { Registerable @@ -35,7 +39,8 @@ type Action interface { // RunJSON runs the action with the given JSON input and streaming callback and returns the output as JSON. RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) // RunJSONWithTelemetry runs the action with the given JSON input and streaming callback and returns the output as JSON along with telemetry info. - RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (*ActionRunResult[json.RawMessage], error) + // The telemetryCb callback, if provided, is called as soon as the trace span is created with the trace ID and span ID. + RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, telemetryCb TelemetryCallback) (*ActionRunResult[json.RawMessage], error) // Desc returns a descriptor of the action. Desc() ActionDesc } diff --git a/go/core/flow.go b/go/core/flow.go index b5311bbbf3..ce8dd9eea8 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -93,7 +93,7 @@ func Run[Out any](ctx context.Context, name string, fn func() (Out, error)) (Out Type: "flowStep", Subtype: "flowStep", } - return tracing.RunInNewSpan(ctx, spanMetadata, nil, func(ctx context.Context, _ any) (Out, error) { + return tracing.RunInNewSpan(ctx, spanMetadata, nil, nil, func(ctx context.Context, _ any) (Out, error) { o, err := fn() if err != nil { return base.Zero[Out](), err @@ -113,8 +113,8 @@ func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessa } // RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON. -func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { - return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb) +func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], telemetryCb api.TelemetryCallback) (*api.ActionRunResult[json.RawMessage], error) { + return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb, telemetryCb) } // Desc returns the descriptor of the flow. diff --git a/go/core/schemas.config b/go/core/schemas.config index 7598011f17..fba6b938ec 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -26,6 +26,9 @@ SpanStatus omit TimeEvent omit TimeEventAnnotation omit TraceData omit +SpanStartEvent omit +SpanEndEvent omit +TraceEvent omit GenerationCommonConfig.maxOutputTokens type int GenerationCommonConfig.topK type int diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index bcd05aac19..bf7025d549 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -173,10 +173,12 @@ type SpanMetadata struct { // RunInNewSpan runs f on input in a new span with the provided metadata. // The metadata contains all span configuration including name, type, labels, etc. +// If telemetryCb is provided, it will be called with the trace ID and span ID as soon as the span is created. func RunInNewSpan[I, O any]( ctx context.Context, metadata *SpanMetadata, input I, + telemetryCb func(traceID, spanID string), f func(context.Context, I) (O, error), ) (O, error) { // TODO: support span links. @@ -239,6 +241,12 @@ func RunInNewSpan[I, O any]( TraceID: span.SpanContext().TraceID().String(), SpanID: span.SpanContext().SpanID().String(), } + + // Fire telemetry callback immediately if provided + if telemetryCb != nil { + telemetryCb(sm.TraceInfo.TraceID, sm.TraceInfo.SpanID) + } + defer span.End() defer func() { span.SetAttributes(sm.attributes()...) }() ctx = spanMetaKey.NewContext(ctx, sm) diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index bb32c79bff..b5970070ba 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -19,6 +19,7 @@ package genkit import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "net" @@ -28,6 +29,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" "github.com/firebase/genkit/go/core" @@ -52,7 +54,46 @@ type runtimeFileData struct { // reflectionServer encapsulates everything needed to serve the Reflection API. type reflectionServer struct { *http.Server - RuntimeFilePath string // Path to the runtime file that was written at startup. + RuntimeFilePath string // Path to the runtime file that was written at startup. + activeActions *activeActionsMap // Tracks active actions for cancellation support. +} + +// activeAction represents an in-flight action that can be cancelled. +type activeAction struct { + cancel context.CancelFunc + startTime time.Time + traceID string +} + +// activeActionsMap safely manages active actions. +type activeActionsMap struct { + mu sync.RWMutex + actions map[string]*activeAction +} + +func newActiveActionsMap() *activeActionsMap { + return &activeActionsMap{ + actions: make(map[string]*activeAction), + } +} + +func (m *activeActionsMap) Set(traceID string, action *activeAction) { + m.mu.Lock() + defer m.mu.Unlock() + m.actions[traceID] = action +} + +func (m *activeActionsMap) Get(traceID string) (*activeAction, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + action, ok := m.actions[traceID] + return action, ok +} + +func (m *activeActionsMap) Delete(traceID string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.actions, traceID) } func (s *reflectionServer) runtimeID() string { @@ -102,6 +143,7 @@ func startReflectionServer(ctx context.Context, g *Genkit, errCh chan<- error, s Server: &http.Server{ Addr: addr, }, + activeActions: newActiveActionsMap(), } s.Handler = serveMux(g, s) @@ -258,8 +300,9 @@ func serveMux(g *Genkit, s *reflectionServer) *http.ServeMux { w.WriteHeader(http.StatusOK) }) mux.HandleFunc("GET /api/actions", wrapReflectionHandler(handleListActions(g))) - mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g))) + mux.HandleFunc("POST /api/runAction", wrapReflectionHandler(handleRunAction(g, s.activeActions))) mux.HandleFunc("POST /api/notify", wrapReflectionHandler(handleNotify())) + mux.HandleFunc("POST /api/cancelAction", wrapReflectionHandler(handleCancelAction(s.activeActions))) return mux } @@ -290,7 +333,7 @@ func wrapReflectionHandler(h func(w http.ResponseWriter, r *http.Request) error) // handleRunAction looks up an action by name in the registry, runs it with the // provided JSON input, and writes back the JSON-marshaled request. -func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) error { +func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.ResponseWriter, r *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() @@ -312,11 +355,54 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err logger.FromContext(ctx).Debug("running action", "key", body.Key, "stream", stream) + // Create cancellable context for this action + actionCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Track whether headers have been sent + headersSent := false + var traceID string + var mu sync.Mutex + + // Set up telemetry callback to capture and send trace ID early + // This is used for BOTH streaming and non-streaming to match JS behavior + telemetryCb := func(tid string, sid string) { + mu.Lock() + defer mu.Unlock() + + if !headersSent { + traceID = tid + + // Track active action for cancellation + activeActions.Set(traceID, &activeAction{ + cancel: cancel, + startTime: time.Now(), + traceID: traceID, + }) + + // Send headers immediately with trace ID + w.Header().Set("X-Genkit-Trace-Id", traceID) + w.Header().Set("X-Genkit-Span-Id", sid) + w.Header().Set("X-Genkit-Version", "go/"+internal.Version) + + if stream { + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Transfer-Encoding", "chunked") + } else { + w.Header().Set("Content-Type", "application/json") + } + + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + headersSent = true + } + } + + // Set up streaming callback if needed var cb streamingCallback[json.RawMessage] if stream { - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Transfer-Encoding", "chunked") - // Stream results are newline-separated JSON. cb = func(ctx context.Context, msg json.RawMessage) error { _, err := fmt.Fprintf(w, "%s\n", msg) if err != nil { @@ -334,35 +420,127 @@ func handleRunAction(g *Genkit) func(w http.ResponseWriter, r *http.Request) err json.Unmarshal(body.Context, &contextMap) } - resp, err := runAction(ctx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap) + // Run the action with telemetry callback + resp, err := runAction(actionCtx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap, telemetryCb) + + // Clean up active action if we have a trace ID + if traceID != "" { + activeActions.Delete(traceID) + } else if resp != nil && resp.Telemetry.TraceID != "" { + // If we didn't get trace ID from callback (non-streaming), track and clean up + traceID = resp.Telemetry.TraceID + activeActions.Set(traceID, &activeAction{ + cancel: cancel, + startTime: time.Now(), + traceID: traceID, + }) + defer activeActions.Delete(traceID) + } + if err != nil { - if stream { - refErr := core.ToReflectionError(err) - refErr.Details.TraceID = &resp.Telemetry.TraceID - reflectErr, err := json.Marshal(refErr) - if err != nil { - return err + // Check if context was cancelled + if errors.Is(err, context.Canceled) { + errorResponse := map[string]interface{}{ + "error": map[string]interface{}{ + "code": 1, // gRPC CANCELLED (matches JS) + "message": "Action was cancelled", + "details": map[string]interface{}{ + "traceId": traceID, + }, + }, } - _, err = fmt.Fprintf(w, "{\"error\": %s }", reflectErr) - if err != nil { - return err + if stream { + // For streaming, write error as final chunk + json.NewEncoder(w).Encode(errorResponse) + } else { + // For non-streaming, return error response + if !headersSent { + w.WriteHeader(499) + } + json.NewEncoder(w).Encode(errorResponse) } + return nil + } - if f, ok := w.(http.Flusher); ok { - f.Flush() + // Handle other errors + if stream { + refErr := core.ToReflectionError(err) + if resp != nil && resp.Telemetry.TraceID != "" { + refErr.Details.TraceID = &resp.Telemetry.TraceID } + + errorResp := map[string]interface{}{ + "error": refErr, + } + json.NewEncoder(w).Encode(errorResp) return nil } + + // Non-streaming error errorResponse := core.ToReflectionError(err) - if resp != nil { + if resp != nil && resp.Telemetry.TraceID != "" { errorResponse.Details.TraceID = &resp.Telemetry.TraceID } - w.WriteHeader(errorResponse.Code) + + if !headersSent { + w.WriteHeader(errorResponse.Code) + } return writeJSON(ctx, w, errorResponse) } - return writeJSON(ctx, w, resp) + // Success case + if stream { + // For streaming, write the final chunk with result and telemetry + // This matches JS: response.write(JSON.stringify({result, telemetry})) + finalResponse := map[string]interface{}{ + "result": resp.Result, + "telemetry": map[string]interface{}{ + "traceId": resp.Telemetry.TraceID, + }, + } + json.NewEncoder(w).Encode(finalResponse) + } else { + // For non-streaming, headers were already sent via telemetry callback + // Response already includes telemetry.traceId in body + return writeJSON(ctx, w, resp) + } + + return nil + } +} + +// handleCancelAction cancels an in-flight action by trace ID. +func handleCancelAction(activeActions *activeActionsMap) func(w http.ResponseWriter, r *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + var body struct { + TraceID string `json:"traceId"` + } + + defer r.Body.Close() + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + return core.NewError(core.INVALID_ARGUMENT, err.Error()) + } + + if body.TraceID == "" { + return core.NewError(core.INVALID_ARGUMENT, "traceId is required") + } + + action, exists := activeActions.Get(body.TraceID) + if !exists { + w.WriteHeader(http.StatusNotFound) + return writeJSON(r.Context(), w, map[string]string{ + "error": "Action not found or already completed", + }) + } + + // Cancel the action's context + action.cancel() + activeActions.Delete(body.TraceID) + + return writeJSON(r.Context(), w, map[string]string{ + "message": "Action cancelled", + }) } } @@ -462,7 +640,7 @@ type telemetry struct { TraceID string `json:"traceId"` } -func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { +func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any, telemetryCb api.TelemetryCallback) (*runActionResponse, error) { action := g.reg.ResolveAction(key) if action == nil { return nil, core.NewError(core.NOT_FOUND, "action %q not found", key) @@ -483,7 +661,7 @@ func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage // Run the action and capture trace ID. We need to ensure there's a valid trace context. var traceID string output, err := func() (json.RawMessage, error) { - r, err := action.RunJSONWithTelemetry(ctx, input, cb) + r, err := action.RunJSONWithTelemetry(ctx, input, cb, telemetryCb) if r != nil { traceID = r.TraceId } diff --git a/go/genkit/reflection_test.go b/go/genkit/reflection_test.go index d47a10a027..7b11914348 100644 --- a/go/genkit/reflection_test.go +++ b/go/genkit/reflection_test.go @@ -21,6 +21,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "os" @@ -91,7 +92,8 @@ func TestServeMux(t *testing.T) { core.DefineAction(g.reg, "test/dec", api.ActionTypeCustom, nil, nil, dec) s := &reflectionServer{ - Server: &http.Server{}, + Server: &http.Server{}, + activeActions: newActiveActionsMap(), } ts := httptest.NewServer(serveMux(g, s)) s.Addr = strings.TrimPrefix(ts.URL, "http://") @@ -290,3 +292,326 @@ func TestServeMux(t *testing.T) { } }) } + +// TestEarlyTraceIDTransmission verifies that trace ID headers are sent BEFORE the action completes. +// +// The key thing we're testing: headers arrive while the action is still running, not after. +// This allows clients to get the trace ID immediately for cancellation or logging. +func TestEarlyTraceIDTransmission(t *testing.T) { + g := Init(context.Background()) + tc := tracing.NewTestOnlyTelemetryClient() + tracing.WriteTelemetryImmediate(tc) + + actionStarted := make(chan struct{}) + actionCanProceed := make(chan struct{}) + + // Action that waits for permission to complete - this lets us check headers while it's running + core.DefineAction(g.reg, "test/slow", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, input any) (any, error) { + close(actionStarted) // Signal we've started + <-actionCanProceed // Wait for test to say we can finish + return "completed", nil + }) + + s := &reflectionServer{Server: &http.Server{}, activeActions: newActiveActionsMap()} + ts := httptest.NewServer(serveMux(g, s)) + defer ts.Close() + + t.Run("headers arrive before body completes", func(t *testing.T) { + // Channel to receive headers as soon as they arrive + type headerResult struct { + traceID string + spanID string + version string + } + gotHeaders := make(chan headerResult) + + go func() { + req, _ := http.NewRequest("POST", ts.URL+"/api/runAction", + strings.NewReader(`{"key":"/custom/test/slow","input":null}`)) + req.Header.Set("Content-Type", "application/json") + + // Do() returns as soon as headers are received (before body is read) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + // Send headers immediately - body isn't done yet! + gotHeaders <- headerResult{ + traceID: resp.Header.Get("X-Genkit-Trace-Id"), + spanID: resp.Header.Get("X-Genkit-Span-Id"), + version: resp.Header.Get("X-Genkit-Version"), + } + + // Now read body (which will block until action completes) + io.ReadAll(resp.Body) + }() + + // Wait for action to start + <-actionStarted + + // Check headers arrived WHILE action is still running + select { + case h := <-gotHeaders: + if h.traceID == "" { + t.Error("Expected X-Genkit-Trace-Id header") + } + if h.spanID == "" { + t.Error("Expected X-Genkit-Span-Id header") + } + if !strings.HasPrefix(h.version, "go/") { + t.Errorf("Expected X-Genkit-Version to start with 'go/', got %q", h.version) + } + t.Logf("Got headers while action running: traceID=%s", h.traceID) + case <-time.After(1 * time.Second): + t.Fatal("Headers did not arrive while action was still running") + } + + // Let action complete + close(actionCanProceed) + }) + + // Backwards compatability + t.Run("trace ID in headers matches body", func(t *testing.T) { + // Reset channels for this subtest + actionStarted = make(chan struct{}) + actionCanProceed = make(chan struct{}) + + // Re-register action for this subtest + core.DefineAction(g.reg, "test/slow2", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, input any) (any, error) { + close(actionStarted) + <-actionCanProceed + return "completed", nil + }) + + req, _ := http.NewRequest("POST", ts.URL+"/api/runAction", + strings.NewReader(`{"key":"/custom/test/slow2","input":null}`)) + req.Header.Set("Content-Type", "application/json") + + // Start request in background + type result struct { + headerTraceID string + bodyTraceID string + } + done := make(chan result) + + go func() { + resp, err := http.DefaultClient.Do(req) + if err != nil { + done <- result{} + return + } + defer resp.Body.Close() + headerTraceID := resp.Header.Get("X-Genkit-Trace-Id") + + var body map[string]interface{} + json.NewDecoder(resp.Body).Decode(&body) + bodyTraceID := "" + if tel, ok := body["telemetry"].(map[string]interface{}); ok { + bodyTraceID, _ = tel["traceId"].(string) + } + done <- result{headerTraceID, bodyTraceID} + }() + + <-actionStarted + close(actionCanProceed) + + r := <-done + if r.headerTraceID == "" { + t.Error("No trace ID in headers") + } + if r.bodyTraceID == "" { + t.Error("No trace ID in body") + } + if r.headerTraceID != r.bodyTraceID { + t.Errorf("Trace ID mismatch: header=%q, body=%q", r.headerTraceID, r.bodyTraceID) + } + }) +} + +// TestActionCancellation verifies that running actions can be cancelled via /api/cancelAction. +// +// Flow: +// 1. Start a long-running action that sends its trace ID via channel when it starts +// 2. Call POST /api/cancelAction with that trace ID +// 3. Verify: cancel endpoint returns 200, action's ctx.Done() fires, response has error code 1 (gRPC CANCELLED) +func TestActionCancellation(t *testing.T) { + g := Init(context.Background()) + tc := tracing.NewTestOnlyTelemetryClient() + tracing.WriteTelemetryImmediate(tc) + + gotTraceID := make(chan string, 1) + gotCancelled := make(chan struct{}) + + // Long-running action that respects cancellation + core.DefineStreamingAction(g.reg, "test/cancellable", api.ActionTypeCustom, nil, nil, + func(ctx context.Context, input any, cb func(context.Context, any) error) (any, error) { + // Send trace ID so test can cancel us + gotTraceID <- tracing.SpanTraceInfo(ctx).TraceID + + for i := 0; i < 100; i++ { + select { + case <-ctx.Done(): + if ctx.Err() != context.Canceled { + return nil, fmt.Errorf("expected context.Canceled, got %v", ctx.Err()) + } + close(gotCancelled) + return nil, ctx.Err() + case <-time.After(50 * time.Millisecond): + if cb != nil && i%10 == 0 { + cb(ctx, fmt.Sprintf("progress: %d", i)) + } + } + } + return "completed", nil + }) + + s := &reflectionServer{Server: &http.Server{}, activeActions: newActiveActionsMap()} + ts := httptest.NewServer(serveMux(g, s)) + defer ts.Close() + + // Start action in background + actionDone := make(chan string) // receives response body when done + go func() { + req, _ := http.NewRequest("POST", ts.URL+"/api/runAction?stream=true", + strings.NewReader(`{"key":"/custom/test/cancellable","input":null}`)) + req.Header.Set("Content-Type", "application/json") + resp, _ := http.DefaultClient.Do(req) + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + actionDone <- string(body) + }() + + // Wait for action to start + traceID := <-gotTraceID + time.Sleep(50 * time.Millisecond) // ensure it's tracked + + // Cancel it + cancelReq, _ := http.NewRequest("POST", ts.URL+"/api/cancelAction", + strings.NewReader(fmt.Sprintf(`{"traceId":"%s"}`, traceID))) + cancelReq.Header.Set("Content-Type", "application/json") + cancelResp, err := http.DefaultClient.Do(cancelReq) + if err != nil { + t.Fatal(err) + } + defer cancelResp.Body.Close() + + if cancelResp.StatusCode != http.StatusOK { + t.Fatalf("Cancel failed with status %d", cancelResp.StatusCode) + } + + // Verify action acknowledged cancellation + select { + case <-gotCancelled: + case <-time.After(1 * time.Second): + t.Fatal("Action did not acknowledge cancellation") + } + + // Verify response indicates cancellation + responseBody := <-actionDone + if !strings.Contains(responseBody, "\"code\":1") { + t.Errorf("Expected error code 1 (gRPC CANCELLED) in response, got: %s", responseBody) + } + if !strings.Contains(responseBody, "Action was cancelled") { + t.Errorf("Expected 'Action was cancelled' message in response, got: %s", responseBody) + } +} + +func TestCancelActionEndpoint(t *testing.T) { + g := Init(context.Background()) + + s := &reflectionServer{ + Server: &http.Server{}, + activeActions: newActiveActionsMap(), + } + ts := httptest.NewServer(serveMux(g, s)) + defer ts.Close() + + t.Run("cancel non-existent action", func(t *testing.T) { + cancelReq, _ := http.NewRequest("POST", ts.URL+"/api/cancelAction", + strings.NewReader(`{"traceId":"non-existent-trace-id"}`)) + cancelReq.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(cancelReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected 404 for non-existent action, got %d", resp.StatusCode) + } + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + if error, ok := result["error"].(string); !ok || error != "Action not found or already completed" { + t.Errorf("Unexpected error message: %v", result) + } + }) + + t.Run("cancel with missing traceId", func(t *testing.T) { + cancelReq, _ := http.NewRequest("POST", ts.URL+"/api/cancelAction", + strings.NewReader(`{}`)) + cancelReq.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(cancelReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest && resp.StatusCode != http.StatusInternalServerError { + t.Errorf("Expected 400 or 500 for missing traceId, got %d", resp.StatusCode) + } + }) + + t.Run("cancel active action", func(t *testing.T) { + // Manually add an action to activeActions + testTraceID := "test-trace-id-12345" + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.activeActions.Set(testTraceID, &activeAction{ + cancel: cancel, + startTime: time.Now(), + traceID: testTraceID, + }) + + // Send cancel request + cancelReq, _ := http.NewRequest("POST", ts.URL+"/api/cancelAction", + strings.NewReader(fmt.Sprintf(`{"traceId":"%s"}`, testTraceID))) + cancelReq.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(cancelReq) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected 200 for successful cancellation, got %d", resp.StatusCode) + } + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + if message, ok := result["message"].(string); !ok || message != "Action cancelled" { + t.Errorf("Expected 'Action cancelled' message, got: %v", result) + } + + // Verify action was removed from activeActions + if action, exists := s.activeActions.Get(testTraceID); exists { + t.Errorf("Action should have been removed from activeActions, but still exists: %v", action) + } + + // Verify context was cancelled + select { + case <-ctx.Done(): + // Good, context was cancelled + default: + t.Error("Context should have been cancelled") + } + }) +} diff --git a/go/internal/cmd/jsonschemagen/jsonschema.go b/go/internal/cmd/jsonschemagen/jsonschema.go index 2ee6930954..0afbde5e74 100644 --- a/go/internal/cmd/jsonschemagen/jsonschema.go +++ b/go/internal/cmd/jsonschemagen/jsonschema.go @@ -33,7 +33,7 @@ type Schema struct { Description string `json:"description,omitempty"` Properties map[string]*Schema `json:"properties,omitempty"` AdditionalProperties *Schema `json:"additionalProperties,omitempty"` - Const bool `json:"const,omitempty"` + Const any `json:"const,omitempty"` Required []string `json:"required,omitempty"` Items *Schema `json:"items,omitempty"` Enum []string `json:"enum,omitempty"` diff --git a/go/samples/flow-sample1/main.go b/go/samples/flow-sample1/main.go index a6f567042b..9b37943d7c 100644 --- a/go/samples/flow-sample1/main.go +++ b/go/samples/flow-sample1/main.go @@ -41,6 +41,7 @@ import ( "log" "net/http" "strconv" + "time" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/genkit" @@ -111,6 +112,122 @@ func main() { return fmt.Sprintf("done: %d, streamed: %d times", count, i), nil }) + // Long-running flow for testing early trace ID transmission and cancellation. + // Each step takes ~5 seconds with nested sub-steps. + // + // Test with: + // curl -d '{"key":"/flow/longRunning/longRunning", "input":{"start": {"input":3}}}' \ + // http://localhost:3100/api/runAction?stream=true + // + // To test cancellation, note the X-Genkit-Trace-Id header and call: + // curl -d '{"traceId":""}' http://localhost:3100/api/cancelAction + type stepResult struct { + Step int `json:"step"` + Timestamp string `json:"timestamp"` + Elapsed int64 `json:"elapsed_ms"` + } + + type longRunningResult struct { + TotalDuration int64 `json:"total_duration_ms"` + StepsCompleted int `json:"steps_completed"` + Timeline []stepResult `json:"timeline"` + } + + genkit.DefineStreamingFlow(g, "longRunning", + func(ctx context.Context, steps int, cb func(context.Context, stepResult) error) (longRunningResult, error) { + if steps <= 0 { + steps = 3 + } + startTime := time.Now() + timeline := make([]stepResult, 0, steps) + + log.Printf("🚀 Starting long-running flow: %d steps × 5s = ~%ds", steps, steps*5) + + for i := 1; i <= steps; i++ { + stepStart := time.Now() + + // Check for cancellation before each step + select { + case <-ctx.Done(): + log.Printf("❌ Cancelled at step %d/%d", i, steps) + return longRunningResult{ + TotalDuration: time.Since(startTime).Milliseconds(), + StepsCompleted: i - 1, + Timeline: timeline, + }, ctx.Err() + default: + } + + log.Printf("[%s] 🔄 Step %d/%d starting...", time.Now().Format(time.RFC3339), i, steps) + + // Nested sub-steps (like the TS version) + _, err := core.Run(ctx, fmt.Sprintf("step-%d-fetch", i), func() (string, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1500 * time.Millisecond): + } + log.Printf(" 📡 Fetched data for step %d", i) + return fmt.Sprintf("fetch-%d", i), nil + }) + if err != nil { + return longRunningResult{TotalDuration: time.Since(startTime).Milliseconds(), StepsCompleted: i - 1, Timeline: timeline}, err + } + + _, err = core.Run(ctx, fmt.Sprintf("step-%d-process", i), func() (string, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1500 * time.Millisecond): + } + log.Printf(" ⚙️ Processed data for step %d", i) + return fmt.Sprintf("process-%d", i), nil + }) + if err != nil { + return longRunningResult{TotalDuration: time.Since(startTime).Milliseconds(), StepsCompleted: i - 1, Timeline: timeline}, err + } + + _, err = core.Run(ctx, fmt.Sprintf("step-%d-save", i), func() (string, error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1500 * time.Millisecond): + } + log.Printf(" 💾 Saved results for step %d", i) + return fmt.Sprintf("save-%d", i), nil + }) + if err != nil { + return longRunningResult{TotalDuration: time.Since(startTime).Milliseconds(), StepsCompleted: i - 1, Timeline: timeline}, err + } + + elapsed := time.Since(stepStart).Milliseconds() + log.Printf("[%s] ✅ Step %d/%d completed (%dms)", time.Now().Format(time.RFC3339), i, steps, elapsed) + + result := stepResult{ + Step: i, + Timestamp: time.Now().Format(time.RFC3339), + Elapsed: elapsed, + } + timeline = append(timeline, result) + + // Stream progress if callback provided + if cb != nil { + if err := cb(ctx, result); err != nil { + return longRunningResult{}, err + } + } + } + + totalDuration := time.Since(startTime).Milliseconds() + log.Printf("🎉 Long-running flow completed in %dms", totalDuration) + + return longRunningResult{ + TotalDuration: totalDuration, + StepsCompleted: steps, + Timeline: timeline, + }, nil + }) + mux := http.NewServeMux() for _, a := range genkit.ListFlows(g) { mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) From 8e952898ec7eba3846d129162a6ff3c88a2494a4 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 2 Dec 2025 20:23:57 -0600 Subject: [PATCH 02/10] fix --- go/core/tracing/tracing_test.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/go/core/tracing/tracing_test.go b/go/core/tracing/tracing_test.go index 6d6f1dd0d3..535f92c967 100644 --- a/go/core/tracing/tracing_test.go +++ b/go/core/tracing/tracing_test.go @@ -314,7 +314,7 @@ func TestRunInNewSpanWithMetadata(t *testing.T) { ctx := context.Background() input := "test input" - output, err := RunInNewSpan(ctx, tc.metadata, input, + output, err := RunInNewSpan(ctx, tc.metadata, input, nil, func(ctx context.Context, input string) (string, error) { // Verify that span metadata is available in context sm := spanMetaKey.FromContext(ctx) @@ -360,7 +360,7 @@ func TestRunInNewSpanWithTypeConvenience(t *testing.T) { Subtype: "tool", } - output, err := RunInNewSpan(ctx, metadata, "input", + output, err := RunInNewSpan(ctx, metadata, "input", nil, func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { @@ -390,10 +390,10 @@ func TestNestedSpanPaths(t *testing.T) { ctx := context.Background() // Test nested spans to verify path building - _, err := RunInNewSpan(ctx, &SpanMetadata{Name: "chatFlow", IsRoot: true, Type: "action", Subtype: "flow"}, "input", + _, err := RunInNewSpan(ctx, &SpanMetadata{Name: "chatFlow", IsRoot: true, Type: "action", Subtype: "flow"}, "input", nil, func(ctx context.Context, input string) (string, error) { // Nested action span - return RunInNewSpan(ctx, &SpanMetadata{Name: "myTool", IsRoot: false, Type: "action", Subtype: "tool"}, input, + return RunInNewSpan(ctx, &SpanMetadata{Name: "myTool", IsRoot: false, Type: "action", Subtype: "tool"}, input, nil, func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { @@ -425,7 +425,7 @@ func TestIsFailureSourceOnError(t *testing.T) { _, err := RunInNewSpan(ctx, &SpanMetadata{ Name: "failing-action", Type: "action", - }, "input", func(ctx context.Context, input string) (string, error) { + }, "input", nil, func(ctx context.Context, input string) (string, error) { return "", testErr }) @@ -446,7 +446,7 @@ func TestRootSpanAutoDetection(t *testing.T) { Type: "action", Subtype: "flow", IsRoot: false, // Even when explicitly set to false, should be overridden - }, "input", func(ctx context.Context, input string) (string, error) { + }, "input", nil, func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { t.Fatal("Expected span metadata in context") @@ -469,7 +469,7 @@ func TestRootSpanAutoDetection(t *testing.T) { Name: "explicitRootFlow", Type: "action", IsRoot: true, // Explicitly set to true - }, "input", func(ctx context.Context, input string) (string, error) { + }, "input", nil, func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { t.Fatal("Expected span metadata in context") @@ -492,13 +492,13 @@ func TestRootSpanAutoDetection(t *testing.T) { Name: "parentFlow", Type: "action", IsRoot: true, - }, "input", func(ctx context.Context, input string) (string, error) { + }, "input", nil, func(ctx context.Context, input string) (string, error) { // This is a nested span - should NOT be root _, err := RunInNewSpan(ctx, &SpanMetadata{ Name: "childAction", Type: "action", IsRoot: false, - }, input, func(ctx context.Context, input string) (string, error) { + }, input, nil, func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { t.Fatal("Expected span metadata in context") From fff7c7ab8dd46877e91d68e0ec57fea99f38269f Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Wed, 3 Dec 2025 19:18:45 -0600 Subject: [PATCH 03/10] address comments --- go/core/action.go | 8 ++--- go/core/api/action.go | 7 +---- go/core/flow.go | 6 ++-- go/core/tracing/tracing.go | 14 +++++++++ go/genkit/reflection.go | 64 ++++++++++++++++++-------------------- 5 files changed, 52 insertions(+), 47 deletions(-) diff --git a/go/core/action.go b/go/core/action.go index cf598b19f3..bb1a580bbf 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -253,15 +253,15 @@ func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input // RunJSON runs the action with a JSON input, and returns a JSON result. func (a *ActionDef[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (json.RawMessage, error) { - r, err := a.RunJSONWithTelemetry(ctx, input, cb, nil) + r, err := a.RunJSONWithTelemetry(ctx, input, cb) if err != nil { return nil, err } return r.Result, nil } -// RunJSON runs the action with a JSON input, and returns a JSON result along with telemetry info. -func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], telemetryCb api.TelemetryCallback) (*api.ActionRunResult[json.RawMessage], error) { +// RunJSONWithTelemetry runs the action with a JSON input, and returns a JSON result along with telemetry info. +func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { i, err := base.UnmarshalAndNormalize[In](input, a.desc.InputSchema) if err != nil { return nil, NewError(INVALID_ARGUMENT, err.Error()) @@ -278,7 +278,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, i } } - r, err := a.runWithTelemetry(ctx, i, scb, telemetryCb) + r, err := a.runWithTelemetry(ctx, i, scb, tracing.TelemetryCb(ctx)) if err != nil { return &api.ActionRunResult[json.RawMessage]{ TraceId: r.TraceId, diff --git a/go/core/api/action.go b/go/core/api/action.go index 51b2d804e9..3cfd2689f1 100644 --- a/go/core/api/action.go +++ b/go/core/api/action.go @@ -27,10 +27,6 @@ type ActionRunResult[T any] struct { SpanId string } -// TelemetryCallback is called when telemetry information becomes available. -// It receives the trace ID and span ID as soon as the span is created. -type TelemetryCallback func(traceID, spanID string) - // Action is the interface that all Genkit primitives (e.g. flows, models, tools) have in common. type Action interface { Registerable @@ -39,8 +35,7 @@ type Action interface { // RunJSON runs the action with the given JSON input and streaming callback and returns the output as JSON. RunJSON(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (json.RawMessage, error) // RunJSONWithTelemetry runs the action with the given JSON input and streaming callback and returns the output as JSON along with telemetry info. - // The telemetryCb callback, if provided, is called as soon as the trace span is created with the trace ID and span ID. - RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error, telemetryCb TelemetryCallback) (*ActionRunResult[json.RawMessage], error) + RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb func(context.Context, json.RawMessage) error) (*ActionRunResult[json.RawMessage], error) // Desc returns a descriptor of the action. Desc() ActionDesc } diff --git a/go/core/flow.go b/go/core/flow.go index ce8dd9eea8..4a264a4c93 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -112,9 +112,9 @@ func (f *Flow[In, Out, Stream]) RunJSON(ctx context.Context, input json.RawMessa return (*ActionDef[In, Out, Stream])(f).RunJSON(ctx, input, cb) } -// RunJSON runs the flow with JSON input and streaming callback and returns the output as JSON. -func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage], telemetryCb api.TelemetryCallback) (*api.ActionRunResult[json.RawMessage], error) { - return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb, telemetryCb) +// RunJSONWithTelemetry runs the flow with JSON input and streaming callback and returns the output as JSON along with telemetry info. +func (f *Flow[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, input json.RawMessage, cb StreamCallback[json.RawMessage]) (*api.ActionRunResult[json.RawMessage], error) { + return (*ActionDef[In, Out, Stream])(f).RunJSONWithTelemetry(ctx, input, cb) } // Desc returns the descriptor of the flow. diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index bf7025d549..c12b81e3d8 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -379,6 +379,20 @@ func (sm *spanMetadata) attributes() []attribute.KeyValue { // spanMetaKey is for storing spanMetadatas in a context. var spanMetaKey = base.NewContextKey[*spanMetadata]() +// telemetryCbKey is the context key for telemetry callbacks. +var telemetryCbKey = base.NewContextKey[func(traceID, spanID string)]() + +// WithTelemetryCb returns a context with the telemetry callback attached. +// Used by the reflection server to pass callbacks to actions. +func WithTelemetryCb(ctx context.Context, cb func(traceID, spanID string)) context.Context { + return telemetryCbKey.NewContext(ctx, cb) +} + +// TelemetryCb retrieves the telemetry callback from context, or nil if not set. +func TelemetryCb(ctx context.Context) func(traceID, spanID string) { + return telemetryCbKey.FromContext(ctx) +} + // SpanPath returns the path as recorded in the current span metadata. func SpanPath(ctx context.Context) string { return spanMetaKey.FromContext(ctx).Path diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index b5970070ba..87968ecff4 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -361,7 +361,7 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res // Track whether headers have been sent headersSent := false - var traceID string + var callbackTraceID string // Trace ID captured from telemetry callback for early header sending var mu sync.Mutex // Set up telemetry callback to capture and send trace ID early @@ -371,17 +371,17 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res defer mu.Unlock() if !headersSent { - traceID = tid + callbackTraceID = tid // Track active action for cancellation - activeActions.Set(traceID, &activeAction{ + activeActions.Set(callbackTraceID, &activeAction{ cancel: cancel, startTime: time.Now(), - traceID: traceID, + traceID: callbackTraceID, }) // Send headers immediately with trace ID - w.Header().Set("X-Genkit-Trace-Id", traceID) + w.Header().Set("X-Genkit-Trace-Id", callbackTraceID) w.Header().Set("X-Genkit-Span-Id", sid) w.Header().Set("X-Genkit-Version", "go/"+internal.Version) @@ -420,35 +420,33 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res json.Unmarshal(body.Context, &contextMap) } - // Run the action with telemetry callback - resp, err := runAction(actionCtx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap, telemetryCb) + // Attach telemetry callback to context so action can invoke it when span is created + actionCtx = tracing.WithTelemetryCb(actionCtx, telemetryCb) + resp, err := runAction(actionCtx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap) - // Clean up active action if we have a trace ID - if traceID != "" { - activeActions.Delete(traceID) - } else if resp != nil && resp.Telemetry.TraceID != "" { - // If we didn't get trace ID from callback (non-streaming), track and clean up - traceID = resp.Telemetry.TraceID - activeActions.Set(traceID, &activeAction{ - cancel: cancel, - startTime: time.Now(), - traceID: traceID, - }) - defer activeActions.Delete(traceID) + // Clean up active action using the trace ID from response + if resp != nil && resp.Telemetry.TraceID != "" { + activeActions.Delete(resp.Telemetry.TraceID) } if err != nil { // Check if context was cancelled if errors.Is(err, context.Canceled) { - errorResponse := map[string]interface{}{ - "error": map[string]interface{}{ - "code": 1, // gRPC CANCELLED (matches JS) - "message": "Action was cancelled", - "details": map[string]interface{}{ - "traceId": traceID, - }, + // Use gRPC CANCELLED code (1) in JSON body to match TypeScript behavior + var traceIDPtr *string + if resp != nil && resp.Telemetry.TraceID != "" { + traceIDPtr = &resp.Telemetry.TraceID + } + cancelledErr := core.ReflectionError{ + Code: core.CodeCancelled, // gRPC CANCELLED = 1 + Message: "Action was cancelled", + Details: &core.ReflectionErrorDetails{ + TraceID: traceIDPtr, }, } + errorResponse := struct { + Error core.ReflectionError `json:"error"` + }{Error: cancelledErr} if stream { // For streaming, write error as final chunk @@ -456,7 +454,7 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res } else { // For non-streaming, return error response if !headersSent { - w.WriteHeader(499) + w.WriteHeader(http.StatusOK) // Match TS: response.status(200).json(...) } json.NewEncoder(w).Encode(errorResponse) } @@ -493,11 +491,9 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res if stream { // For streaming, write the final chunk with result and telemetry // This matches JS: response.write(JSON.stringify({result, telemetry})) - finalResponse := map[string]interface{}{ - "result": resp.Result, - "telemetry": map[string]interface{}{ - "traceId": resp.Telemetry.TraceID, - }, + finalResponse := runActionResponse{ + Result: resp.Result, + Telemetry: telemetry{TraceID: resp.Telemetry.TraceID}, } json.NewEncoder(w).Encode(finalResponse) } else { @@ -640,7 +636,7 @@ type telemetry struct { TraceID string `json:"traceId"` } -func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any, telemetryCb api.TelemetryCallback) (*runActionResponse, error) { +func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { action := g.reg.ResolveAction(key) if action == nil { return nil, core.NewError(core.NOT_FOUND, "action %q not found", key) @@ -661,7 +657,7 @@ func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage // Run the action and capture trace ID. We need to ensure there's a valid trace context. var traceID string output, err := func() (json.RawMessage, error) { - r, err := action.RunJSONWithTelemetry(ctx, input, cb, telemetryCb) + r, err := action.RunJSONWithTelemetry(ctx, input, cb) if r != nil { traceID = r.TraceId } From ee9155731a8e09dd8bc9ae374ac5ab3d939ccdc4 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Wed, 3 Dec 2025 19:36:36 -0600 Subject: [PATCH 04/10] Add errorResponse type --- go/genkit/reflection.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index 87968ecff4..f027f90e2d 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -437,26 +437,25 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res if resp != nil && resp.Telemetry.TraceID != "" { traceIDPtr = &resp.Telemetry.TraceID } - cancelledErr := core.ReflectionError{ - Code: core.CodeCancelled, // gRPC CANCELLED = 1 - Message: "Action was cancelled", - Details: &core.ReflectionErrorDetails{ - TraceID: traceIDPtr, + errResp := errorResponse{ + Error: core.ReflectionError{ + Code: core.CodeCancelled, // gRPC CANCELLED = 1 + Message: "Action was cancelled", + Details: &core.ReflectionErrorDetails{ + TraceID: traceIDPtr, + }, }, } - errorResponse := struct { - Error core.ReflectionError `json:"error"` - }{Error: cancelledErr} if stream { // For streaming, write error as final chunk - json.NewEncoder(w).Encode(errorResponse) + json.NewEncoder(w).Encode(errResp) } else { // For non-streaming, return error response if !headersSent { w.WriteHeader(http.StatusOK) // Match TS: response.status(200).json(...) } - json.NewEncoder(w).Encode(errorResponse) + json.NewEncoder(w).Encode(errResp) } return nil } @@ -468,10 +467,7 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res refErr.Details.TraceID = &resp.Telemetry.TraceID } - errorResp := map[string]interface{}{ - "error": refErr, - } - json.NewEncoder(w).Encode(errorResp) + json.NewEncoder(w).Encode(errorResponse{Error: refErr}) return nil } @@ -636,6 +632,10 @@ type telemetry struct { TraceID string `json:"traceId"` } +type errorResponse struct { + Error core.ReflectionError `json:"error"` +} + func runAction(ctx context.Context, g *Genkit, key string, input json.RawMessage, telemetryLabels json.RawMessage, cb streamingCallback[json.RawMessage], runtimeContext map[string]any) (*runActionResponse, error) { action := g.reg.ResolveAction(key) if action == nil { From fd87aad59e8939a1525292757733d44063f36479 Mon Sep 17 00:00:00 2001 From: huangjeff5 <64040981+huangjeff5@users.noreply.github.com> Date: Thu, 4 Dec 2025 14:40:09 -0600 Subject: [PATCH 05/10] Update go/core/tracing/tracing.go Co-authored-by: Alex Pascal --- go/core/tracing/tracing.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index c12b81e3d8..8bdd7e7cef 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -389,7 +389,7 @@ func WithTelemetryCb(ctx context.Context, cb func(traceID, spanID string)) conte } // TelemetryCb retrieves the telemetry callback from context, or nil if not set. -func TelemetryCb(ctx context.Context) func(traceID, spanID string) { +func telemetryCallback(ctx context.Context) func(traceID, spanID string) { return telemetryCbKey.FromContext(ctx) } From 1fba81921c9b9f3389ddcc04c3bd3207d90c9f27 Mon Sep 17 00:00:00 2001 From: huangjeff5 <64040981+huangjeff5@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:11:44 -0600 Subject: [PATCH 06/10] Update go/core/tracing/tracing.go Co-authored-by: Alex Pascal --- go/core/tracing/tracing.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index 8bdd7e7cef..3a60877209 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -384,7 +384,7 @@ var telemetryCbKey = base.NewContextKey[func(traceID, spanID string)]() // WithTelemetryCb returns a context with the telemetry callback attached. // Used by the reflection server to pass callbacks to actions. -func WithTelemetryCb(ctx context.Context, cb func(traceID, spanID string)) context.Context { +func WithTelemetryCallback(ctx context.Context, cb func(traceID, spanID string)) context.Context { return telemetryCbKey.NewContext(ctx, cb) } From e85a455ab862eec421a02b39dcf7a55fe396871d Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 4 Dec 2025 16:12:23 -0600 Subject: [PATCH 07/10] fix comments --- js/core/src/reflection.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/core/src/reflection.ts b/js/core/src/reflection.ts index d9504ce198..1337b26a9e 100644 --- a/js/core/src/reflection.ts +++ b/js/core/src/reflection.ts @@ -287,7 +287,7 @@ export class ReflectionServer { } } catch (err) { if (isAbortError(err)) { - // Handle cancellation - headers may have been sent via onTelemetry + // Handle cancellation - headers may have been sent via onTraceStart const errorResponse: Status = { code: StatusCodes.CANCELLED, message: 'Action was cancelled', @@ -375,7 +375,7 @@ export class ReflectionServer { }, }; - // Headers may have been sent already (via onTelemetry), so check before setting status + // Headers may have been sent already (via onTraceStart), so check before setting status if (!res.headersSent) { res.status(500).json(errorResponse); } else { From 1768fef35fa857c22f82a7134b687451d8475df0 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 4 Dec 2025 16:25:32 -0600 Subject: [PATCH 08/10] remove from RunInNewSpan --- go/ai/evaluator.go | 4 ++-- go/ai/generate.go | 2 +- go/core/action.go | 8 ++++---- go/core/flow.go | 2 +- go/core/tracing/tracing.go | 14 +++++++------- go/core/tracing/tracing_test.go | 20 ++++++++++---------- go/genkit/reflection.go | 2 +- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/go/ai/evaluator.go b/go/ai/evaluator.go index bf941f20fc..cd1fb3b74a 100644 --- a/go/ai/evaluator.go +++ b/go/ai/evaluator.go @@ -201,8 +201,8 @@ func NewEvaluator(name string, opts *EvaluatorOptions, fn EvaluatorFunc) Evaluat Type: "evaluator", Subtype: "evaluator", } - _, err := tracing.RunInNewSpan(ctx, spanMetadata, datapoint, nil, - func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) { + _, err := tracing.RunInNewSpan(ctx, spanMetadata, datapoint, + func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) { traceId := trace.SpanContextFromContext(ctx).TraceID().String() spanId := trace.SpanContextFromContext(ctx).SpanID().String() diff --git a/go/ai/generate.go b/go/ai/generate.go index 61bfe14be0..d0743de7ac 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -330,7 +330,7 @@ func GenerateWithRequest(ctx context.Context, r api.Registry, opts *GenerateActi Subtype: "util", } - return tracing.RunInNewSpan(ctx, spanMetadata, req, nil, func(ctx context.Context, req *ModelRequest) (*ModelResponse, error) { + return tracing.RunInNewSpan(ctx, spanMetadata, req, func(ctx context.Context, req *ModelRequest) (*ModelResponse, error) { var wrappedCb ModelStreamCallback currentRole := RoleModel currentIndex := messageIndex diff --git a/go/core/action.go b/go/core/action.go index bb1a580bbf..9acfc03008 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -176,7 +176,7 @@ func (a *ActionDef[In, Out, Stream]) Name() string { return a.desc.Name } // Run executes the Action's function in a new trace span. func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb StreamCallback[Stream]) (output Out, err error) { - r, err := a.runWithTelemetry(ctx, input, cb, nil) + r, err := a.runWithTelemetry(ctx, input, cb) if err != nil { return base.Zero[Out](), err } @@ -184,7 +184,7 @@ func (a *ActionDef[In, Out, Stream]) Run(ctx context.Context, input In, cb Strea } // Run executes the Action's function in a new trace span. -func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream], telemetryCb func(traceID, spanID string)) (output api.ActionRunResult[Out], err error) { +func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input In, cb StreamCallback[Stream]) (output api.ActionRunResult[Out], err error) { inputBytes, _ := json.Marshal(input) logger.FromContext(ctx).Debug("Action.Run", "name", a.Name(), @@ -215,7 +215,7 @@ func (a *ActionDef[In, Out, Stream]) runWithTelemetry(ctx context.Context, input var traceID string var spanID string - o, err := tracing.RunInNewSpan(ctx, spanMetadata, input, telemetryCb, + o, err := tracing.RunInNewSpan(ctx, spanMetadata, input, func(ctx context.Context, input In) (Out, error) { traceInfo := tracing.SpanTraceInfo(ctx) traceID = traceInfo.TraceID @@ -278,7 +278,7 @@ func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, i } } - r, err := a.runWithTelemetry(ctx, i, scb, tracing.TelemetryCb(ctx)) + r, err := a.runWithTelemetry(ctx, i, scb) if err != nil { return &api.ActionRunResult[json.RawMessage]{ TraceId: r.TraceId, diff --git a/go/core/flow.go b/go/core/flow.go index 4a264a4c93..0cd12120f2 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -93,7 +93,7 @@ func Run[Out any](ctx context.Context, name string, fn func() (Out, error)) (Out Type: "flowStep", Subtype: "flowStep", } - return tracing.RunInNewSpan(ctx, spanMetadata, nil, nil, func(ctx context.Context, _ any) (Out, error) { + return tracing.RunInNewSpan(ctx, spanMetadata, nil, func(ctx context.Context, _ any) (Out, error) { o, err := fn() if err != nil { return base.Zero[Out](), err diff --git a/go/core/tracing/tracing.go b/go/core/tracing/tracing.go index 3a60877209..727625f75a 100644 --- a/go/core/tracing/tracing.go +++ b/go/core/tracing/tracing.go @@ -173,12 +173,12 @@ type SpanMetadata struct { // RunInNewSpan runs f on input in a new span with the provided metadata. // The metadata contains all span configuration including name, type, labels, etc. -// If telemetryCb is provided, it will be called with the trace ID and span ID as soon as the span is created. +// If a telemetry callback was set on the context via WithTelemetryCallback, +// it will be called with the trace ID and span ID as soon as the span is created. func RunInNewSpan[I, O any]( ctx context.Context, metadata *SpanMetadata, input I, - telemetryCb func(traceID, spanID string), f func(context.Context, I) (O, error), ) (O, error) { // TODO: support span links. @@ -242,9 +242,9 @@ func RunInNewSpan[I, O any]( SpanID: span.SpanContext().SpanID().String(), } - // Fire telemetry callback immediately if provided - if telemetryCb != nil { - telemetryCb(sm.TraceInfo.TraceID, sm.TraceInfo.SpanID) + // Fire telemetry callback immediately if one was set on the context + if cb := telemetryCallback(ctx); cb != nil { + cb(sm.TraceInfo.TraceID, sm.TraceInfo.SpanID) } defer span.End() @@ -382,13 +382,13 @@ var spanMetaKey = base.NewContextKey[*spanMetadata]() // telemetryCbKey is the context key for telemetry callbacks. var telemetryCbKey = base.NewContextKey[func(traceID, spanID string)]() -// WithTelemetryCb returns a context with the telemetry callback attached. +// WithTelemetryCallback returns a context with the telemetry callback attached. // Used by the reflection server to pass callbacks to actions. func WithTelemetryCallback(ctx context.Context, cb func(traceID, spanID string)) context.Context { return telemetryCbKey.NewContext(ctx, cb) } -// TelemetryCb retrieves the telemetry callback from context, or nil if not set. +// telemetryCallback retrieves the telemetry callback from context, or nil if not set. func telemetryCallback(ctx context.Context) func(traceID, spanID string) { return telemetryCbKey.FromContext(ctx) } diff --git a/go/core/tracing/tracing_test.go b/go/core/tracing/tracing_test.go index 535f92c967..979a23101b 100644 --- a/go/core/tracing/tracing_test.go +++ b/go/core/tracing/tracing_test.go @@ -314,8 +314,8 @@ func TestRunInNewSpanWithMetadata(t *testing.T) { ctx := context.Background() input := "test input" - output, err := RunInNewSpan(ctx, tc.metadata, input, nil, - func(ctx context.Context, input string) (string, error) { + output, err := RunInNewSpan(ctx, tc.metadata, input, + func(ctx context.Context, input string) (string, error) { // Verify that span metadata is available in context sm := spanMetaKey.FromContext(ctx) if sm == nil { @@ -360,7 +360,7 @@ func TestRunInNewSpanWithTypeConvenience(t *testing.T) { Subtype: "tool", } - output, err := RunInNewSpan(ctx, metadata, "input", nil, + output, err := RunInNewSpan(ctx, metadata, "input", func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { @@ -390,10 +390,10 @@ func TestNestedSpanPaths(t *testing.T) { ctx := context.Background() // Test nested spans to verify path building - _, err := RunInNewSpan(ctx, &SpanMetadata{Name: "chatFlow", IsRoot: true, Type: "action", Subtype: "flow"}, "input", nil, + _, err := RunInNewSpan(ctx, &SpanMetadata{Name: "chatFlow", IsRoot: true, Type: "action", Subtype: "flow"}, "input", func(ctx context.Context, input string) (string, error) { // Nested action span - return RunInNewSpan(ctx, &SpanMetadata{Name: "myTool", IsRoot: false, Type: "action", Subtype: "tool"}, input, nil, + return RunInNewSpan(ctx, &SpanMetadata{Name: "myTool", IsRoot: false, Type: "action", Subtype: "tool"}, input, func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { @@ -425,7 +425,7 @@ func TestIsFailureSourceOnError(t *testing.T) { _, err := RunInNewSpan(ctx, &SpanMetadata{ Name: "failing-action", Type: "action", - }, "input", nil, func(ctx context.Context, input string) (string, error) { + }, "input", func(ctx context.Context, input string) (string, error) { return "", testErr }) @@ -446,7 +446,7 @@ func TestRootSpanAutoDetection(t *testing.T) { Type: "action", Subtype: "flow", IsRoot: false, // Even when explicitly set to false, should be overridden - }, "input", nil, func(ctx context.Context, input string) (string, error) { + }, "input", func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { t.Fatal("Expected span metadata in context") @@ -469,7 +469,7 @@ func TestRootSpanAutoDetection(t *testing.T) { Name: "explicitRootFlow", Type: "action", IsRoot: true, // Explicitly set to true - }, "input", nil, func(ctx context.Context, input string) (string, error) { + }, "input", func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { t.Fatal("Expected span metadata in context") @@ -492,13 +492,13 @@ func TestRootSpanAutoDetection(t *testing.T) { Name: "parentFlow", Type: "action", IsRoot: true, - }, "input", nil, func(ctx context.Context, input string) (string, error) { + }, "input", func(ctx context.Context, input string) (string, error) { // This is a nested span - should NOT be root _, err := RunInNewSpan(ctx, &SpanMetadata{ Name: "childAction", Type: "action", IsRoot: false, - }, input, nil, func(ctx context.Context, input string) (string, error) { + }, input, func(ctx context.Context, input string) (string, error) { sm := spanMetaKey.FromContext(ctx) if sm == nil { t.Fatal("Expected span metadata in context") diff --git a/go/genkit/reflection.go b/go/genkit/reflection.go index f027f90e2d..f0dcbf5578 100644 --- a/go/genkit/reflection.go +++ b/go/genkit/reflection.go @@ -421,7 +421,7 @@ func handleRunAction(g *Genkit, activeActions *activeActionsMap) func(w http.Res } // Attach telemetry callback to context so action can invoke it when span is created - actionCtx = tracing.WithTelemetryCb(actionCtx, telemetryCb) + actionCtx = tracing.WithTelemetryCallback(actionCtx, telemetryCb) resp, err := runAction(actionCtx, g, body.Key, body.Input, body.TelemetryLabels, cb, contextMap) // Clean up active action using the trace ID from response From 0ec19678eb1e95d9d2fef8e1425f17abc94d869b Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 5 Dec 2025 10:43:58 -0600 Subject: [PATCH 09/10] Format --- go/ai/evaluator.go | 4 ++-- go/core/tracing/tracing_test.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go/ai/evaluator.go b/go/ai/evaluator.go index cd1fb3b74a..aa536fac9b 100644 --- a/go/ai/evaluator.go +++ b/go/ai/evaluator.go @@ -201,8 +201,8 @@ func NewEvaluator(name string, opts *EvaluatorOptions, fn EvaluatorFunc) Evaluat Type: "evaluator", Subtype: "evaluator", } - _, err := tracing.RunInNewSpan(ctx, spanMetadata, datapoint, - func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) { + _, err := tracing.RunInNewSpan(ctx, spanMetadata, datapoint, + func(ctx context.Context, input *Example) (*EvaluatorCallbackResponse, error) { traceId := trace.SpanContextFromContext(ctx).TraceID().String() spanId := trace.SpanContextFromContext(ctx).SpanID().String() diff --git a/go/core/tracing/tracing_test.go b/go/core/tracing/tracing_test.go index 979a23101b..6d6f1dd0d3 100644 --- a/go/core/tracing/tracing_test.go +++ b/go/core/tracing/tracing_test.go @@ -314,8 +314,8 @@ func TestRunInNewSpanWithMetadata(t *testing.T) { ctx := context.Background() input := "test input" - output, err := RunInNewSpan(ctx, tc.metadata, input, - func(ctx context.Context, input string) (string, error) { + output, err := RunInNewSpan(ctx, tc.metadata, input, + func(ctx context.Context, input string) (string, error) { // Verify that span metadata is available in context sm := spanMetaKey.FromContext(ctx) if sm == nil { From 022290b9f1e8adbc2c05eac87eeb45d021b1fefa Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 5 Dec 2025 16:50:02 -0600 Subject: [PATCH 10/10] add omit type --- go/core/schemas.config | 1 + 1 file changed, 1 insertion(+) diff --git a/go/core/schemas.config b/go/core/schemas.config index fba6b938ec..1d4ff98001 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -28,6 +28,7 @@ TimeEventAnnotation omit TraceData omit SpanStartEvent omit SpanEndEvent omit +SpanEventBase omit TraceEvent omit GenerationCommonConfig.maxOutputTokens type int