diff --git a/CHANGELOG.md b/CHANGELOG.md index 967fac2a3..85f522e43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## v4.14.0 - 2025-12-xx + +**Security** + +* Logger middleware: escape string values when logger format looks like JSON + + +**Enhancements** + +* Add `middleware.RequestLogger` function to replace `middleware.Logger`. `middleware.RequestLogger` uses default slog logger. + Default slog logger output can be configured to JSON format like that: + ```go + slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil))) + e.Use(middleware.RequestLogger()) + ``` +* Deprecate `middleware.Logger` function and point users to `middleware.RequestLogger` and `middleware.RequestLoggerWithConfig` + ## v4.13.4 - 2025-05-22 **Enhancements** diff --git a/README.md b/README.md index 5a920e875..5e52d1d4e 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,8 @@ func main() { e := echo.New() // Middleware - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) + e.Use(middleware.RequestLogger()) // use the default RequestLogger middleware with slog logger + e.Use(middleware.Recover()) // recover panics as errors for proper error handling // Routes e.GET("/", hello) diff --git a/middleware/logger.go b/middleware/logger.go index 5d9d29e1b..c800a8a90 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -197,6 +197,7 @@ type LoggerConfig struct { template *fasttemplate.Template colorer *color.Color pool *sync.Pool + timeNow func() time.Time } // DefaultLoggerConfig is the default Logger middleware config. @@ -208,6 +209,7 @@ var DefaultLoggerConfig = LoggerConfig{ `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", CustomTimeFormat: "2006-01-02 15:04:05.00000", colorer: color.New(), + timeNow: time.Now, } // Logger returns a middleware that logs HTTP requests using the default configuration. @@ -235,6 +237,8 @@ var DefaultLoggerConfig = LoggerConfig{ // "bytes_in":0,"bytes_out":42} // // For custom configurations, use LoggerWithConfig instead. +// +// Deprecated: please use middleware.RequestLogger or middleware.RequestLoggerWithConfig instead. func Logger() echo.MiddlewareFunc { return LoggerWithConfig(DefaultLoggerConfig) } @@ -259,6 +263,8 @@ func Logger() echo.MiddlewareFunc { // return c.Request().URL.Path == "/health" // }, // })) +// +// Deprecated: please use middleware.RequestLoggerWithConfig instead. func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { @@ -267,9 +273,18 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { if config.Format == "" { config.Format = DefaultLoggerConfig.Format } + writeString := func(buf *bytes.Buffer, in string) (int, error) { return buf.WriteString(in) } + if config.Format[0] == '{' { // format looks like JSON, so we need to escape invalid characters + writeString = writeJSONSafeString + } + if config.Output == nil { config.Output = DefaultLoggerConfig.Output } + timeNow := DefaultLoggerConfig.timeNow + if config.timeNow != nil { + timeNow = config.timeNow + } config.template = fasttemplate.New(config.Format, "${", "}") config.colorer = color.New() @@ -305,49 +320,47 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { } return config.CustomTagFunc(c, buf) case "time_unix": - return buf.WriteString(strconv.FormatInt(time.Now().Unix(), 10)) + return buf.WriteString(strconv.FormatInt(timeNow().Unix(), 10)) case "time_unix_milli": - // go 1.17 or later, it supports time#UnixMilli() - return buf.WriteString(strconv.FormatInt(time.Now().UnixNano()/1000000, 10)) + return buf.WriteString(strconv.FormatInt(timeNow().UnixMilli(), 10)) case "time_unix_micro": - // go 1.17 or later, it supports time#UnixMicro() - return buf.WriteString(strconv.FormatInt(time.Now().UnixNano()/1000, 10)) + return buf.WriteString(strconv.FormatInt(timeNow().UnixMicro(), 10)) case "time_unix_nano": - return buf.WriteString(strconv.FormatInt(time.Now().UnixNano(), 10)) + return buf.WriteString(strconv.FormatInt(timeNow().UnixNano(), 10)) case "time_rfc3339": - return buf.WriteString(time.Now().Format(time.RFC3339)) + return buf.WriteString(timeNow().Format(time.RFC3339)) case "time_rfc3339_nano": - return buf.WriteString(time.Now().Format(time.RFC3339Nano)) + return buf.WriteString(timeNow().Format(time.RFC3339Nano)) case "time_custom": - return buf.WriteString(time.Now().Format(config.CustomTimeFormat)) + return buf.WriteString(timeNow().Format(config.CustomTimeFormat)) case "id": id := req.Header.Get(echo.HeaderXRequestID) if id == "" { id = res.Header().Get(echo.HeaderXRequestID) } - return buf.WriteString(id) + return writeString(buf, id) case "remote_ip": - return buf.WriteString(c.RealIP()) + return writeString(buf, c.RealIP()) case "host": - return buf.WriteString(req.Host) + return writeString(buf, req.Host) case "uri": - return buf.WriteString(req.RequestURI) + return writeString(buf, req.RequestURI) case "method": - return buf.WriteString(req.Method) + return writeString(buf, req.Method) case "path": p := req.URL.Path if p == "" { p = "/" } - return buf.WriteString(p) + return writeString(buf, p) case "route": - return buf.WriteString(c.Path()) + return writeString(buf, c.Path()) case "protocol": - return buf.WriteString(req.Proto) + return writeString(buf, req.Proto) case "referer": - return buf.WriteString(req.Referer()) + return writeString(buf, req.Referer()) case "user_agent": - return buf.WriteString(req.UserAgent()) + return writeString(buf, req.UserAgent()) case "status": n := res.Status s := config.colorer.Green(n) @@ -377,17 +390,17 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { if cl == "" { cl = "0" } - return buf.WriteString(cl) + return writeString(buf, cl) case "bytes_out": return buf.WriteString(strconv.FormatInt(res.Size, 10)) default: switch { case strings.HasPrefix(tag, "header:"): - return buf.Write([]byte(c.Request().Header.Get(tag[7:]))) + return writeString(buf, c.Request().Header.Get(tag[7:])) case strings.HasPrefix(tag, "query:"): - return buf.Write([]byte(c.QueryParam(tag[6:]))) + return writeString(buf, c.QueryParam(tag[6:])) case strings.HasPrefix(tag, "form:"): - return buf.Write([]byte(c.FormValue(tag[5:]))) + return writeString(buf, c.FormValue(tag[5:])) case strings.HasPrefix(tag, "cookie:"): cookie, err := c.Cookie(tag[7:]) if err == nil { diff --git a/middleware/logger_strings.go b/middleware/logger_strings.go new file mode 100644 index 000000000..8476cb046 --- /dev/null +++ b/middleware/logger_strings.go @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: BSD-3-Clause +// SPDX-FileCopyrightText: Copyright 2010 The Go Authors +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// +// Go LICENSE https://raw.githubusercontent.com/golang/go/36bca3166e18db52687a4d91ead3f98ffe6d00b8/LICENSE +/** +Copyright 2009 The Go Authors. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google LLC nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +package middleware + +import ( + "bytes" + "unicode/utf8" +) + +// This function is modified copy from Go standard library encoding/json/encode.go `appendString` function +// Source: https://github.com/golang/go/blob/36bca3166e18db52687a4d91ead3f98ffe6d00b8/src/encoding/json/encode.go#L999 +func writeJSONSafeString(buf *bytes.Buffer, src string) (int, error) { + const hex = "0123456789abcdef" + + written := 0 + start := 0 + for i := 0; i < len(src); { + if b := src[i]; b < utf8.RuneSelf { + if safeSet[b] { + i++ + continue + } + + n, err := buf.Write([]byte(src[start:i])) + written += n + if err != nil { + return written, err + } + switch b { + case '\\', '"': + n, err := buf.Write([]byte{'\\', b}) + written += n + if err != nil { + return written, err + } + case '\b': + n, err := buf.Write([]byte{'\\', 'b'}) + written += n + if err != nil { + return n, err + } + case '\f': + n, err := buf.Write([]byte{'\\', 'f'}) + written += n + if err != nil { + return written, err + } + case '\n': + n, err := buf.Write([]byte{'\\', 'n'}) + written += n + if err != nil { + return written, err + } + case '\r': + n, err := buf.Write([]byte{'\\', 'r'}) + written += n + if err != nil { + return written, err + } + case '\t': + n, err := buf.Write([]byte{'\\', 't'}) + written += n + if err != nil { + return written, err + } + default: + // This encodes bytes < 0x20 except for \b, \f, \n, \r and \t. + n, err := buf.Write([]byte{'\\', 'u', '0', '0', hex[b>>4], hex[b&0xF]}) + written += n + if err != nil { + return written, err + } + } + i++ + start = i + continue + } + srcN := min(len(src)-i, utf8.UTFMax) + c, size := utf8.DecodeRuneInString(src[i : i+srcN]) + if c == utf8.RuneError && size == 1 { + n, err := buf.Write([]byte(src[start:i])) + written += n + if err != nil { + return written, err + } + n, err = buf.Write([]byte(`\ufffd`)) + written += n + if err != nil { + return written, err + } + i += size + start = i + continue + } + i += size + } + n, err := buf.Write([]byte(src[start:])) + written += n + return written, err +} + +// safeSet holds the value true if the ASCII character with the given array +// position can be represented inside a JSON string without any further +// escaping. +// +// All values are true except for the ASCII control characters (0-31), the +// double quote ("), and the backslash character ("\"). +var safeSet = [utf8.RuneSelf]bool{ + ' ': true, + '!': true, + '"': false, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '(': true, + ')': true, + '*': true, + '+': true, + ',': true, + '-': true, + '.': true, + '/': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + ':': true, + ';': true, + '<': true, + '=': true, + '>': true, + '?': true, + '@': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'V': true, + 'W': true, + 'X': true, + 'Y': true, + 'Z': true, + '[': true, + '\\': false, + ']': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '{': true, + '|': true, + '}': true, + '~': true, + '\u007f': true, +} diff --git a/middleware/logger_strings_test.go b/middleware/logger_strings_test.go new file mode 100644 index 000000000..90231a683 --- /dev/null +++ b/middleware/logger_strings_test.go @@ -0,0 +1,285 @@ +package middleware + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWriteJSONSafeString(t *testing.T) { + testCases := []struct { + name string + whenInput string + expect string + expectN int + }{ + // Basic cases + { + name: "empty string", + whenInput: "", + expect: "", + expectN: 0, + }, + { + name: "simple ASCII without special chars", + whenInput: "hello", + expect: "hello", + expectN: 5, + }, + { + name: "single character", + whenInput: "a", + expect: "a", + expectN: 1, + }, + { + name: "alphanumeric", + whenInput: "Hello123World", + expect: "Hello123World", + expectN: 13, + }, + + // Special character escaping + { + name: "backslash", + whenInput: `path\to\file`, + expect: `path\\to\\file`, + expectN: 14, + }, + { + name: "double quote", + whenInput: `say "hello"`, + expect: `say \"hello\"`, + expectN: 13, + }, + { + name: "backslash and quote combined", + whenInput: `a\b"c`, + expect: `a\\b\"c`, + expectN: 7, + }, + { + name: "single backslash", + whenInput: `\`, + expect: `\\`, + expectN: 2, + }, + { + name: "single quote", + whenInput: `"`, + expect: `\"`, + expectN: 2, + }, + + // Control character escaping + { + name: "backspace", + whenInput: "hello\bworld", + expect: `hello\bworld`, + expectN: 12, + }, + { + name: "form feed", + whenInput: "hello\fworld", + expect: `hello\fworld`, + expectN: 12, + }, + { + name: "newline", + whenInput: "hello\nworld", + expect: `hello\nworld`, + expectN: 12, + }, + { + name: "carriage return", + whenInput: "hello\rworld", + expect: `hello\rworld`, + expectN: 12, + }, + { + name: "tab", + whenInput: "hello\tworld", + expect: `hello\tworld`, + expectN: 12, + }, + { + name: "multiple newlines", + whenInput: "line1\nline2\nline3", + expect: `line1\nline2\nline3`, + expectN: 19, + }, + + // Low control characters (< 0x20) + { + name: "null byte", + whenInput: "hello\x00world", + expect: `hello\u0000world`, + expectN: 16, + }, + { + name: "control character 0x01", + whenInput: "test\x01value", + expect: `test\u0001value`, + expectN: 15, + }, + { + name: "control character 0x0e", + whenInput: "test\x0evalue", + expect: `test\u000evalue`, + expectN: 15, + }, + { + name: "control character 0x1f", + whenInput: "test\x1fvalue", + expect: `test\u001fvalue`, + expectN: 15, + }, + { + name: "multiple control characters", + whenInput: "\x00\x01\x02", + expect: `\u0000\u0001\u0002`, + expectN: 18, + }, + + // UTF-8 handling + { + name: "valid UTF-8 Chinese", + whenInput: "hello 世界", + expect: "hello 世界", + expectN: 12, + }, + { + name: "valid UTF-8 emoji", + whenInput: "party 🎉 time", + expect: "party 🎉 time", + expectN: 15, + }, + { + name: "mixed ASCII and UTF-8", + whenInput: "Hello世界123", + expect: "Hello世界123", + expectN: 14, + }, + { + name: "UTF-8 with special chars", + whenInput: "世界\n\"test\"", + expect: `世界\n\"test\"`, + expectN: 16, + }, + + // Invalid UTF-8 + { + name: "invalid UTF-8 sequence", + whenInput: "hello\xff\xfeworld", + expect: `hello\ufffd\ufffdworld`, + expectN: 22, + }, + { + name: "incomplete UTF-8 sequence", + whenInput: "test\xc3value", + expect: `test\ufffdvalue`, + expectN: 15, + }, + + // Complex mixed cases + { + name: "all common escapes", + whenInput: "tab\there\nquote\"backslash\\", + expect: `tab\there\nquote\"backslash\\`, + expectN: 29, + }, + { + name: "mixed controls and UTF-8", + whenInput: "hello\t世界\ntest\"", + expect: `hello\t世界\ntest\"`, + expectN: 21, + }, + { + name: "all control characters", + whenInput: "\b\f\n\r\t", + expect: `\b\f\n\r\t`, + expectN: 10, + }, + { + name: "control and low ASCII", + whenInput: "a\nb\x00c", + expect: `a\nb\u0000c`, + expectN: 11, + }, + + // Edge cases + { + name: "starts with special char", + whenInput: "\\start", + expect: `\\start`, + expectN: 7, + }, + { + name: "ends with special char", + whenInput: "end\"", + expect: `end\"`, + expectN: 5, + }, + { + name: "consecutive special chars", + whenInput: "\\\\\"\"", + expect: `\\\\\"\"`, + expectN: 8, + }, + { + name: "only special characters", + whenInput: "\"\\\n\t", + expect: `\"\\\n\t`, + expectN: 8, + }, + { + name: "spaces and punctuation", + whenInput: "Hello, World! How are you?", + expect: "Hello, World! How are you?", + expectN: 26, + }, + { + name: "JSON-like string", + whenInput: "{\"key\":\"value\"}", + expect: `{\"key\":\"value\"}`, + expectN: 19, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + n, err := writeJSONSafeString(buf, tt.whenInput) + + assert.NoError(t, err) + assert.Equal(t, tt.expect, buf.String()) + assert.Equal(t, tt.expectN, n) + }) + } +} + +func BenchmarkWriteJSONSafeString(b *testing.B) { + testCases := []struct { + name string + input string + }{ + {"simple", "hello world"}, + {"with escapes", "tab\there\nquote\"backslash\\"}, + {"utf8", "hello 世界 🎉"}, + {"mixed", "Hello\t世界\ntest\"value\\path"}, + {"long simple", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789"}, + {"long complex", "line1\nline2\tline3\"quote\\slash\x00null世界🎉"}, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + buf := &bytes.Buffer{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + writeJSONSafeString(buf, tc.input) + } + }) + } +} diff --git a/middleware/logger_test.go b/middleware/logger_test.go index d5236e1ac..7c58ce0b4 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -5,12 +5,13 @@ package middleware import ( "bytes" + "cmp" "encoding/json" "errors" "net/http" "net/http/httptest" "net/url" - "strconv" + "regexp" "strings" "testing" "time" @@ -20,72 +21,323 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLogger(t *testing.T) { - // Note: Just for the test coverage, not a real test. - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := Logger()(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - // Status 2xx - h(c) - - // Status 3xx - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = Logger()(func(c echo.Context) error { - return c.String(http.StatusTemporaryRedirect, "test") - }) - h(c) +func TestLoggerDefaultMW(t *testing.T) { + var testCases = []struct { + name string + whenHeader map[string]string + whenStatusCode int + whenResponse string + whenError error + expect string + }{ + { + name: "ok, status 200", + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", + }, + { + name: "ok, status 300", + whenStatusCode: http.StatusTemporaryRedirect, + whenResponse: "test", + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":307,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", + }, + { + name: "ok, handler error = status 500", + whenError: errors.New("error"), + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":36}` + "\n", + }, + { + name: "ok, remote_ip from X-Real-Ip header", + whenHeader: map[string]string{echo.HeaderXRealIP: "127.0.0.1"}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"127.0.0.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", + }, + { + name: "ok, remote_ip from X-Forwarded-For header", + whenHeader: map[string]string{echo.HeaderXForwardedFor: "127.0.0.1"}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"127.0.0.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", + }, + } - // Status 4xx - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = Logger()(func(c echo.Context) error { - return c.String(http.StatusNotFound, "test") - }) - h(c) - - // Status 5xx with empty path - req = httptest.NewRequest(http.MethodGet, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h = Logger()(func(c echo.Context) error { - return errors.New("error") - }) - h(c) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + if len(tc.whenHeader) > 0 { + for k, v := range tc.whenHeader { + req.Header.Add(k, v) + } + } + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + DefaultLoggerConfig.timeNow = func() time.Time { return time.Unix(1588037200, 0).UTC() } + h := Logger()(func(c echo.Context) error { + if tc.whenError != nil { + return tc.whenError + } + return c.String(tc.whenStatusCode, tc.whenResponse) + }) + buf := new(bytes.Buffer) + e.Logger.SetOutput(buf) + + err := h(c) + assert.NoError(t, err) + + result := buf.String() + // handle everchanging latency numbers + result = regexp.MustCompile(`"latency":\d+,`).ReplaceAllString(result, `"latency":1,`) + result = regexp.MustCompile(`"latency_human":"[^"]+"`).ReplaceAllString(result, `"latency_human":"1µs"`) + + assert.Equal(t, tc.expect, result) + }) + } } -func TestLoggerIPAddress(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - ip := "127.0.0.1" - h := Logger()(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) +func TestLoggerWithLoggerConfig(t *testing.T) { + // to handle everchanging latency numbers + jsonLatency := map[string]*regexp.Regexp{ + `"latency":1,`: regexp.MustCompile(`"latency":\d+,`), + `"latency_human":"1µs"`: regexp.MustCompile(`"latency_human":"[^"]+"`), + } - // With X-Real-IP - req.Header.Add(echo.HeaderXRealIP, ip) - h(c) - assert.Contains(t, buf.String(), ip) - - // With X-Forwarded-For - buf.Reset() - req.Header.Del(echo.HeaderXRealIP) - req.Header.Add(echo.HeaderXForwardedFor, ip) - h(c) - assert.Contains(t, buf.String(), ip) - - buf.Reset() - h(c) - assert.Contains(t, buf.String(), ip) + form := make(url.Values) + form.Set("csrf", "token") + form.Add("multiple", "1") + form.Add("multiple", "2") + + var testCases = []struct { + name string + givenConfig LoggerConfig + whenURI string + whenMethod string + whenHost string + whenPath string + whenRoute string + whenProto string + whenRequestURI string + whenHeader map[string]string + whenFormValues url.Values + whenStatusCode int + whenResponse string + whenError error + whenReplacers map[string]*regexp.Regexp + expect string + }{ + { + name: "ok, skipper", + givenConfig: LoggerConfig{ + Skipper: func(c echo.Context) bool { return true }, + }, + expect: ``, + }, + { // this is an example how format that does not seem to be JSON is not currently escaped + name: "ok, NON json string is not escaped: method", + givenConfig: LoggerConfig{Format: `method:"${method}"`}, + whenMethod: `","method":":D"`, + expect: `method:"","method":":D""`, + }, + { + name: "ok, json string escape: method", + givenConfig: LoggerConfig{Format: `{"method":"${method}"}`}, + whenMethod: `","method":":D"`, + expect: `{"method":"\",\"method\":\":D\""}`, + }, + { + name: "ok, json string escape: id", + givenConfig: LoggerConfig{Format: `{"id":"${id}"}`}, + whenHeader: map[string]string{echo.HeaderXRequestID: `\"127.0.0.1\"`}, + expect: `{"id":"\\\"127.0.0.1\\\""}`, + }, + { + name: "ok, json string escape: remote_ip", + givenConfig: LoggerConfig{Format: `{"remote_ip":"${remote_ip}"}`}, + whenHeader: map[string]string{echo.HeaderXForwardedFor: `\"127.0.0.1\"`}, + expect: `{"remote_ip":"\\\"127.0.0.1\\\""}`, + }, + { + name: "ok, json string escape: host", + givenConfig: LoggerConfig{Format: `{"host":"${host}"}`}, + whenHost: `\"127.0.0.1\"`, + expect: `{"host":"\\\"127.0.0.1\\\""}`, + }, + { + name: "ok, json string escape: path", + givenConfig: LoggerConfig{Format: `{"path":"${path}"}`}, + whenPath: `\","` + "\n", + expect: `{"path":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: route", + givenConfig: LoggerConfig{Format: `{"route":"${route}"}`}, + whenRoute: `\","` + "\n", + expect: `{"route":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: proto", + givenConfig: LoggerConfig{Format: `{"protocol":"${protocol}"}`}, + whenProto: `\","` + "\n", + expect: `{"protocol":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: referer", + givenConfig: LoggerConfig{Format: `{"referer":"${referer}"}`}, + whenHeader: map[string]string{"Referer": `\","` + "\n"}, + expect: `{"referer":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: user_agent", + givenConfig: LoggerConfig{Format: `{"user_agent":"${user_agent}"}`}, + whenHeader: map[string]string{"User-Agent": `\","` + "\n"}, + expect: `{"user_agent":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: bytes_in", + givenConfig: LoggerConfig{Format: `{"bytes_in":"${bytes_in}"}`}, + whenHeader: map[string]string{echo.HeaderContentLength: `\","` + "\n"}, + expect: `{"bytes_in":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: query param", + givenConfig: LoggerConfig{Format: `{"query":"${query:test}"}`}, + whenURI: `/?test=1","`, + expect: `{"query":"1\",\""}`, + }, + { + name: "ok, json string escape: header", + givenConfig: LoggerConfig{Format: `{"header":"${header:referer}"}`}, + whenHeader: map[string]string{"referer": `\","` + "\n"}, + expect: `{"header":"\\\",\"\n"}`, + }, + { + name: "ok, json string escape: form", + givenConfig: LoggerConfig{Format: `{"csrf":"${form:csrf}"}`}, + whenMethod: http.MethodPost, + whenFormValues: url.Values{"csrf": {`token","`}}, + expect: `{"csrf":"token\",\""}`, + }, + { + name: "nok, json string escape: cookie - will not accept invalid chars", + // net/cookie.go: validCookieValueByte function allows these byte in cookie value + // only `0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'` + givenConfig: LoggerConfig{Format: `{"cookie":"${cookie:session}"}`}, + whenHeader: map[string]string{"Cookie": `_ga=GA1.2.000000000.0000000000; session=test\n`}, + expect: `{"cookie":""}`, + }, + { + name: "ok, format time_unix", + givenConfig: LoggerConfig{Format: `${time_unix}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `1588037200`, + }, + { + name: "ok, format time_unix_milli", + givenConfig: LoggerConfig{Format: `${time_unix_milli}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `1588037200000`, + }, + { + name: "ok, format time_unix_micro", + givenConfig: LoggerConfig{Format: `${time_unix_micro}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `1588037200000000`, + }, + { + name: "ok, format time_unix_nano", + givenConfig: LoggerConfig{Format: `${time_unix_nano}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `1588037200000000000`, + }, + { + name: "ok, format time_rfc3339", + givenConfig: LoggerConfig{Format: `${time_rfc3339}`}, + whenStatusCode: http.StatusOK, + whenResponse: "test", + expect: `2020-04-28T01:26:40Z`, + }, + { + name: "ok, status 200", + whenStatusCode: http.StatusOK, + whenResponse: "test", + whenReplacers: jsonLatency, + expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, cmp.Or(tc.whenURI, "/"), nil) + if tc.whenFormValues != nil { + req = httptest.NewRequest(http.MethodGet, cmp.Or(tc.whenURI, "/"), strings.NewReader(tc.whenFormValues.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + } + + for k, v := range tc.whenHeader { + req.Header.Add(k, v) + } + if tc.whenHost != "" { + req.Host = tc.whenHost + } + if tc.whenMethod != "" { + req.Method = tc.whenMethod + } + if tc.whenProto != "" { + req.Proto = tc.whenProto + } + if tc.whenRequestURI != "" { + req.RequestURI = tc.whenRequestURI + } + if tc.whenPath != "" { + req.URL.Path = tc.whenPath + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tc.whenFormValues != nil { + c.FormValue("to trigger form parsing") + } + if tc.whenRoute != "" { + c.SetPath(tc.whenRoute) + } + + config := tc.givenConfig + if config.timeNow == nil { + config.timeNow = func() time.Time { return time.Unix(1588037200, 0).UTC() } + } + buf := new(bytes.Buffer) + if config.Output == nil { + e.Logger.SetOutput(buf) + } + + h := LoggerWithConfig(config)(func(c echo.Context) error { + if tc.whenError != nil { + return tc.whenError + } + return c.String(cmp.Or(tc.whenStatusCode, http.StatusOK), cmp.Or(tc.whenResponse, "test")) + }) + + err := h(c) + assert.NoError(t, err) + + result := buf.String() + + for replaceTo, replacer := range tc.whenReplacers { + result = replacer.ReplaceAllString(result, replaceTo) + } + + assert.Equal(t, tc.expect, result) + }) + } } func TestLoggerTemplate(t *testing.T) { @@ -271,49 +523,3 @@ func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) { buf.Reset() } } - -func TestLoggerTemplateWithTimeUnixMilli(t *testing.T) { - buf := new(bytes.Buffer) - - e := echo.New() - e.Use(LoggerWithConfig(LoggerConfig{ - Format: `${time_unix_milli}`, - Output: buf, - })) - - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusOK, "OK") - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - unixMillis, err := strconv.ParseInt(buf.String(), 10, 64) - assert.NoError(t, err) - assert.WithinDuration(t, time.Unix(unixMillis/1000, 0), time.Now(), 3*time.Second) -} - -func TestLoggerTemplateWithTimeUnixMicro(t *testing.T) { - buf := new(bytes.Buffer) - - e := echo.New() - e.Use(LoggerWithConfig(LoggerConfig{ - Format: `${time_unix_micro}`, - Output: buf, - })) - - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusOK, "OK") - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - unixMicros, err := strconv.ParseInt(buf.String(), 10, 64) - assert.NoError(t, err) - assert.WithinDuration(t, time.Unix(unixMicros/1000000, 0), time.Now(), 3*time.Second) -} diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 7c18200b0..211abf464 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -4,7 +4,9 @@ package middleware import ( + "context" "errors" + "log/slog" "net/http" "time" @@ -247,6 +249,72 @@ func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc { return mw } +// RequestLogger returns a RequestLogger middleware with default configuration which +// uses default slog.slog logger. +// +// To customize slog output format replace slog default logger: +// For JSON format: `slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))` +func RequestLogger() echo.MiddlewareFunc { + config := RequestLoggerConfig{ + LogLatency: true, + LogProtocol: false, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogURIPath: false, + LogRoutePath: false, + LogRequestID: true, + LogReferer: false, + LogUserAgent: true, + LogStatus: true, + LogError: true, + LogContentLength: true, + LogResponseSize: true, + LogHeaders: nil, + LogQueryParams: nil, + LogFormValues: nil, + HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code + LogValuesFunc: func(c echo.Context, v RequestLoggerValues) error { + if v.Error == nil { + slog.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", + slog.String("method", v.Method), + slog.String("uri", v.URI), + slog.Int("status", v.Status), + slog.Duration("latency", v.Latency), + slog.String("host", v.Host), + slog.String("bytes_in", v.ContentLength), + slog.Int64("bytes_out", v.ResponseSize), + slog.String("user_agent", v.UserAgent), + slog.String("remote_ip", v.RemoteIP), + slog.String("request_id", v.RequestID), + ) + } else { + slog.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", + slog.String("method", v.Method), + slog.String("uri", v.URI), + slog.Int("status", v.Status), + slog.Duration("latency", v.Latency), + slog.String("host", v.Host), + slog.String("bytes_in", v.ContentLength), + slog.Int64("bytes_out", v.ResponseSize), + slog.String("user_agent", v.UserAgent), + slog.String("remote_ip", v.RemoteIP), + slog.String("request_id", v.RequestID), + + slog.String("error", v.Error.Error()), + ) + } + return nil + }, + } + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + // ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration. func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index c612f5c22..510d34edd 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -4,8 +4,10 @@ package middleware import ( - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" + "bytes" + "encoding/json" + "errors" + "log/slog" "net/http" "net/http/httptest" "net/url" @@ -13,8 +15,105 @@ import ( "strings" "testing" "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" ) +func TestRequestLoggerOK(t *testing.T) { + old := slog.Default() + t.Cleanup(func() { + slog.SetDefault(old) + }) + + buf := new(bytes.Buffer) + slog.SetDefault(slog.New(slog.NewJSONHandler(buf, nil))) + + e := echo.New() + e.Use(RequestLogger()) + + e.POST("/test", func(c echo.Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + reader := strings.NewReader(`{"foo":"bar"}`) + req := httptest.NewRequest(http.MethodPost, "/test", reader) + req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") + req.Header.Set("User-Agent", "curl/7.68.0") + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + logAttrs := map[string]interface{}{} + assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs)) + logAttrs["latency"] = 123 + logAttrs["time"] = "x" + + expect := map[string]interface{}{ + "level": "INFO", + "msg": "REQUEST", + "method": "POST", + "uri": "/test", + "status": float64(418), + "bytes_in": "13", + "host": "example.com", + "bytes_out": float64(2), + "user_agent": "curl/7.68.0", + "remote_ip": "8.8.8.8", + "request_id": "", + + "time": "x", + "latency": 123, + } + assert.Equal(t, expect, logAttrs) +} + +func TestRequestLoggerError(t *testing.T) { + old := slog.Default() + t.Cleanup(func() { + slog.SetDefault(old) + }) + + buf := new(bytes.Buffer) + slog.SetDefault(slog.New(slog.NewJSONHandler(buf, nil))) + + e := echo.New() + e.Use(RequestLogger()) + + e.GET("/test", func(c echo.Context) error { + return errors.New("nope") + }) + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + logAttrs := map[string]interface{}{} + assert.NoError(t, json.Unmarshal(buf.Bytes(), &logAttrs)) + logAttrs["latency"] = 123 + logAttrs["time"] = "x" + + expect := map[string]interface{}{ + "level": "ERROR", + "msg": "REQUEST_ERROR", + "method": "GET", + "uri": "/test", + "status": float64(500), + "bytes_in": "", + "host": "example.com", + "bytes_out": float64(36.0), + "user_agent": "", + "remote_ip": "192.0.2.1", + "request_id": "", + "error": "nope", + + "latency": 123, + "time": "x", + } + assert.Equal(t, expect, logAttrs) +} + func TestRequestLoggerWithConfig(t *testing.T) { e := echo.New()