diff --git a/internal/integration/clam_prose_test.go b/internal/integration/clam_prose_test.go index 7dbb564280..e14b39a973 100644 --- a/internal/integration/clam_prose_test.go +++ b/internal/integration/clam_prose_test.go @@ -198,7 +198,6 @@ func clamMultiByteTruncLogs(mt *mtest.T) []truncValidator { // Insert started. validators[0] = newTruncValidator(mt, cmd, func(cmd string) error { - // Remove the suffix from the command string. cmd = cmd[:len(cmd)-len(logger.TruncationSuffix)] diff --git a/mongo/errors.go b/mongo/errors.go index 234445ab86..88dc943f90 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -334,6 +334,12 @@ type LabeledError interface { HasErrorLabel(string) bool } +type errorCoder interface { + ErrorCodes() []int +} + +var _ errorCoder = ServerError(nil) + // ServerError is the interface implemented by errors returned from the server. Custom implementations of this // interface should not be used in production. type ServerError interface { @@ -364,10 +370,12 @@ func hasErrorCode(srvErr ServerError, code int) bool { return false } -var _ ServerError = CommandError{} -var _ ServerError = WriteError{} -var _ ServerError = WriteException{} -var _ ServerError = BulkWriteException{} +var ( + _ ServerError = CommandError{} + _ ServerError = WriteError{} + _ ServerError = WriteException{} + _ ServerError = BulkWriteException{} +) var _ error = ClientBulkWriteException{} @@ -901,3 +909,23 @@ func joinBatchErrors(errs []error) string { return buf.String() } + +// ErrorCodes returns the list of server error codes contained in err. +func ErrorCodes(err error) []int { + if err == nil { + return nil + } + + var ec errorCoder + // First check if the error is already wrapped (common case) + if errors.As(err, &ec) { + return ec.ErrorCodes() + } + + // Only wrap if necessary (for internal errors) + if errors.As(wrapErrors(err), &ec) { + return ec.ErrorCodes() + } + + return []int{} +} diff --git a/mongo/errors_test.go b/mongo/errors_test.go index 2ff04c4dd2..ef51f5435d 100644 --- a/mongo/errors_test.go +++ b/mongo/errors_test.go @@ -760,3 +760,146 @@ func (n netErr) Temporary() bool { } var _ net.Error = (*netErr)(nil) + +func TestErrorCodes(t *testing.T) { + tests := []struct { + name string + input error + want []int + }{ + { + name: "nil error", + input: nil, + want: nil, + }, + { + name: "non-server error", + input: errors.New("boom"), + want: []int{}, + }, + { + name: "CommandError single code", + input: CommandError{Code: 1}, + want: []int{1}, + }, + { + name: "WriteError single code", + input: WriteError{Code: 1}, + want: []int{1}, + }, + { + name: "WriteException write errors only", + input: WriteException{WriteErrors: WriteErrors{{Code: 1}, {Code: 2}}}, + want: []int{1, 2}, + }, + { + name: "WriteException with write concern error", + input: WriteException{WriteErrors: WriteErrors{{Code: 1}}, WriteConcernError: &WriteConcernError{Code: 2}}, + want: []int{1, 2}, + }, + { + name: "BulkWriteException write errors only", + input: BulkWriteException{ + WriteErrors: []BulkWriteError{ + {WriteError: WriteError{Code: 1}}, + {WriteError: WriteError{Code: 2}}, + }, + }, + want: []int{1, 2}, + }, + { + name: "BulkWriteException with write concern error", + input: BulkWriteException{ + WriteErrors: []BulkWriteError{ + {WriteError: WriteError{Code: 1}}, + {WriteError: WriteError{Code: 2}}, + }, + WriteConcernError: &WriteConcernError{Code: 3}, + }, + want: []int{1, 2, 3}, + }, + { + name: "driver.Error wraps to CommandError", + input: driver.Error{Code: 1, Message: "shutdown in progress"}, + want: []int{1}, + }, + { + name: "wrapped driver.Error", + input: fmt.Errorf("context: %w", driver.Error{Code: 1, Message: "ExceededTimeLimit"}), + want: []int{1}, + }, + { + input: wrapErrors(driver.Error{Code: 1, Message: "Custom error"}), + name: "double wrapped driver.Error", + want: []int{1}, + }, + { + name: "already wrapped CommandError", + input: CommandError{Code: 1}, + want: []int{1}, + }, + { + name: "CommandError wrapped in fmt.Errorf", + input: fmt.Errorf("operation failed: %w", CommandError{Code: 1}), + want: []int{1}, + }, + { + name: "WriteException wrapped in fmt.Errorf", + input: fmt.Errorf("batch failed: %w", WriteException{ + WriteErrors: WriteErrors{{Code: 1}, {Code: 2}}, + }), + want: []int{1, 2}, + }, + { + name: "BulkWriteException with all error types", + input: BulkWriteException{ + WriteErrors: []BulkWriteError{ + {WriteError: WriteError{Code: 1}}, + {WriteError: WriteError{Code: 2}}, + {WriteError: WriteError{Code: 1}}, + }, + WriteConcernError: &WriteConcernError{Code: 2}, + }, + want: []int{1, 2, 1, 2}, + }, + { + name: "driver.Error with multiple fields", + input: driver.Error{Code: 1, Message: "test", Name: "TestError", Labels: []string{"label1"}}, + want: []int{1}, + }, + { + name: "topology.ErrTopologyClosed converts to ErrClientDisconnected", + input: topology.ErrTopologyClosed, + want: []int{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, ErrorCodes(tt.input)) + }) + } +} + +func TestErrorCodesNoDoubleWrapping(t *testing.T) { + driverErr := driver.Error{Code: 1, Message: "test error"} + + // Wrap it once + wrapped := wrapErrors(driverErr) + cmdErr, ok := wrapped.(CommandError) + require.True(t, ok, "wrapErrors should return CommandError") + require.Equal(t, int32(1), cmdErr.Code) + + // Call ErrorCodes on the wrapped error + codes := ErrorCodes(wrapped) + require.Equal(t, []int{1}, codes) + + // The wrapped error's structure should not have changed + cmdErrAfter, ok := wrapped.(CommandError) + require.True(t, ok, "error should still be CommandError") + require.Equal(t, cmdErr.Code, cmdErrAfter.Code) + + // Verify that calling ErrorCodes again gives same result + codes2 := ErrorCodes(wrapped) + require.Equal(t, codes, codes2) +} diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index ac2d5f69e1..9906563100 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -747,7 +747,6 @@ func (op Operation) Execute(ctx context.Context) error { var moreToCome bool var startedInfo startedInformation *wm, moreToCome, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) - if err != nil { return err }