Akash Kankanala | 761955c | 2024-02-21 19:32:20 +0530 | [diff] [blame^] | 1 | /* |
| 2 | * |
| 3 | * Copyright 2023 gRPC authors. |
| 4 | * |
| 5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | * you may not use this file except in compliance with the License. |
| 7 | * You may obtain a copy of the License at |
| 8 | * |
| 9 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | * |
| 11 | * Unless required by applicable law or agreed to in writing, software |
| 12 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | * See the License for the specific language governing permissions and |
| 15 | * limitations under the License. |
| 16 | * |
| 17 | */ |
| 18 | |
| 19 | package serviceconfig |
| 20 | |
| 21 | import ( |
| 22 | "encoding/json" |
| 23 | "fmt" |
| 24 | "math" |
| 25 | "strconv" |
| 26 | "strings" |
| 27 | "time" |
| 28 | ) |
| 29 | |
| 30 | // Duration defines JSON marshal and unmarshal methods to conform to the |
| 31 | // protobuf JSON spec defined [here]. |
| 32 | // |
| 33 | // [here]: https://protobuf.dev/reference/protobuf/google.protobuf/#duration |
| 34 | type Duration time.Duration |
| 35 | |
| 36 | func (d Duration) String() string { |
| 37 | return fmt.Sprint(time.Duration(d)) |
| 38 | } |
| 39 | |
| 40 | // MarshalJSON converts from d to a JSON string output. |
| 41 | func (d Duration) MarshalJSON() ([]byte, error) { |
| 42 | ns := time.Duration(d).Nanoseconds() |
| 43 | sec := ns / int64(time.Second) |
| 44 | ns = ns % int64(time.Second) |
| 45 | |
| 46 | var sign string |
| 47 | if sec < 0 || ns < 0 { |
| 48 | sign, sec, ns = "-", -1*sec, -1*ns |
| 49 | } |
| 50 | |
| 51 | // Generated output always contains 0, 3, 6, or 9 fractional digits, |
| 52 | // depending on required precision. |
| 53 | str := fmt.Sprintf("%s%d.%09d", sign, sec, ns) |
| 54 | str = strings.TrimSuffix(str, "000") |
| 55 | str = strings.TrimSuffix(str, "000") |
| 56 | str = strings.TrimSuffix(str, ".000") |
| 57 | return []byte(fmt.Sprintf("\"%ss\"", str)), nil |
| 58 | } |
| 59 | |
| 60 | // UnmarshalJSON unmarshals b as a duration JSON string into d. |
| 61 | func (d *Duration) UnmarshalJSON(b []byte) error { |
| 62 | var s string |
| 63 | if err := json.Unmarshal(b, &s); err != nil { |
| 64 | return err |
| 65 | } |
| 66 | if !strings.HasSuffix(s, "s") { |
| 67 | return fmt.Errorf("malformed duration %q: missing seconds unit", s) |
| 68 | } |
| 69 | neg := false |
| 70 | if s[0] == '-' { |
| 71 | neg = true |
| 72 | s = s[1:] |
| 73 | } |
| 74 | ss := strings.SplitN(s[:len(s)-1], ".", 3) |
| 75 | if len(ss) > 2 { |
| 76 | return fmt.Errorf("malformed duration %q: too many decimals", s) |
| 77 | } |
| 78 | // hasDigits is set if either the whole or fractional part of the number is |
| 79 | // present, since both are optional but one is required. |
| 80 | hasDigits := false |
| 81 | var sec, ns int64 |
| 82 | if len(ss[0]) > 0 { |
| 83 | var err error |
| 84 | if sec, err = strconv.ParseInt(ss[0], 10, 64); err != nil { |
| 85 | return fmt.Errorf("malformed duration %q: %v", s, err) |
| 86 | } |
| 87 | // Maximum seconds value per the durationpb spec. |
| 88 | const maxProtoSeconds = 315_576_000_000 |
| 89 | if sec > maxProtoSeconds { |
| 90 | return fmt.Errorf("out of range: %q", s) |
| 91 | } |
| 92 | hasDigits = true |
| 93 | } |
| 94 | if len(ss) == 2 && len(ss[1]) > 0 { |
| 95 | if len(ss[1]) > 9 { |
| 96 | return fmt.Errorf("malformed duration %q: too many digits after decimal", s) |
| 97 | } |
| 98 | var err error |
| 99 | if ns, err = strconv.ParseInt(ss[1], 10, 64); err != nil { |
| 100 | return fmt.Errorf("malformed duration %q: %v", s, err) |
| 101 | } |
| 102 | for i := 9; i > len(ss[1]); i-- { |
| 103 | ns *= 10 |
| 104 | } |
| 105 | hasDigits = true |
| 106 | } |
| 107 | if !hasDigits { |
| 108 | return fmt.Errorf("malformed duration %q: contains no numbers", s) |
| 109 | } |
| 110 | |
| 111 | if neg { |
| 112 | sec *= -1 |
| 113 | ns *= -1 |
| 114 | } |
| 115 | |
| 116 | // Maximum/minimum seconds/nanoseconds representable by Go's time.Duration. |
| 117 | const maxSeconds = math.MaxInt64 / int64(time.Second) |
| 118 | const maxNanosAtMaxSeconds = math.MaxInt64 % int64(time.Second) |
| 119 | const minSeconds = math.MinInt64 / int64(time.Second) |
| 120 | const minNanosAtMinSeconds = math.MinInt64 % int64(time.Second) |
| 121 | |
| 122 | if sec > maxSeconds || (sec == maxSeconds && ns >= maxNanosAtMaxSeconds) { |
| 123 | *d = Duration(math.MaxInt64) |
| 124 | } else if sec < minSeconds || (sec == minSeconds && ns <= minNanosAtMinSeconds) { |
| 125 | *d = Duration(math.MinInt64) |
| 126 | } else { |
| 127 | *d = Duration(sec*int64(time.Second) + ns) |
| 128 | } |
| 129 | return nil |
| 130 | } |