| Abhay Kumar | a2ae599 | 2025-11-10 14:02:24 +0000 | [diff] [blame^] | 1 | package runtime |
| 2 | |
| 3 | import ( |
| 4 | "errors" |
| 5 | "fmt" |
| 6 | "net/url" |
| 7 | "regexp" |
| 8 | "strconv" |
| 9 | "strings" |
| 10 | "time" |
| 11 | |
| 12 | "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" |
| 13 | "google.golang.org/grpc/grpclog" |
| 14 | "google.golang.org/protobuf/encoding/protojson" |
| 15 | "google.golang.org/protobuf/proto" |
| 16 | "google.golang.org/protobuf/reflect/protoreflect" |
| 17 | "google.golang.org/protobuf/reflect/protoregistry" |
| 18 | "google.golang.org/protobuf/types/known/durationpb" |
| 19 | field_mask "google.golang.org/protobuf/types/known/fieldmaskpb" |
| 20 | "google.golang.org/protobuf/types/known/structpb" |
| 21 | "google.golang.org/protobuf/types/known/timestamppb" |
| 22 | "google.golang.org/protobuf/types/known/wrapperspb" |
| 23 | ) |
| 24 | |
| 25 | var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`) |
| 26 | |
| 27 | var currentQueryParser QueryParameterParser = &DefaultQueryParser{} |
| 28 | |
| 29 | // QueryParameterParser defines interface for all query parameter parsers |
| 30 | type QueryParameterParser interface { |
| 31 | Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error |
| 32 | } |
| 33 | |
| 34 | // PopulateQueryParameters parses query parameters |
| 35 | // into "msg" using current query parser |
| 36 | func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { |
| 37 | return currentQueryParser.Parse(msg, values, filter) |
| 38 | } |
| 39 | |
| 40 | // DefaultQueryParser is a QueryParameterParser which implements the default |
| 41 | // query parameters parsing behavior. |
| 42 | // |
| 43 | // See https://github.com/grpc-ecosystem/grpc-gateway/issues/2632 for more context. |
| 44 | type DefaultQueryParser struct{} |
| 45 | |
| 46 | // Parse populates "values" into "msg". |
| 47 | // A value is ignored if its key starts with one of the elements in "filter". |
| 48 | func (*DefaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { |
| 49 | for key, values := range values { |
| 50 | if match := valuesKeyRegexp.FindStringSubmatch(key); len(match) == 3 { |
| 51 | key = match[1] |
| 52 | values = append([]string{match[2]}, values...) |
| 53 | } |
| 54 | |
| 55 | msgValue := msg.ProtoReflect() |
| 56 | fieldPath := normalizeFieldPath(msgValue, strings.Split(key, ".")) |
| 57 | if filter.HasCommonPrefix(fieldPath) { |
| 58 | continue |
| 59 | } |
| 60 | if err := populateFieldValueFromPath(msgValue, fieldPath, values); err != nil { |
| 61 | return err |
| 62 | } |
| 63 | } |
| 64 | return nil |
| 65 | } |
| 66 | |
| 67 | // PopulateFieldFromPath sets a value in a nested Protobuf structure. |
| 68 | func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error { |
| 69 | fieldPath := strings.Split(fieldPathString, ".") |
| 70 | return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value}) |
| 71 | } |
| 72 | |
| 73 | func normalizeFieldPath(msgValue protoreflect.Message, fieldPath []string) []string { |
| 74 | newFieldPath := make([]string, 0, len(fieldPath)) |
| 75 | for i, fieldName := range fieldPath { |
| 76 | fields := msgValue.Descriptor().Fields() |
| 77 | fieldDesc := fields.ByTextName(fieldName) |
| 78 | if fieldDesc == nil { |
| 79 | fieldDesc = fields.ByJSONName(fieldName) |
| 80 | } |
| 81 | if fieldDesc == nil { |
| 82 | // return initial field path values if no matching message field was found |
| 83 | return fieldPath |
| 84 | } |
| 85 | |
| 86 | newFieldPath = append(newFieldPath, string(fieldDesc.Name())) |
| 87 | |
| 88 | // If this is the last element, we're done |
| 89 | if i == len(fieldPath)-1 { |
| 90 | break |
| 91 | } |
| 92 | |
| 93 | // Only singular message fields are allowed |
| 94 | if fieldDesc.Message() == nil || fieldDesc.Cardinality() == protoreflect.Repeated { |
| 95 | return fieldPath |
| 96 | } |
| 97 | |
| 98 | // Get the nested message |
| 99 | msgValue = msgValue.Get(fieldDesc).Message() |
| 100 | } |
| 101 | |
| 102 | return newFieldPath |
| 103 | } |
| 104 | |
| 105 | func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error { |
| 106 | if len(fieldPath) < 1 { |
| 107 | return errors.New("no field path") |
| 108 | } |
| 109 | if len(values) < 1 { |
| 110 | return errors.New("no value provided") |
| 111 | } |
| 112 | |
| 113 | var fieldDescriptor protoreflect.FieldDescriptor |
| 114 | for i, fieldName := range fieldPath { |
| 115 | fields := msgValue.Descriptor().Fields() |
| 116 | |
| 117 | // Get field by name |
| 118 | fieldDescriptor = fields.ByName(protoreflect.Name(fieldName)) |
| 119 | if fieldDescriptor == nil { |
| 120 | fieldDescriptor = fields.ByJSONName(fieldName) |
| 121 | if fieldDescriptor == nil { |
| 122 | // We're not returning an error here because this could just be |
| 123 | // an extra query parameter that isn't part of the request. |
| 124 | grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, ".")) |
| 125 | return nil |
| 126 | } |
| 127 | } |
| 128 | |
| 129 | // Check if oneof already set |
| 130 | if of := fieldDescriptor.ContainingOneof(); of != nil && !of.IsSynthetic() { |
| 131 | if f := msgValue.WhichOneof(of); f != nil { |
| 132 | if fieldDescriptor.Message() == nil || fieldDescriptor.FullName() != f.FullName() { |
| 133 | return fmt.Errorf("field already set for oneof %q", of.FullName().Name()) |
| 134 | } |
| 135 | } |
| 136 | } |
| 137 | |
| 138 | // If this is the last element, we're done |
| 139 | if i == len(fieldPath)-1 { |
| 140 | break |
| 141 | } |
| 142 | |
| 143 | // Only singular message fields are allowed |
| 144 | if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated { |
| 145 | return fmt.Errorf("invalid path: %q is not a message", fieldName) |
| 146 | } |
| 147 | |
| 148 | // Get the nested message |
| 149 | msgValue = msgValue.Mutable(fieldDescriptor).Message() |
| 150 | } |
| 151 | |
| 152 | switch { |
| 153 | case fieldDescriptor.IsList(): |
| 154 | return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values) |
| 155 | case fieldDescriptor.IsMap(): |
| 156 | return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values) |
| 157 | } |
| 158 | |
| 159 | if len(values) > 1 { |
| 160 | return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", ")) |
| 161 | } |
| 162 | |
| 163 | return populateField(fieldDescriptor, msgValue, values[0]) |
| 164 | } |
| 165 | |
| 166 | func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error { |
| 167 | v, err := parseField(fieldDescriptor, value) |
| 168 | if err != nil { |
| 169 | return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err) |
| 170 | } |
| 171 | |
| 172 | msgValue.Set(fieldDescriptor, v) |
| 173 | return nil |
| 174 | } |
| 175 | |
| 176 | func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error { |
| 177 | for _, value := range values { |
| 178 | v, err := parseField(fieldDescriptor, value) |
| 179 | if err != nil { |
| 180 | return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err) |
| 181 | } |
| 182 | list.Append(v) |
| 183 | } |
| 184 | |
| 185 | return nil |
| 186 | } |
| 187 | |
| 188 | func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error { |
| 189 | if len(values) != 2 { |
| 190 | return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName()) |
| 191 | } |
| 192 | |
| 193 | key, err := parseField(fieldDescriptor.MapKey(), values[0]) |
| 194 | if err != nil { |
| 195 | return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err) |
| 196 | } |
| 197 | |
| 198 | value, err := parseField(fieldDescriptor.MapValue(), values[1]) |
| 199 | if err != nil { |
| 200 | return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err) |
| 201 | } |
| 202 | |
| 203 | mp.Set(key.MapKey(), value) |
| 204 | |
| 205 | return nil |
| 206 | } |
| 207 | |
| 208 | func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) { |
| 209 | switch fieldDescriptor.Kind() { |
| 210 | case protoreflect.BoolKind: |
| 211 | v, err := strconv.ParseBool(value) |
| 212 | if err != nil { |
| 213 | return protoreflect.Value{}, err |
| 214 | } |
| 215 | return protoreflect.ValueOfBool(v), nil |
| 216 | case protoreflect.EnumKind: |
| 217 | enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName()) |
| 218 | if err != nil { |
| 219 | if errors.Is(err, protoregistry.NotFound) { |
| 220 | return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName()) |
| 221 | } |
| 222 | return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err) |
| 223 | } |
| 224 | // Look for enum by name |
| 225 | v := enum.Descriptor().Values().ByName(protoreflect.Name(value)) |
| 226 | if v == nil { |
| 227 | i, err := strconv.Atoi(value) |
| 228 | if err != nil { |
| 229 | return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value) |
| 230 | } |
| 231 | // Look for enum by number |
| 232 | if v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)); v == nil { |
| 233 | return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value) |
| 234 | } |
| 235 | } |
| 236 | return protoreflect.ValueOfEnum(v.Number()), nil |
| 237 | case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: |
| 238 | v, err := strconv.ParseInt(value, 10, 32) |
| 239 | if err != nil { |
| 240 | return protoreflect.Value{}, err |
| 241 | } |
| 242 | return protoreflect.ValueOfInt32(int32(v)), nil |
| 243 | case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: |
| 244 | v, err := strconv.ParseInt(value, 10, 64) |
| 245 | if err != nil { |
| 246 | return protoreflect.Value{}, err |
| 247 | } |
| 248 | return protoreflect.ValueOfInt64(v), nil |
| 249 | case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: |
| 250 | v, err := strconv.ParseUint(value, 10, 32) |
| 251 | if err != nil { |
| 252 | return protoreflect.Value{}, err |
| 253 | } |
| 254 | return protoreflect.ValueOfUint32(uint32(v)), nil |
| 255 | case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: |
| 256 | v, err := strconv.ParseUint(value, 10, 64) |
| 257 | if err != nil { |
| 258 | return protoreflect.Value{}, err |
| 259 | } |
| 260 | return protoreflect.ValueOfUint64(v), nil |
| 261 | case protoreflect.FloatKind: |
| 262 | v, err := strconv.ParseFloat(value, 32) |
| 263 | if err != nil { |
| 264 | return protoreflect.Value{}, err |
| 265 | } |
| 266 | return protoreflect.ValueOfFloat32(float32(v)), nil |
| 267 | case protoreflect.DoubleKind: |
| 268 | v, err := strconv.ParseFloat(value, 64) |
| 269 | if err != nil { |
| 270 | return protoreflect.Value{}, err |
| 271 | } |
| 272 | return protoreflect.ValueOfFloat64(v), nil |
| 273 | case protoreflect.StringKind: |
| 274 | return protoreflect.ValueOfString(value), nil |
| 275 | case protoreflect.BytesKind: |
| 276 | v, err := Bytes(value) |
| 277 | if err != nil { |
| 278 | return protoreflect.Value{}, err |
| 279 | } |
| 280 | return protoreflect.ValueOfBytes(v), nil |
| 281 | case protoreflect.MessageKind, protoreflect.GroupKind: |
| 282 | return parseMessage(fieldDescriptor.Message(), value) |
| 283 | default: |
| 284 | panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind())) |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) { |
| 289 | var msg proto.Message |
| 290 | switch msgDescriptor.FullName() { |
| 291 | case "google.protobuf.Timestamp": |
| 292 | t, err := time.Parse(time.RFC3339Nano, value) |
| 293 | if err != nil { |
| 294 | return protoreflect.Value{}, err |
| 295 | } |
| 296 | timestamp := timestamppb.New(t) |
| 297 | if ok := timestamp.IsValid(); !ok { |
| 298 | return protoreflect.Value{}, fmt.Errorf("%s before 0001-01-01", value) |
| 299 | } |
| 300 | msg = timestamp |
| 301 | case "google.protobuf.Duration": |
| 302 | d, err := time.ParseDuration(value) |
| 303 | if err != nil { |
| 304 | return protoreflect.Value{}, err |
| 305 | } |
| 306 | msg = durationpb.New(d) |
| 307 | case "google.protobuf.DoubleValue": |
| 308 | v, err := strconv.ParseFloat(value, 64) |
| 309 | if err != nil { |
| 310 | return protoreflect.Value{}, err |
| 311 | } |
| 312 | msg = wrapperspb.Double(v) |
| 313 | case "google.protobuf.FloatValue": |
| 314 | v, err := strconv.ParseFloat(value, 32) |
| 315 | if err != nil { |
| 316 | return protoreflect.Value{}, err |
| 317 | } |
| 318 | msg = wrapperspb.Float(float32(v)) |
| 319 | case "google.protobuf.Int64Value": |
| 320 | v, err := strconv.ParseInt(value, 10, 64) |
| 321 | if err != nil { |
| 322 | return protoreflect.Value{}, err |
| 323 | } |
| 324 | msg = wrapperspb.Int64(v) |
| 325 | case "google.protobuf.Int32Value": |
| 326 | v, err := strconv.ParseInt(value, 10, 32) |
| 327 | if err != nil { |
| 328 | return protoreflect.Value{}, err |
| 329 | } |
| 330 | msg = wrapperspb.Int32(int32(v)) |
| 331 | case "google.protobuf.UInt64Value": |
| 332 | v, err := strconv.ParseUint(value, 10, 64) |
| 333 | if err != nil { |
| 334 | return protoreflect.Value{}, err |
| 335 | } |
| 336 | msg = wrapperspb.UInt64(v) |
| 337 | case "google.protobuf.UInt32Value": |
| 338 | v, err := strconv.ParseUint(value, 10, 32) |
| 339 | if err != nil { |
| 340 | return protoreflect.Value{}, err |
| 341 | } |
| 342 | msg = wrapperspb.UInt32(uint32(v)) |
| 343 | case "google.protobuf.BoolValue": |
| 344 | v, err := strconv.ParseBool(value) |
| 345 | if err != nil { |
| 346 | return protoreflect.Value{}, err |
| 347 | } |
| 348 | msg = wrapperspb.Bool(v) |
| 349 | case "google.protobuf.StringValue": |
| 350 | msg = wrapperspb.String(value) |
| 351 | case "google.protobuf.BytesValue": |
| 352 | v, err := Bytes(value) |
| 353 | if err != nil { |
| 354 | return protoreflect.Value{}, err |
| 355 | } |
| 356 | msg = wrapperspb.Bytes(v) |
| 357 | case "google.protobuf.FieldMask": |
| 358 | fm := &field_mask.FieldMask{} |
| 359 | fm.Paths = append(fm.Paths, strings.Split(value, ",")...) |
| 360 | msg = fm |
| 361 | case "google.protobuf.Value": |
| 362 | var v structpb.Value |
| 363 | if err := protojson.Unmarshal([]byte(value), &v); err != nil { |
| 364 | return protoreflect.Value{}, err |
| 365 | } |
| 366 | msg = &v |
| 367 | case "google.protobuf.Struct": |
| 368 | var v structpb.Struct |
| 369 | if err := protojson.Unmarshal([]byte(value), &v); err != nil { |
| 370 | return protoreflect.Value{}, err |
| 371 | } |
| 372 | msg = &v |
| 373 | default: |
| 374 | return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName())) |
| 375 | } |
| 376 | |
| 377 | return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil |
| 378 | } |