blob: 99a1eb95f7c48eda86b0a232237aeb0b2b8a8374 [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/encoding/protowire"
15 "google.golang.org/protobuf/internal/encoding/messageset"
16 "google.golang.org/protobuf/internal/flags"
17 "google.golang.org/protobuf/internal/genid"
18 "google.golang.org/protobuf/internal/strs"
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053019 "google.golang.org/protobuf/reflect/protoreflect"
20 "google.golang.org/protobuf/reflect/protoregistry"
21 "google.golang.org/protobuf/runtime/protoiface"
khenaidoo106c61a2021-08-11 18:05:46 -040022)
23
24// ValidationStatus is the result of validating the wire-format encoding of a message.
25type ValidationStatus int
26
27const (
28 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
29 // The validator was unable to render a judgement.
30 //
31 // The only causes of this status are an aberrant message type appearing somewhere
32 // in the message or a failure in the extension resolver.
33 ValidationUnknown ValidationStatus = iota + 1
34
35 // ValidationInvalid indicates that unmarshaling the message will fail.
36 ValidationInvalid
37
38 // ValidationValid indicates that unmarshaling the message will succeed.
39 ValidationValid
Abhay Kumara61c5222025-11-10 07:32:50 +000040
41 // ValidationWrongWireType indicates that a validated field does not have
42 // the expected wire type.
43 ValidationWrongWireType
khenaidoo106c61a2021-08-11 18:05:46 -040044)
45
46func (v ValidationStatus) String() string {
47 switch v {
48 case ValidationUnknown:
49 return "ValidationUnknown"
50 case ValidationInvalid:
51 return "ValidationInvalid"
52 case ValidationValid:
53 return "ValidationValid"
54 default:
55 return fmt.Sprintf("ValidationStatus(%d)", int(v))
56 }
57}
58
59// Validate determines whether the contents of the buffer are a valid wire encoding
60// of the message type.
61//
62// This function is exposed for testing.
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053063func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
khenaidoo106c61a2021-08-11 18:05:46 -040064 mi, ok := mt.(*MessageInfo)
65 if !ok {
66 return out, ValidationUnknown
67 }
68 if in.Resolver == nil {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053069 in.Resolver = protoregistry.GlobalTypes
khenaidoo106c61a2021-08-11 18:05:46 -040070 }
bseeniva0b9cbcb2026-02-12 19:11:11 +053071 if in.Depth == 0 {
72 in.Depth = protowire.DefaultRecursionLimit
73 }
khenaidoo106c61a2021-08-11 18:05:46 -040074 o, st := mi.validate(in.Buf, 0, unmarshalOptions{
75 flags: in.Flags,
76 resolver: in.Resolver,
bseeniva0b9cbcb2026-02-12 19:11:11 +053077 depth: in.Depth,
khenaidoo106c61a2021-08-11 18:05:46 -040078 })
79 if o.initialized {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053080 out.Flags |= protoiface.UnmarshalInitialized
khenaidoo106c61a2021-08-11 18:05:46 -040081 }
82 return out, st
83}
84
85type validationInfo struct {
86 mi *MessageInfo
87 typ validationType
88 keyType, valType validationType
89
90 // For non-required fields, requiredBit is 0.
91 //
92 // For required fields, requiredBit's nth bit is set, where n is a
93 // unique index in the range [0, MessageInfo.numRequiredFields).
94 //
95 // If there are more than 64 required fields, requiredBit is 0.
96 requiredBit uint64
97}
98
99type validationType uint8
100
101const (
102 validationTypeOther validationType = iota
103 validationTypeMessage
104 validationTypeGroup
105 validationTypeMap
106 validationTypeRepeatedVarint
107 validationTypeRepeatedFixed32
108 validationTypeRepeatedFixed64
109 validationTypeVarint
110 validationTypeFixed32
111 validationTypeFixed64
112 validationTypeBytes
113 validationTypeUTF8String
114 validationTypeMessageSetItem
115)
116
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530117func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
khenaidoo106c61a2021-08-11 18:05:46 -0400118 var vi validationInfo
119 switch {
120 case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
121 switch fd.Kind() {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530122 case protoreflect.MessageKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400123 vi.typ = validationTypeMessage
124 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
125 vi.mi = getMessageInfo(ot.Field(0).Type)
126 }
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530127 case protoreflect.GroupKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400128 vi.typ = validationTypeGroup
129 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
130 vi.mi = getMessageInfo(ot.Field(0).Type)
131 }
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530132 case protoreflect.StringKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400133 if strs.EnforceUTF8(fd) {
134 vi.typ = validationTypeUTF8String
135 }
136 }
137 default:
138 vi = newValidationInfo(fd, ft)
139 }
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530140 if fd.Cardinality() == protoreflect.Required {
khenaidoo106c61a2021-08-11 18:05:46 -0400141 // Avoid overflow. The required field check is done with a 64-bit mask, with
142 // any message containing more than 64 required fields always reported as
143 // potentially uninitialized, so it is not important to get a precise count
144 // of the required fields past 64.
145 if mi.numRequiredFields < math.MaxUint8 {
146 mi.numRequiredFields++
147 vi.requiredBit = 1 << (mi.numRequiredFields - 1)
148 }
149 }
150 return vi
151}
152
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530153func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
khenaidoo106c61a2021-08-11 18:05:46 -0400154 var vi validationInfo
155 switch {
156 case fd.IsList():
157 switch fd.Kind() {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530158 case protoreflect.MessageKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400159 vi.typ = validationTypeMessage
Abhay Kumara61c5222025-11-10 07:32:50 +0000160
161 if ft.Kind() == reflect.Ptr {
162 // Repeated opaque message fields are *[]*T.
163 ft = ft.Elem()
164 }
165
khenaidoo106c61a2021-08-11 18:05:46 -0400166 if ft.Kind() == reflect.Slice {
167 vi.mi = getMessageInfo(ft.Elem())
168 }
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530169 case protoreflect.GroupKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400170 vi.typ = validationTypeGroup
Abhay Kumara61c5222025-11-10 07:32:50 +0000171
172 if ft.Kind() == reflect.Ptr {
173 // Repeated opaque message fields are *[]*T.
174 ft = ft.Elem()
175 }
176
khenaidoo106c61a2021-08-11 18:05:46 -0400177 if ft.Kind() == reflect.Slice {
178 vi.mi = getMessageInfo(ft.Elem())
179 }
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530180 case protoreflect.StringKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400181 vi.typ = validationTypeBytes
182 if strs.EnforceUTF8(fd) {
183 vi.typ = validationTypeUTF8String
184 }
185 default:
186 switch wireTypes[fd.Kind()] {
187 case protowire.VarintType:
188 vi.typ = validationTypeRepeatedVarint
189 case protowire.Fixed32Type:
190 vi.typ = validationTypeRepeatedFixed32
191 case protowire.Fixed64Type:
192 vi.typ = validationTypeRepeatedFixed64
193 }
194 }
195 case fd.IsMap():
196 vi.typ = validationTypeMap
197 switch fd.MapKey().Kind() {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530198 case protoreflect.StringKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400199 if strs.EnforceUTF8(fd) {
200 vi.keyType = validationTypeUTF8String
201 }
202 }
203 switch fd.MapValue().Kind() {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530204 case protoreflect.MessageKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400205 vi.valType = validationTypeMessage
206 if ft.Kind() == reflect.Map {
207 vi.mi = getMessageInfo(ft.Elem())
208 }
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530209 case protoreflect.StringKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400210 if strs.EnforceUTF8(fd) {
211 vi.valType = validationTypeUTF8String
212 }
213 }
214 default:
215 switch fd.Kind() {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530216 case protoreflect.MessageKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400217 vi.typ = validationTypeMessage
Abhay Kumara61c5222025-11-10 07:32:50 +0000218 vi.mi = getMessageInfo(ft)
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530219 case protoreflect.GroupKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400220 vi.typ = validationTypeGroup
221 vi.mi = getMessageInfo(ft)
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530222 case protoreflect.StringKind:
khenaidoo106c61a2021-08-11 18:05:46 -0400223 vi.typ = validationTypeBytes
224 if strs.EnforceUTF8(fd) {
225 vi.typ = validationTypeUTF8String
226 }
227 default:
228 switch wireTypes[fd.Kind()] {
229 case protowire.VarintType:
230 vi.typ = validationTypeVarint
231 case protowire.Fixed32Type:
232 vi.typ = validationTypeFixed32
233 case protowire.Fixed64Type:
234 vi.typ = validationTypeFixed64
235 case protowire.BytesType:
236 vi.typ = validationTypeBytes
237 }
238 }
239 }
240 return vi
241}
242
243func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
244 mi.init()
245 type validationState struct {
246 typ validationType
247 keyType, valType validationType
248 endGroup protowire.Number
249 mi *MessageInfo
250 tail []byte
251 requiredMask uint64
252 }
253
254 // Pre-allocate some slots to avoid repeated slice reallocation.
255 states := make([]validationState, 0, 16)
256 states = append(states, validationState{
257 typ: validationTypeMessage,
258 mi: mi,
259 })
260 if groupTag > 0 {
261 states[0].typ = validationTypeGroup
262 states[0].endGroup = groupTag
263 }
bseeniva0b9cbcb2026-02-12 19:11:11 +0530264 if opts.depth--; opts.depth < 0 {
265 return out, ValidationInvalid
266 }
khenaidoo106c61a2021-08-11 18:05:46 -0400267 initialized := true
268 start := len(b)
269State:
270 for len(states) > 0 {
271 st := &states[len(states)-1]
272 for len(b) > 0 {
273 // Parse the tag (field number and wire type).
274 var tag uint64
275 if b[0] < 0x80 {
276 tag = uint64(b[0])
277 b = b[1:]
278 } else if len(b) >= 2 && b[1] < 128 {
279 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
280 b = b[2:]
281 } else {
282 var n int
283 tag, n = protowire.ConsumeVarint(b)
284 if n < 0 {
285 return out, ValidationInvalid
286 }
287 b = b[n:]
288 }
289 var num protowire.Number
290 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
291 return out, ValidationInvalid
292 } else {
293 num = protowire.Number(n)
294 }
295 wtyp := protowire.Type(tag & 7)
296
297 if wtyp == protowire.EndGroupType {
298 if st.endGroup == num {
299 goto PopState
300 }
301 return out, ValidationInvalid
302 }
303 var vi validationInfo
304 switch {
305 case st.typ == validationTypeMap:
306 switch num {
307 case genid.MapEntry_Key_field_number:
308 vi.typ = st.keyType
309 case genid.MapEntry_Value_field_number:
310 vi.typ = st.valType
311 vi.mi = st.mi
312 vi.requiredBit = 1
313 }
314 case flags.ProtoLegacy && st.mi.isMessageSet:
315 switch num {
316 case messageset.FieldItem:
317 vi.typ = validationTypeMessageSetItem
318 }
319 default:
320 var f *coderFieldInfo
321 if int(num) < len(st.mi.denseCoderFields) {
322 f = st.mi.denseCoderFields[num]
323 } else {
324 f = st.mi.coderFields[num]
325 }
326 if f != nil {
327 vi = f.validation
khenaidoo106c61a2021-08-11 18:05:46 -0400328 break
329 }
330 // Possible extension field.
331 //
332 // TODO: We should return ValidationUnknown when:
333 // 1. The resolver is not frozen. (More extensions may be added to it.)
334 // 2. The resolver returns preg.NotFound.
335 // In this case, a type added to the resolver in the future could cause
336 // unmarshaling to begin failing. Supporting this requires some way to
337 // determine if the resolver is frozen.
338 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530339 if err != nil && err != protoregistry.NotFound {
khenaidoo106c61a2021-08-11 18:05:46 -0400340 return out, ValidationUnknown
341 }
342 if err == nil {
343 vi = getExtensionFieldInfo(xt).validation
344 }
345 }
346 if vi.requiredBit != 0 {
347 // Check that the field has a compatible wire type.
348 // We only need to consider non-repeated field types,
349 // since repeated fields (and maps) can never be required.
350 ok := false
351 switch vi.typ {
352 case validationTypeVarint:
353 ok = wtyp == protowire.VarintType
354 case validationTypeFixed32:
355 ok = wtyp == protowire.Fixed32Type
356 case validationTypeFixed64:
357 ok = wtyp == protowire.Fixed64Type
358 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
359 ok = wtyp == protowire.BytesType
360 case validationTypeGroup:
361 ok = wtyp == protowire.StartGroupType
362 }
363 if ok {
364 st.requiredMask |= vi.requiredBit
365 }
366 }
367
368 switch wtyp {
369 case protowire.VarintType:
370 if len(b) >= 10 {
371 switch {
372 case b[0] < 0x80:
373 b = b[1:]
374 case b[1] < 0x80:
375 b = b[2:]
376 case b[2] < 0x80:
377 b = b[3:]
378 case b[3] < 0x80:
379 b = b[4:]
380 case b[4] < 0x80:
381 b = b[5:]
382 case b[5] < 0x80:
383 b = b[6:]
384 case b[6] < 0x80:
385 b = b[7:]
386 case b[7] < 0x80:
387 b = b[8:]
388 case b[8] < 0x80:
389 b = b[9:]
390 case b[9] < 0x80 && b[9] < 2:
391 b = b[10:]
392 default:
393 return out, ValidationInvalid
394 }
395 } else {
396 switch {
397 case len(b) > 0 && b[0] < 0x80:
398 b = b[1:]
399 case len(b) > 1 && b[1] < 0x80:
400 b = b[2:]
401 case len(b) > 2 && b[2] < 0x80:
402 b = b[3:]
403 case len(b) > 3 && b[3] < 0x80:
404 b = b[4:]
405 case len(b) > 4 && b[4] < 0x80:
406 b = b[5:]
407 case len(b) > 5 && b[5] < 0x80:
408 b = b[6:]
409 case len(b) > 6 && b[6] < 0x80:
410 b = b[7:]
411 case len(b) > 7 && b[7] < 0x80:
412 b = b[8:]
413 case len(b) > 8 && b[8] < 0x80:
414 b = b[9:]
415 case len(b) > 9 && b[9] < 2:
416 b = b[10:]
417 default:
418 return out, ValidationInvalid
419 }
420 }
421 continue State
422 case protowire.BytesType:
423 var size uint64
424 if len(b) >= 1 && b[0] < 0x80 {
425 size = uint64(b[0])
426 b = b[1:]
427 } else if len(b) >= 2 && b[1] < 128 {
428 size = uint64(b[0]&0x7f) + uint64(b[1])<<7
429 b = b[2:]
430 } else {
431 var n int
432 size, n = protowire.ConsumeVarint(b)
433 if n < 0 {
434 return out, ValidationInvalid
435 }
436 b = b[n:]
437 }
438 if size > uint64(len(b)) {
439 return out, ValidationInvalid
440 }
441 v := b[:size]
442 b = b[size:]
443 switch vi.typ {
444 case validationTypeMessage:
445 if vi.mi == nil {
446 return out, ValidationUnknown
447 }
448 vi.mi.init()
449 fallthrough
450 case validationTypeMap:
451 if vi.mi != nil {
452 vi.mi.init()
453 }
454 states = append(states, validationState{
455 typ: vi.typ,
456 keyType: vi.keyType,
457 valType: vi.valType,
458 mi: vi.mi,
459 tail: b,
460 })
bseeniva0b9cbcb2026-02-12 19:11:11 +0530461 if vi.typ == validationTypeMessage ||
462 vi.typ == validationTypeGroup ||
463 vi.typ == validationTypeMap {
464 if opts.depth--; opts.depth < 0 {
465 return out, ValidationInvalid
466 }
467 }
khenaidoo106c61a2021-08-11 18:05:46 -0400468 b = v
469 continue State
470 case validationTypeRepeatedVarint:
471 // Packed field.
472 for len(v) > 0 {
473 _, n := protowire.ConsumeVarint(v)
474 if n < 0 {
475 return out, ValidationInvalid
476 }
477 v = v[n:]
478 }
479 case validationTypeRepeatedFixed32:
480 // Packed field.
481 if len(v)%4 != 0 {
482 return out, ValidationInvalid
483 }
484 case validationTypeRepeatedFixed64:
485 // Packed field.
486 if len(v)%8 != 0 {
487 return out, ValidationInvalid
488 }
489 case validationTypeUTF8String:
490 if !utf8.Valid(v) {
491 return out, ValidationInvalid
492 }
493 }
494 case protowire.Fixed32Type:
495 if len(b) < 4 {
496 return out, ValidationInvalid
497 }
498 b = b[4:]
499 case protowire.Fixed64Type:
500 if len(b) < 8 {
501 return out, ValidationInvalid
502 }
503 b = b[8:]
504 case protowire.StartGroupType:
505 switch {
506 case vi.typ == validationTypeGroup:
507 if vi.mi == nil {
508 return out, ValidationUnknown
509 }
510 vi.mi.init()
511 states = append(states, validationState{
512 typ: validationTypeGroup,
513 mi: vi.mi,
514 endGroup: num,
515 })
bseeniva0b9cbcb2026-02-12 19:11:11 +0530516 if opts.depth--; opts.depth < 0 {
517 return out, ValidationInvalid
518 }
khenaidoo106c61a2021-08-11 18:05:46 -0400519 continue State
520 case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
521 typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
522 if err != nil {
523 return out, ValidationInvalid
524 }
525 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
526 switch {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530527 case err == protoregistry.NotFound:
khenaidoo106c61a2021-08-11 18:05:46 -0400528 b = b[n:]
529 case err != nil:
530 return out, ValidationUnknown
531 default:
532 xvi := getExtensionFieldInfo(xt).validation
533 if xvi.mi != nil {
534 xvi.mi.init()
535 }
536 states = append(states, validationState{
537 typ: xvi.typ,
538 mi: xvi.mi,
539 tail: b[n:],
540 })
bseeniva0b9cbcb2026-02-12 19:11:11 +0530541 if xvi.typ == validationTypeMessage ||
542 xvi.typ == validationTypeGroup ||
543 xvi.typ == validationTypeMap {
544 if opts.depth--; opts.depth < 0 {
545 return out, ValidationInvalid
546 }
547 }
khenaidoo106c61a2021-08-11 18:05:46 -0400548 b = v
549 continue State
550 }
551 default:
552 n := protowire.ConsumeFieldValue(num, wtyp, b)
553 if n < 0 {
554 return out, ValidationInvalid
555 }
556 b = b[n:]
557 }
558 default:
559 return out, ValidationInvalid
560 }
561 }
562 if st.endGroup != 0 {
563 return out, ValidationInvalid
564 }
565 if len(b) != 0 {
566 return out, ValidationInvalid
567 }
568 b = st.tail
569 PopState:
570 numRequiredFields := 0
571 switch st.typ {
572 case validationTypeMessage, validationTypeGroup:
573 numRequiredFields = int(st.mi.numRequiredFields)
bseeniva0b9cbcb2026-02-12 19:11:11 +0530574 opts.depth++
khenaidoo106c61a2021-08-11 18:05:46 -0400575 case validationTypeMap:
576 // If this is a map field with a message value that contains
577 // required fields, require that the value be present.
578 if st.mi != nil && st.mi.numRequiredFields > 0 {
579 numRequiredFields = 1
580 }
bseeniva0b9cbcb2026-02-12 19:11:11 +0530581 opts.depth++
khenaidoo106c61a2021-08-11 18:05:46 -0400582 }
583 // If there are more than 64 required fields, this check will
584 // always fail and we will report that the message is potentially
585 // uninitialized.
586 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
587 initialized = false
588 }
589 states = states[:len(states)-1]
590 }
591 out.n = start - len(b)
592 if initialized {
593 out.initialized = true
594 }
595 return out, ValidationValid
596}