blob: 8549dfb97afb0c9e68aa1349d05c87a40fdca117 [file] [log] [blame]
Abhay Kumara2ae5992025-11-10 14:02:24 +00001package runtime
2
3import (
4 "errors"
5 "fmt"
6 "net/url"
7 "regexp"
8 "strconv"
9 "strings"
10 "time"
11
12 "github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
13 "google.golang.org/grpc/grpclog"
14 "google.golang.org/protobuf/encoding/protojson"
15 "google.golang.org/protobuf/proto"
16 "google.golang.org/protobuf/reflect/protoreflect"
17 "google.golang.org/protobuf/reflect/protoregistry"
18 "google.golang.org/protobuf/types/known/durationpb"
19 field_mask "google.golang.org/protobuf/types/known/fieldmaskpb"
20 "google.golang.org/protobuf/types/known/structpb"
21 "google.golang.org/protobuf/types/known/timestamppb"
22 "google.golang.org/protobuf/types/known/wrapperspb"
23)
24
25var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
26
27var currentQueryParser QueryParameterParser = &DefaultQueryParser{}
28
29// QueryParameterParser defines interface for all query parameter parsers
30type QueryParameterParser interface {
31 Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
32}
33
34// PopulateQueryParameters parses query parameters
35// into "msg" using current query parser
36func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
37 return currentQueryParser.Parse(msg, values, filter)
38}
39
40// DefaultQueryParser is a QueryParameterParser which implements the default
41// query parameters parsing behavior.
42//
43// See https://github.com/grpc-ecosystem/grpc-gateway/issues/2632 for more context.
44type DefaultQueryParser struct{}
45
46// Parse populates "values" into "msg".
47// A value is ignored if its key starts with one of the elements in "filter".
48func (*DefaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
49 for key, values := range values {
50 if match := valuesKeyRegexp.FindStringSubmatch(key); len(match) == 3 {
51 key = match[1]
52 values = append([]string{match[2]}, values...)
53 }
54
55 msgValue := msg.ProtoReflect()
56 fieldPath := normalizeFieldPath(msgValue, strings.Split(key, "."))
57 if filter.HasCommonPrefix(fieldPath) {
58 continue
59 }
60 if err := populateFieldValueFromPath(msgValue, fieldPath, values); err != nil {
61 return err
62 }
63 }
64 return nil
65}
66
67// PopulateFieldFromPath sets a value in a nested Protobuf structure.
68func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
69 fieldPath := strings.Split(fieldPathString, ".")
70 return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
71}
72
73func normalizeFieldPath(msgValue protoreflect.Message, fieldPath []string) []string {
74 newFieldPath := make([]string, 0, len(fieldPath))
75 for i, fieldName := range fieldPath {
76 fields := msgValue.Descriptor().Fields()
77 fieldDesc := fields.ByTextName(fieldName)
78 if fieldDesc == nil {
79 fieldDesc = fields.ByJSONName(fieldName)
80 }
81 if fieldDesc == nil {
82 // return initial field path values if no matching message field was found
83 return fieldPath
84 }
85
86 newFieldPath = append(newFieldPath, string(fieldDesc.Name()))
87
88 // If this is the last element, we're done
89 if i == len(fieldPath)-1 {
90 break
91 }
92
93 // Only singular message fields are allowed
94 if fieldDesc.Message() == nil || fieldDesc.Cardinality() == protoreflect.Repeated {
95 return fieldPath
96 }
97
98 // Get the nested message
99 msgValue = msgValue.Get(fieldDesc).Message()
100 }
101
102 return newFieldPath
103}
104
105func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
106 if len(fieldPath) < 1 {
107 return errors.New("no field path")
108 }
109 if len(values) < 1 {
110 return errors.New("no value provided")
111 }
112
113 var fieldDescriptor protoreflect.FieldDescriptor
114 for i, fieldName := range fieldPath {
115 fields := msgValue.Descriptor().Fields()
116
117 // Get field by name
118 fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
119 if fieldDescriptor == nil {
120 fieldDescriptor = fields.ByJSONName(fieldName)
121 if fieldDescriptor == nil {
122 // We're not returning an error here because this could just be
123 // an extra query parameter that isn't part of the request.
124 grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
125 return nil
126 }
127 }
128
129 // Check if oneof already set
130 if of := fieldDescriptor.ContainingOneof(); of != nil && !of.IsSynthetic() {
131 if f := msgValue.WhichOneof(of); f != nil {
132 if fieldDescriptor.Message() == nil || fieldDescriptor.FullName() != f.FullName() {
133 return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
134 }
135 }
136 }
137
138 // If this is the last element, we're done
139 if i == len(fieldPath)-1 {
140 break
141 }
142
143 // Only singular message fields are allowed
144 if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
145 return fmt.Errorf("invalid path: %q is not a message", fieldName)
146 }
147
148 // Get the nested message
149 msgValue = msgValue.Mutable(fieldDescriptor).Message()
150 }
151
152 switch {
153 case fieldDescriptor.IsList():
154 return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
155 case fieldDescriptor.IsMap():
156 return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
157 }
158
159 if len(values) > 1 {
160 return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
161 }
162
163 return populateField(fieldDescriptor, msgValue, values[0])
164}
165
166func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
167 v, err := parseField(fieldDescriptor, value)
168 if err != nil {
169 return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
170 }
171
172 msgValue.Set(fieldDescriptor, v)
173 return nil
174}
175
176func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
177 for _, value := range values {
178 v, err := parseField(fieldDescriptor, value)
179 if err != nil {
180 return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
181 }
182 list.Append(v)
183 }
184
185 return nil
186}
187
188func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
189 if len(values) != 2 {
190 return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
191 }
192
193 key, err := parseField(fieldDescriptor.MapKey(), values[0])
194 if err != nil {
195 return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
196 }
197
198 value, err := parseField(fieldDescriptor.MapValue(), values[1])
199 if err != nil {
200 return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
201 }
202
203 mp.Set(key.MapKey(), value)
204
205 return nil
206}
207
208func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
209 switch fieldDescriptor.Kind() {
210 case protoreflect.BoolKind:
211 v, err := strconv.ParseBool(value)
212 if err != nil {
213 return protoreflect.Value{}, err
214 }
215 return protoreflect.ValueOfBool(v), nil
216 case protoreflect.EnumKind:
217 enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
218 if err != nil {
219 if errors.Is(err, protoregistry.NotFound) {
220 return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
221 }
222 return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
223 }
224 // Look for enum by name
225 v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
226 if v == nil {
227 i, err := strconv.Atoi(value)
228 if err != nil {
229 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
230 }
231 // Look for enum by number
232 if v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)); v == nil {
233 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
234 }
235 }
236 return protoreflect.ValueOfEnum(v.Number()), nil
237 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
238 v, err := strconv.ParseInt(value, 10, 32)
239 if err != nil {
240 return protoreflect.Value{}, err
241 }
242 return protoreflect.ValueOfInt32(int32(v)), nil
243 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
244 v, err := strconv.ParseInt(value, 10, 64)
245 if err != nil {
246 return protoreflect.Value{}, err
247 }
248 return protoreflect.ValueOfInt64(v), nil
249 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
250 v, err := strconv.ParseUint(value, 10, 32)
251 if err != nil {
252 return protoreflect.Value{}, err
253 }
254 return protoreflect.ValueOfUint32(uint32(v)), nil
255 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
256 v, err := strconv.ParseUint(value, 10, 64)
257 if err != nil {
258 return protoreflect.Value{}, err
259 }
260 return protoreflect.ValueOfUint64(v), nil
261 case protoreflect.FloatKind:
262 v, err := strconv.ParseFloat(value, 32)
263 if err != nil {
264 return protoreflect.Value{}, err
265 }
266 return protoreflect.ValueOfFloat32(float32(v)), nil
267 case protoreflect.DoubleKind:
268 v, err := strconv.ParseFloat(value, 64)
269 if err != nil {
270 return protoreflect.Value{}, err
271 }
272 return protoreflect.ValueOfFloat64(v), nil
273 case protoreflect.StringKind:
274 return protoreflect.ValueOfString(value), nil
275 case protoreflect.BytesKind:
276 v, err := Bytes(value)
277 if err != nil {
278 return protoreflect.Value{}, err
279 }
280 return protoreflect.ValueOfBytes(v), nil
281 case protoreflect.MessageKind, protoreflect.GroupKind:
282 return parseMessage(fieldDescriptor.Message(), value)
283 default:
284 panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
285 }
286}
287
288func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
289 var msg proto.Message
290 switch msgDescriptor.FullName() {
291 case "google.protobuf.Timestamp":
292 t, err := time.Parse(time.RFC3339Nano, value)
293 if err != nil {
294 return protoreflect.Value{}, err
295 }
296 timestamp := timestamppb.New(t)
297 if ok := timestamp.IsValid(); !ok {
298 return protoreflect.Value{}, fmt.Errorf("%s before 0001-01-01", value)
299 }
300 msg = timestamp
301 case "google.protobuf.Duration":
302 d, err := time.ParseDuration(value)
303 if err != nil {
304 return protoreflect.Value{}, err
305 }
306 msg = durationpb.New(d)
307 case "google.protobuf.DoubleValue":
308 v, err := strconv.ParseFloat(value, 64)
309 if err != nil {
310 return protoreflect.Value{}, err
311 }
312 msg = wrapperspb.Double(v)
313 case "google.protobuf.FloatValue":
314 v, err := strconv.ParseFloat(value, 32)
315 if err != nil {
316 return protoreflect.Value{}, err
317 }
318 msg = wrapperspb.Float(float32(v))
319 case "google.protobuf.Int64Value":
320 v, err := strconv.ParseInt(value, 10, 64)
321 if err != nil {
322 return protoreflect.Value{}, err
323 }
324 msg = wrapperspb.Int64(v)
325 case "google.protobuf.Int32Value":
326 v, err := strconv.ParseInt(value, 10, 32)
327 if err != nil {
328 return protoreflect.Value{}, err
329 }
330 msg = wrapperspb.Int32(int32(v))
331 case "google.protobuf.UInt64Value":
332 v, err := strconv.ParseUint(value, 10, 64)
333 if err != nil {
334 return protoreflect.Value{}, err
335 }
336 msg = wrapperspb.UInt64(v)
337 case "google.protobuf.UInt32Value":
338 v, err := strconv.ParseUint(value, 10, 32)
339 if err != nil {
340 return protoreflect.Value{}, err
341 }
342 msg = wrapperspb.UInt32(uint32(v))
343 case "google.protobuf.BoolValue":
344 v, err := strconv.ParseBool(value)
345 if err != nil {
346 return protoreflect.Value{}, err
347 }
348 msg = wrapperspb.Bool(v)
349 case "google.protobuf.StringValue":
350 msg = wrapperspb.String(value)
351 case "google.protobuf.BytesValue":
352 v, err := Bytes(value)
353 if err != nil {
354 return protoreflect.Value{}, err
355 }
356 msg = wrapperspb.Bytes(v)
357 case "google.protobuf.FieldMask":
358 fm := &field_mask.FieldMask{}
359 fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
360 msg = fm
361 case "google.protobuf.Value":
362 var v structpb.Value
363 if err := protojson.Unmarshal([]byte(value), &v); err != nil {
364 return protoreflect.Value{}, err
365 }
366 msg = &v
367 case "google.protobuf.Struct":
368 var v structpb.Struct
369 if err := protojson.Unmarshal([]byte(value), &v); err != nil {
370 return protoreflect.Value{}, err
371 }
372 msg = &v
373 default:
374 return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
375 }
376
377 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
378}