blob: 889d8511d2759cf186944bc4362d117ce79b7bde [file] [log] [blame]
khenaidoo106c61a2021-08-11 18:05:46 -04001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "google.golang.org/protobuf/encoding/protowire"
9 "google.golang.org/protobuf/internal/encoding/messageset"
10 "google.golang.org/protobuf/internal/errors"
khenaidoo106c61a2021-08-11 18:05:46 -040011 "google.golang.org/protobuf/internal/genid"
12 "google.golang.org/protobuf/internal/pragma"
13 "google.golang.org/protobuf/reflect/protoreflect"
14 "google.golang.org/protobuf/reflect/protoregistry"
15 "google.golang.org/protobuf/runtime/protoiface"
16)
17
18// UnmarshalOptions configures the unmarshaler.
19//
20// Example usage:
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053021//
22// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
khenaidoo106c61a2021-08-11 18:05:46 -040023type UnmarshalOptions struct {
24 pragma.NoUnkeyedLiterals
25
26 // Merge merges the input into the destination message.
27 // The default behavior is to always reset the message before unmarshaling,
28 // unless Merge is specified.
29 Merge bool
30
31 // AllowPartial accepts input for messages that will result in missing
32 // required fields. If AllowPartial is false (the default), Unmarshal will
33 // return an error if there are any missing required fields.
34 AllowPartial bool
35
36 // If DiscardUnknown is set, unknown fields are ignored.
37 DiscardUnknown bool
38
39 // Resolver is used for looking up types when unmarshaling extension fields.
40 // If nil, this defaults to using protoregistry.GlobalTypes.
41 Resolver interface {
42 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
43 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
44 }
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053045
46 // RecursionLimit limits how deeply messages may be nested.
47 // If zero, a default limit is applied.
48 RecursionLimit int
Abhay Kumara61c5222025-11-10 07:32:50 +000049
50 //
51 // NoLazyDecoding turns off lazy decoding, which otherwise is enabled by
52 // default. Lazy decoding only affects submessages (annotated with [lazy =
53 // true] in the .proto file) within messages that use the Opaque API.
54 NoLazyDecoding bool
khenaidoo106c61a2021-08-11 18:05:46 -040055}
56
57// Unmarshal parses the wire-format message in b and places the result in m.
58// The provided message must be mutable (e.g., a non-nil pointer to a message).
Abhay Kumara61c5222025-11-10 07:32:50 +000059//
60// See the [UnmarshalOptions] type if you need more control.
khenaidoo106c61a2021-08-11 18:05:46 -040061func Unmarshal(b []byte, m Message) error {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053062 _, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
khenaidoo106c61a2021-08-11 18:05:46 -040063 return err
64}
65
66// Unmarshal parses the wire-format message in b and places the result in m.
67// The provided message must be mutable (e.g., a non-nil pointer to a message).
68func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053069 if o.RecursionLimit == 0 {
70 o.RecursionLimit = protowire.DefaultRecursionLimit
71 }
khenaidoo106c61a2021-08-11 18:05:46 -040072 _, err := o.unmarshal(b, m.ProtoReflect())
73 return err
74}
75
76// UnmarshalState parses a wire-format message and places the result in m.
77//
78// This method permits fine-grained control over the unmarshaler.
Abhay Kumara61c5222025-11-10 07:32:50 +000079// Most users should use [Unmarshal] instead.
khenaidoo106c61a2021-08-11 18:05:46 -040080func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +053081 if o.RecursionLimit == 0 {
82 o.RecursionLimit = protowire.DefaultRecursionLimit
83 }
khenaidoo106c61a2021-08-11 18:05:46 -040084 return o.unmarshal(in.Buf, in.Message)
85}
86
87// unmarshal is a centralized function that all unmarshal operations go through.
88// For profiling purposes, avoid changing the name of this function or
89// introducing other code paths for unmarshal that do not go through this.
90func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
91 if o.Resolver == nil {
92 o.Resolver = protoregistry.GlobalTypes
93 }
94 if !o.Merge {
95 Reset(m.Interface())
96 }
97 allowPartial := o.AllowPartial
98 o.Merge = true
99 o.AllowPartial = true
100 methods := protoMethods(m)
101 if methods != nil && methods.Unmarshal != nil &&
102 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
103 in := protoiface.UnmarshalInput{
104 Message: m,
105 Buf: b,
106 Resolver: o.Resolver,
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530107 Depth: o.RecursionLimit,
khenaidoo106c61a2021-08-11 18:05:46 -0400108 }
109 if o.DiscardUnknown {
110 in.Flags |= protoiface.UnmarshalDiscardUnknown
111 }
Abhay Kumara61c5222025-11-10 07:32:50 +0000112
113 if !allowPartial {
114 // This does not affect how current unmarshal functions work, it just allows them
115 // to record this for lazy the decoding case.
116 in.Flags |= protoiface.UnmarshalCheckRequired
117 }
118 if o.NoLazyDecoding {
119 in.Flags |= protoiface.UnmarshalNoLazyDecoding
120 }
121
khenaidoo106c61a2021-08-11 18:05:46 -0400122 out, err = methods.Unmarshal(in)
123 } else {
bseeniva0b9cbcb2026-02-12 19:11:11 +0530124 if o.RecursionLimit--; o.RecursionLimit < 0 {
125 return out, errRecursionDepth
Akash Reddy Kankanalac6b6ca12025-06-12 14:26:57 +0530126 }
khenaidoo106c61a2021-08-11 18:05:46 -0400127 err = o.unmarshalMessageSlow(b, m)
128 }
129 if err != nil {
130 return out, err
131 }
132 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
133 return out, nil
134 }
135 return out, checkInitialized(m)
136}
137
138func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
139 _, err := o.unmarshal(b, m)
140 return err
141}
142
143func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
144 md := m.Descriptor()
145 if messageset.IsMessageSet(md) {
146 return o.unmarshalMessageSet(b, m)
147 }
148 fields := md.Fields()
149 for len(b) > 0 {
150 // Parse the tag (field number and wire type).
151 num, wtyp, tagLen := protowire.ConsumeTag(b)
152 if tagLen < 0 {
153 return errDecode
154 }
155 if num > protowire.MaxValidNumber {
156 return errDecode
157 }
158
159 // Find the field descriptor for this field number.
160 fd := fields.ByNumber(num)
161 if fd == nil && md.ExtensionRanges().Has(num) {
162 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
163 if err != nil && err != protoregistry.NotFound {
164 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
165 }
166 if extType != nil {
167 fd = extType.TypeDescriptor()
168 }
169 }
170 var err error
171 if fd == nil {
172 err = errUnknown
khenaidoo106c61a2021-08-11 18:05:46 -0400173 }
174
175 // Parse the field value.
176 var valLen int
177 switch {
178 case err != nil:
179 case fd.IsList():
180 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
181 case fd.IsMap():
182 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
183 default:
184 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
185 }
186 if err != nil {
187 if err != errUnknown {
188 return err
189 }
190 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
191 if valLen < 0 {
192 return errDecode
193 }
194 if !o.DiscardUnknown {
195 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
196 }
197 }
198 b = b[tagLen+valLen:]
199 }
200 return nil
201}
202
203func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
204 v, n, err := o.unmarshalScalar(b, wtyp, fd)
205 if err != nil {
206 return 0, err
207 }
208 switch fd.Kind() {
209 case protoreflect.GroupKind, protoreflect.MessageKind:
210 m2 := m.Mutable(fd).Message()
211 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
212 return n, err
213 }
214 default:
215 // Non-message scalars replace the previous value.
216 m.Set(fd, v)
217 }
218 return n, nil
219}
220
221func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
bseeniva0b9cbcb2026-02-12 19:11:11 +0530222 if o.RecursionLimit--; o.RecursionLimit < 0 {
223 return 0, errRecursionDepth
224 }
khenaidoo106c61a2021-08-11 18:05:46 -0400225 if wtyp != protowire.BytesType {
226 return 0, errUnknown
227 }
228 b, n = protowire.ConsumeBytes(b)
229 if n < 0 {
230 return 0, errDecode
231 }
232 var (
233 keyField = fd.MapKey()
234 valField = fd.MapValue()
235 key protoreflect.Value
236 val protoreflect.Value
237 haveKey bool
238 haveVal bool
239 )
240 switch valField.Kind() {
241 case protoreflect.GroupKind, protoreflect.MessageKind:
242 val = mapv.NewValue()
243 }
244 // Map entries are represented as a two-element message with fields
245 // containing the key and value.
246 for len(b) > 0 {
247 num, wtyp, n := protowire.ConsumeTag(b)
248 if n < 0 {
249 return 0, errDecode
250 }
251 if num > protowire.MaxValidNumber {
252 return 0, errDecode
253 }
254 b = b[n:]
255 err = errUnknown
256 switch num {
257 case genid.MapEntry_Key_field_number:
258 key, n, err = o.unmarshalScalar(b, wtyp, keyField)
259 if err != nil {
260 break
261 }
262 haveKey = true
263 case genid.MapEntry_Value_field_number:
264 var v protoreflect.Value
265 v, n, err = o.unmarshalScalar(b, wtyp, valField)
266 if err != nil {
267 break
268 }
269 switch valField.Kind() {
270 case protoreflect.GroupKind, protoreflect.MessageKind:
271 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
272 return 0, err
273 }
274 default:
275 val = v
276 }
277 haveVal = true
278 }
279 if err == errUnknown {
280 n = protowire.ConsumeFieldValue(num, wtyp, b)
281 if n < 0 {
282 return 0, errDecode
283 }
284 } else if err != nil {
285 return 0, err
286 }
287 b = b[n:]
288 }
289 // Every map entry should have entries for key and value, but this is not strictly required.
290 if !haveKey {
291 key = keyField.Default()
292 }
293 if !haveVal {
294 switch valField.Kind() {
295 case protoreflect.GroupKind, protoreflect.MessageKind:
296 default:
297 val = valField.Default()
298 }
299 }
300 mapv.Set(key.MapKey(), val)
301 return n, nil
302}
303
304// errUnknown is used internally to indicate fields which should be added
305// to the unknown field set of a message. It is never returned from an exported
306// function.
307var errUnknown = errors.New("BUG: internal error (unknown)")
308
309var errDecode = errors.New("cannot parse invalid wire-format data")
bseeniva0b9cbcb2026-02-12 19:11:11 +0530310
311var errRecursionDepth = errors.New("exceeded maximum recursion depth")