diff --git a/nat/nat.go b/nat/nat.go index 1ffe0355..135a826f 100644 --- a/nat/nat.go +++ b/nat/nat.go @@ -27,19 +27,15 @@ type PortSet map[Port]struct{} type Port string // NewPort creates a new instance of a Port given a protocol and port number or port range -func NewPort(proto, port string) (Port, error) { - // Check for parsing issues on "port" now so we can avoid having - // to check it later on. - - portStartInt, portEndInt, err := ParsePortRangeToInt(port) +func NewPort(proto, portOrRange string) (Port, error) { + start, end, err := parsePortRange(portOrRange) if err != nil { return "", err } - - if portStartInt == portEndInt { - return Port(fmt.Sprintf("%d/%s", portStartInt, proto)), nil + if start == end { + return Port(fmt.Sprintf("%d/%s", start, proto)), nil } - return Port(fmt.Sprintf("%d-%d/%s", portStartInt, portEndInt, proto)), nil + return Port(fmt.Sprintf("%d-%d/%s", start, end, proto)), nil } // ParsePort parses the port number string and returns an int @@ -47,49 +43,53 @@ func ParsePort(rawPort string) (int, error) { if rawPort == "" { return 0, nil } - port, err := strconv.ParseUint(rawPort, 10, 16) + port, err := parsePortNumber(rawPort) if err != nil { - return 0, fmt.Errorf("invalid port '%s': %w", rawPort, errors.Unwrap(err)) + return 0, fmt.Errorf("invalid port '%s': %w", rawPort, err) } - return int(port), nil + return port, nil } // ParsePortRangeToInt parses the port range string and returns start/end ints -func ParsePortRangeToInt(rawPort string) (int, int, error) { +func ParsePortRangeToInt(rawPort string) (startPort, endPort int, _ error) { if rawPort == "" { + // TODO(thaJeztah): consider making this an error; this was kept to keep existing behavior. return 0, 0, nil } - start, end, err := ParsePortRange(rawPort) - if err != nil { - return 0, 0, err - } - return int(start), int(end), nil + return parsePortRange(rawPort) } // Proto returns the protocol of a Port func (p Port) Proto() string { - proto, _ := SplitProtoPort(string(p)) + _, proto, _ := strings.Cut(string(p), "/") + if proto == "" { + proto = "tcp" + } return proto } // Port returns the port number of a Port func (p Port) Port() string { - _, port := SplitProtoPort(string(p)) + port, _, _ := strings.Cut(string(p), "/") return port } -// Int returns the port number of a Port as an int +// Int returns the port number of a Port as an int. It assumes [Port] +// is valid, and returns 0 otherwise. func (p Port) Int() int { - portStr := p.Port() // We don't need to check for an error because we're going to - // assume that any error would have been found, and reported, in NewPort() - port, _ := ParsePort(portStr) + // assume that any error would have been found, and reported, in [NewPort] + port, _ := parsePortNumber(p.Port()) return port } // Range returns the start/end port numbers of a Port range as ints func (p Port) Range() (int, int, error) { - return ParsePortRangeToInt(p.Port()) + portRange := p.Port() + if portRange == "" { + return 0, 0, nil + } + return parsePortRange(portRange) } // SplitProtoPort splits a port(range) and protocol, formatted as "/[]" @@ -193,14 +193,14 @@ func ParsePortSpec(rawPort string) ([]PortMapping, error) { return nil, fmt.Errorf("no port specified: %s", rawPort) } - startPort, endPort, err := ParsePortRange(containerPort) + startPort, endPort, err := parsePortRange(containerPort) if err != nil { return nil, errors.New("invalid containerPort: " + containerPort) } - var startHostPort, endHostPort uint64 + var startHostPort, endHostPort int if hostPort != "" { - startHostPort, endHostPort, err = ParsePortRange(hostPort) + startHostPort, endHostPort, err = parsePortRange(hostPort) if err != nil { return nil, errors.New("invalid hostPort: " + hostPort) } @@ -217,19 +217,18 @@ func ParsePortSpec(rawPort string) ([]PortMapping, error) { count := endPort - startPort + 1 ports := make([]PortMapping, 0, count) - for i := uint64(0); i < count; i++ { - cPort := Port(strconv.FormatUint(startPort+i, 10) + "/" + proto) + for i := 0; i < count; i++ { hPort := "" if hostPort != "" { - hPort = strconv.FormatUint(startHostPort+i, 10) + hPort = strconv.Itoa(startHostPort + i) // Set hostPort to a range only if there is a single container port // and a dynamic host port. if count == 1 && startHostPort != endHostPort { - hPort += "-" + strconv.FormatUint(endHostPort, 10) + hPort += "-" + strconv.Itoa(endHostPort) } } ports = append(ports, PortMapping{ - Port: cPort, + Port: Port(strconv.Itoa(startPort+i) + "/" + proto), Binding: PortBinding{HostIP: ip, HostPort: hPort}, }) } diff --git a/nat/nat_test.go b/nat/nat_test.go index 6ee59e1f..bccc307a 100644 --- a/nat/nat_test.go +++ b/nat/nat_test.go @@ -38,7 +38,7 @@ func TestParsePort(t *testing.T) { doc: "negative value", input: "-1", expPort: 0, - expErr: `invalid port '-1': invalid syntax`, + expErr: `invalid port '-1': value out of range (0–65535)`, }, // FIXME currently this is a valid port. I don't think it should be. // I'm leaving this test until we make a decision. @@ -57,7 +57,7 @@ func TestParsePort(t *testing.T) { doc: "value out of range", input: "65536", expPort: 0, - expErr: `invalid port '65536': value out of range`, + expErr: `invalid port '65536': value out of range (0–65535)`, }, } @@ -675,11 +675,11 @@ func TestParseNetworkOptsNegativePorts(t *testing.T) { t.Fail() } if len(ports) != 0 { - t.Logf("Expected nil got %d", len(ports)) + t.Logf("Expected 0 got %d: %#v", len(ports), ports) t.Fail() } if len(bindings) != 0 { - t.Logf("Expected 0 got %d", len(bindings)) + t.Logf("Expected 0 got %d: %#v", len(bindings), bindings) t.Fail() } } @@ -690,11 +690,11 @@ func TestParseNetworkOptsUdp(t *testing.T) { t.Fatal(err) } if len(ports) != 1 { - t.Logf("Expected 1 got %d", len(ports)) + t.Logf("Expected 1 got %d: %#v", len(ports), ports) t.FailNow() } if len(bindings) != 1 { - t.Logf("Expected 1 got %d", len(bindings)) + t.Logf("Expected 1 got %d: %#v", len(bindings), bindings) t.FailNow() } for k := range ports { @@ -732,11 +732,11 @@ func TestParseNetworkOptsSctp(t *testing.T) { t.Fatal(err) } if len(ports) != 1 { - t.Logf("Expected 1 got %d", len(ports)) + t.Logf("Expected 1 got %d: %#v", len(ports), ports) t.FailNow() } if len(bindings) != 1 { - t.Logf("Expected 1 got %d", len(bindings)) + t.Logf("Expected 1 got %d: %#v", len(bindings), bindings) t.FailNow() } for k := range ports { diff --git a/nat/parse.go b/nat/parse.go index 64affa2a..f6f86bd0 100644 --- a/nat/parse.go +++ b/nat/parse.go @@ -2,32 +2,59 @@ package nat import ( "errors" + "fmt" "strconv" "strings" ) -// ParsePortRange parses and validates the specified string as a port-range (8000-9000) -func ParsePortRange(ports string) (uint64, uint64, error) { +// ParsePortRange parses and validates the specified string as a port range (e.g., "8000-9000"). +func ParsePortRange(ports string) (startPort, endPort uint64, _ error) { + start, end, err := parsePortRange(ports) + return uint64(start), uint64(end), err +} + +// parsePortRange parses and validates the specified string as a port range (e.g., "8000-9000"). +func parsePortRange(ports string) (startPort, endPort int, _ error) { if ports == "" { return 0, 0, errors.New("empty string specified for ports") } - if !strings.Contains(ports, "-") { - start, err := strconv.ParseUint(ports, 10, 16) - end := start - return start, end, err + start, end, ok := strings.Cut(ports, "-") + + startPort, err := parsePortNumber(start) + if err != nil { + return 0, 0, fmt.Errorf("invalid start port '%s': %w", start, err) + } + if !ok || start == end { + return startPort, startPort, nil } - parts := strings.Split(ports, "-") - start, err := strconv.ParseUint(parts[0], 10, 16) + endPort, err = parsePortNumber(end) if err != nil { - return 0, 0, err + return 0, 0, fmt.Errorf("invalid end port '%s': %w", end, err) } - end, err := strconv.ParseUint(parts[1], 10, 16) + if endPort < startPort { + return 0, 0, errors.New("invalid port range: " + ports) + } + return startPort, endPort, nil +} + +// parsePortNumber parses rawPort into an int, unwrapping strconv errors +// and returning a single "out of range" error for any value outside 0–65535. +func parsePortNumber(rawPort string) (int, error) { + if rawPort == "" { + return 0, errors.New("value is empty") + } + port, err := strconv.ParseInt(rawPort, 10, 0) if err != nil { - return 0, 0, err + var numErr *strconv.NumError + if errors.As(err, &numErr) { + err = numErr.Err + } + return 0, err } - if end < start { - return 0, 0, errors.New("invalid range specified for port: " + ports) + if port < 0 || port > 65535 { + return 0, errors.New("value out of range (0–65535)") } - return start, end, nil + + return int(port), nil } diff --git a/nat/parse_test.go b/nat/parse_test.go index a9b3f089..039bd200 100644 --- a/nat/parse_test.go +++ b/nat/parse_test.go @@ -50,52 +50,52 @@ func TestParsePortRange(t *testing.T) { { doc: "non-numeric port", input: "asdf", - expErr: `strconv.ParseUint: parsing "asdf": invalid syntax`, + expErr: `invalid start port 'asdf': invalid syntax`, }, { doc: "reversed range", input: "9000-8000", - expErr: `invalid range specified for port: 9000-8000`, + expErr: `invalid port range: 9000-8000`, }, { doc: "range missing end", input: "8000-", - expErr: `strconv.ParseUint: parsing "": invalid syntax`, + expErr: `invalid end port '': value is empty`, }, { doc: "range missing start", input: "-9000", - expErr: `strconv.ParseUint: parsing "": invalid syntax`, + expErr: `invalid start port '': value is empty`, }, { doc: "invalid range end", input: "8000-a", - expErr: `strconv.ParseUint: parsing "a": invalid syntax`, + expErr: `invalid end port 'a': invalid syntax`, }, { doc: "invalid range end port", input: "8000-9000a", - expErr: `strconv.ParseUint: parsing "9000a": invalid syntax`, + expErr: `invalid end port '9000a': invalid syntax`, }, { doc: "range range start", input: "a-9000", - expErr: `strconv.ParseUint: parsing "a": invalid syntax`, + expErr: `invalid start port 'a': invalid syntax`, }, { doc: "range range start port", input: "8000a-9000", - expErr: `strconv.ParseUint: parsing "8000a": invalid syntax`, + expErr: `invalid start port '8000a': invalid syntax`, }, { doc: "range with trailing hyphen", input: "-8000-", - expErr: `strconv.ParseUint: parsing "": invalid syntax`, + expErr: `invalid start port '': value is empty`, }, { doc: "range without ports", input: "-", - expErr: `strconv.ParseUint: parsing "": invalid syntax`, + expErr: `invalid start port '': value is empty`, }, } @@ -120,3 +120,85 @@ func TestParsePortRange(t *testing.T) { }) } } + +func TestParsePortNumber(t *testing.T) { + tests := []struct { + doc string + input string + exp int + expErr string + }{ + { + doc: "empty string", + input: "", + expErr: "value is empty", + }, + { + doc: "whitespace only", + input: " ", + expErr: "invalid syntax", + }, + { + doc: "single valid port", + input: "1234", + exp: 1234, + }, + { + doc: "zero port", + input: "0", + exp: 0, + }, + { + doc: "max valid port", + input: "65535", + exp: 65535, + }, + { + doc: "leading/trailing spaces", + input: " 42 ", + expErr: "invalid syntax", + }, + { + doc: "negative port", + input: "-1", + expErr: "value out of range (0–65535)", + }, + { + doc: "too large port", + input: "70000", + expErr: "value out of range (0–65535)", + }, + { + doc: "non-numeric", + input: "foo", + expErr: "invalid syntax", + }, + { + doc: "trailing garbage", + input: "1234abc", + expErr: "invalid syntax", + }, + } + + for _, tc := range tests { + t.Run(tc.doc, func(t *testing.T) { + got, err := parsePortNumber(tc.input) + + if tc.expErr == "" { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.exp { + t.Errorf("expected %d, got %d", tc.exp, got) + } + } else { + if err == nil { + t.Fatalf("expected error %q, got nil", tc.expErr) + } + if err.Error() != tc.expErr { + t.Errorf("expected error %q, got %q", tc.expErr, err.Error()) + } + } + }) + } +}