blob: 2f2b342431d6f82773796e6ca699117fffcfce13 [file] [log] [blame]
Abhay Kumara2ae5992025-11-10 14:02:24 +00001package runtime
2
3import (
4 "context"
5 "encoding/base64"
6 "fmt"
7 "net"
8 "net/http"
9 "net/textproto"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "google.golang.org/grpc/codes"
16 "google.golang.org/grpc/grpclog"
17 "google.golang.org/grpc/metadata"
18 "google.golang.org/grpc/status"
19)
20
21// MetadataHeaderPrefix is the http prefix that represents custom metadata
22// parameters to or from a gRPC call.
23const MetadataHeaderPrefix = "Grpc-Metadata-"
24
25// MetadataPrefix is prepended to permanent HTTP header keys (as specified
26// by the IANA) when added to the gRPC context.
27const MetadataPrefix = "grpcgateway-"
28
29// MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
30// HTTP headers in a response handled by grpc-gateway
31const MetadataTrailerPrefix = "Grpc-Trailer-"
32
33const metadataGrpcTimeout = "Grpc-Timeout"
34const metadataHeaderBinarySuffix = "-Bin"
35
36const xForwardedFor = "X-Forwarded-For"
37const xForwardedHost = "X-Forwarded-Host"
38
39// DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
40// header isn't present. If the value is 0 the sent `context` will not have a timeout.
41var DefaultContextTimeout = 0 * time.Second
42
43// malformedHTTPHeaders lists the headers that the gRPC server may reject outright as malformed.
44// See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more context.
45var malformedHTTPHeaders = map[string]struct{}{
46 "connection": {},
47}
48
49type (
50 rpcMethodKey struct{}
51 httpPathPatternKey struct{}
52 httpPatternKey struct{}
53
54 AnnotateContextOption func(ctx context.Context) context.Context
55)
56
57func WithHTTPPathPattern(pattern string) AnnotateContextOption {
58 return func(ctx context.Context) context.Context {
59 return withHTTPPathPattern(ctx, pattern)
60 }
61}
62
63func decodeBinHeader(v string) ([]byte, error) {
64 if len(v)%4 == 0 {
65 // Input was padded, or padding was not necessary.
66 return base64.StdEncoding.DecodeString(v)
67 }
68 return base64.RawStdEncoding.DecodeString(v)
69}
70
71/*
72AnnotateContext adds context information such as metadata from the request.
73
74At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
75except that the forwarded destination is not another HTTP service but rather
76a gRPC service.
77*/
78func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
79 ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
80 if err != nil {
81 return nil, err
82 }
83 if md == nil {
84 return ctx, nil
85 }
86
87 return metadata.NewOutgoingContext(ctx, md), nil
88}
89
90// AnnotateIncomingContext adds context information such as metadata from the request.
91// Attach metadata as incoming context.
92func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
93 ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
94 if err != nil {
95 return nil, err
96 }
97 if md == nil {
98 return ctx, nil
99 }
100
101 return metadata.NewIncomingContext(ctx, md), nil
102}
103
104func isValidGRPCMetadataKey(key string) bool {
105 // Must be a valid gRPC "Header-Name" as defined here:
106 // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md
107 // This means 0-9 a-z _ - .
108 // Only lowercase letters are valid in the wire protocol, but the client library will normalize
109 // uppercase ASCII to lowercase, so uppercase ASCII is also acceptable.
110 bytes := []byte(key) // gRPC validates strings on the byte level, not Unicode.
111 for _, ch := range bytes {
112 validLowercaseLetter := ch >= 'a' && ch <= 'z'
113 validUppercaseLetter := ch >= 'A' && ch <= 'Z'
114 validDigit := ch >= '0' && ch <= '9'
115 validOther := ch == '.' || ch == '-' || ch == '_'
116 if !validLowercaseLetter && !validUppercaseLetter && !validDigit && !validOther {
117 return false
118 }
119 }
120 return true
121}
122
123func isValidGRPCMetadataTextValue(textValue string) bool {
124 // Must be a valid gRPC "ASCII-Value" as defined here:
125 // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md
126 // This means printable ASCII (including/plus spaces); 0x20 to 0x7E inclusive.
127 bytes := []byte(textValue) // gRPC validates strings on the byte level, not Unicode.
128 for _, ch := range bytes {
129 if ch < 0x20 || ch > 0x7E {
130 return false
131 }
132 }
133 return true
134}
135
136func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) {
137 ctx = withRPCMethod(ctx, rpcMethodName)
138 for _, o := range options {
139 ctx = o(ctx)
140 }
141 timeout := DefaultContextTimeout
142 if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
143 var err error
144 timeout, err = timeoutDecode(tm)
145 if err != nil {
146 return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
147 }
148 }
149 var pairs []string
150 for key, vals := range req.Header {
151 key = textproto.CanonicalMIMEHeaderKey(key)
152 switch key {
153 case xForwardedFor, xForwardedHost:
154 // Handled separately below
155 continue
156 }
157
158 for _, val := range vals {
159 // For backwards-compatibility, pass through 'authorization' header with no prefix.
160 if key == "Authorization" {
161 pairs = append(pairs, "authorization", val)
162 }
163 if h, ok := mux.incomingHeaderMatcher(key); ok {
164 if !isValidGRPCMetadataKey(h) {
165 grpclog.Errorf("HTTP header name %q is not valid as gRPC metadata key; skipping", h)
166 continue
167 }
168 // Handles "-bin" metadata in grpc, since grpc will do another base64
169 // encode before sending to server, we need to decode it first.
170 if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
171 b, err := decodeBinHeader(val)
172 if err != nil {
173 return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
174 }
175
176 val = string(b)
177 } else if !isValidGRPCMetadataTextValue(val) {
178 grpclog.Errorf("Value of HTTP header %q contains non-ASCII value (not valid as gRPC metadata): skipping", h)
179 continue
180 }
181 pairs = append(pairs, h, val)
182 }
183 }
184 }
185 if host := req.Header.Get(xForwardedHost); host != "" {
186 pairs = append(pairs, strings.ToLower(xForwardedHost), host)
187 } else if req.Host != "" {
188 pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
189 }
190
191 xff := req.Header.Values(xForwardedFor)
192 if addr := req.RemoteAddr; addr != "" {
193 if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
194 xff = append(xff, remoteIP)
195 }
196 }
197 if len(xff) > 0 {
198 pairs = append(pairs, strings.ToLower(xForwardedFor), strings.Join(xff, ", "))
199 }
200
201 if timeout != 0 {
202 ctx, _ = context.WithTimeout(ctx, timeout)
203 }
204 if len(pairs) == 0 {
205 return ctx, nil, nil
206 }
207 md := metadata.Pairs(pairs...)
208 for _, mda := range mux.metadataAnnotators {
209 md = metadata.Join(md, mda(ctx, req))
210 }
211 return ctx, md, nil
212}
213
214// ServerMetadata consists of metadata sent from gRPC server.
215type ServerMetadata struct {
216 HeaderMD metadata.MD
217 TrailerMD metadata.MD
218}
219
220type serverMetadataKey struct{}
221
222// NewServerMetadataContext creates a new context with ServerMetadata
223func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
224 if ctx == nil {
225 ctx = context.Background()
226 }
227 return context.WithValue(ctx, serverMetadataKey{}, md)
228}
229
230// ServerMetadataFromContext returns the ServerMetadata in ctx
231func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
232 if ctx == nil {
233 return md, false
234 }
235 md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
236 return
237}
238
239// ServerTransportStream implements grpc.ServerTransportStream.
240// It should only be used by the generated files to support grpc.SendHeader
241// outside of gRPC server use.
242type ServerTransportStream struct {
243 mu sync.Mutex
244 header metadata.MD
245 trailer metadata.MD
246}
247
248// Method returns the method for the stream.
249func (s *ServerTransportStream) Method() string {
250 return ""
251}
252
253// Header returns the header metadata of the stream.
254func (s *ServerTransportStream) Header() metadata.MD {
255 s.mu.Lock()
256 defer s.mu.Unlock()
257 return s.header.Copy()
258}
259
260// SetHeader sets the header metadata.
261func (s *ServerTransportStream) SetHeader(md metadata.MD) error {
262 if md.Len() == 0 {
263 return nil
264 }
265
266 s.mu.Lock()
267 s.header = metadata.Join(s.header, md)
268 s.mu.Unlock()
269 return nil
270}
271
272// SendHeader sets the header metadata.
273func (s *ServerTransportStream) SendHeader(md metadata.MD) error {
274 return s.SetHeader(md)
275}
276
277// Trailer returns the cached trailer metadata.
278func (s *ServerTransportStream) Trailer() metadata.MD {
279 s.mu.Lock()
280 defer s.mu.Unlock()
281 return s.trailer.Copy()
282}
283
284// SetTrailer sets the trailer metadata.
285func (s *ServerTransportStream) SetTrailer(md metadata.MD) error {
286 if md.Len() == 0 {
287 return nil
288 }
289
290 s.mu.Lock()
291 s.trailer = metadata.Join(s.trailer, md)
292 s.mu.Unlock()
293 return nil
294}
295
296func timeoutDecode(s string) (time.Duration, error) {
297 size := len(s)
298 if size < 2 {
299 return 0, fmt.Errorf("timeout string is too short: %q", s)
300 }
301 d, ok := timeoutUnitToDuration(s[size-1])
302 if !ok {
303 return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
304 }
305 t, err := strconv.ParseInt(s[:size-1], 10, 64)
306 if err != nil {
307 return 0, err
308 }
309 return d * time.Duration(t), nil
310}
311
312func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
313 switch u {
314 case 'H':
315 return time.Hour, true
316 case 'M':
317 return time.Minute, true
318 case 'S':
319 return time.Second, true
320 case 'm':
321 return time.Millisecond, true
322 case 'u':
323 return time.Microsecond, true
324 case 'n':
325 return time.Nanosecond, true
326 default:
327 return
328 }
329}
330
331// isPermanentHTTPHeader checks whether hdr belongs to the list of
332// permanent request headers maintained by IANA.
333// http://www.iana.org/assignments/message-headers/message-headers.xml
334func isPermanentHTTPHeader(hdr string) bool {
335 switch hdr {
336 case
337 "Accept",
338 "Accept-Charset",
339 "Accept-Language",
340 "Accept-Ranges",
341 "Authorization",
342 "Cache-Control",
343 "Content-Type",
344 "Cookie",
345 "Date",
346 "Expect",
347 "From",
348 "Host",
349 "If-Match",
350 "If-Modified-Since",
351 "If-None-Match",
352 "If-Schedule-Tag-Match",
353 "If-Unmodified-Since",
354 "Max-Forwards",
355 "Origin",
356 "Pragma",
357 "Referer",
358 "User-Agent",
359 "Via",
360 "Warning":
361 return true
362 }
363 return false
364}
365
366// isMalformedHTTPHeader checks whether header belongs to the list of
367// "malformed headers" and would be rejected by the gRPC server.
368func isMalformedHTTPHeader(header string) bool {
369 _, isMalformed := malformedHTTPHeaders[strings.ToLower(header)]
370 return isMalformed
371}
372
373// RPCMethod returns the method string for the server context. The returned
374// string is in the format of "/package.service/method".
375func RPCMethod(ctx context.Context) (string, bool) {
376 m := ctx.Value(rpcMethodKey{})
377 if m == nil {
378 return "", false
379 }
380 ms, ok := m.(string)
381 if !ok {
382 return "", false
383 }
384 return ms, true
385}
386
387func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context {
388 return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName)
389}
390
391// HTTPPathPattern returns the HTTP path pattern string relating to the HTTP handler, if one exists.
392// The format of the returned string is defined by the google.api.http path template type.
393func HTTPPathPattern(ctx context.Context) (string, bool) {
394 m := ctx.Value(httpPathPatternKey{})
395 if m == nil {
396 return "", false
397 }
398 ms, ok := m.(string)
399 if !ok {
400 return "", false
401 }
402 return ms, true
403}
404
405func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
406 return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
407}
408
409// HTTPPattern returns the HTTP path pattern struct relating to the HTTP handler, if one exists.
410func HTTPPattern(ctx context.Context) (Pattern, bool) {
411 v, ok := ctx.Value(httpPatternKey{}).(Pattern)
412 return v, ok
413}
414
415func withHTTPPattern(ctx context.Context, httpPattern Pattern) context.Context {
416 return context.WithValue(ctx, httpPatternKey{}, httpPattern)
417}