| Abhay Kumar | a2ae599 | 2025-11-10 14:02:24 +0000 | [diff] [blame^] | 1 | package runtime |
| 2 | |
| 3 | import ( |
| 4 | "context" |
| 5 | "encoding/base64" |
| 6 | "fmt" |
| 7 | "net" |
| 8 | "net/http" |
| 9 | "net/textproto" |
| 10 | "strconv" |
| 11 | "strings" |
| 12 | "sync" |
| 13 | "time" |
| 14 | |
| 15 | "google.golang.org/grpc/codes" |
| 16 | "google.golang.org/grpc/grpclog" |
| 17 | "google.golang.org/grpc/metadata" |
| 18 | "google.golang.org/grpc/status" |
| 19 | ) |
| 20 | |
| 21 | // MetadataHeaderPrefix is the http prefix that represents custom metadata |
| 22 | // parameters to or from a gRPC call. |
| 23 | const MetadataHeaderPrefix = "Grpc-Metadata-" |
| 24 | |
| 25 | // MetadataPrefix is prepended to permanent HTTP header keys (as specified |
| 26 | // by the IANA) when added to the gRPC context. |
| 27 | const MetadataPrefix = "grpcgateway-" |
| 28 | |
| 29 | // MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to |
| 30 | // HTTP headers in a response handled by grpc-gateway |
| 31 | const MetadataTrailerPrefix = "Grpc-Trailer-" |
| 32 | |
| 33 | const metadataGrpcTimeout = "Grpc-Timeout" |
| 34 | const metadataHeaderBinarySuffix = "-Bin" |
| 35 | |
| 36 | const xForwardedFor = "X-Forwarded-For" |
| 37 | const xForwardedHost = "X-Forwarded-Host" |
| 38 | |
| 39 | // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound |
| 40 | // header isn't present. If the value is 0 the sent `context` will not have a timeout. |
| 41 | var DefaultContextTimeout = 0 * time.Second |
| 42 | |
| 43 | // malformedHTTPHeaders lists the headers that the gRPC server may reject outright as malformed. |
| 44 | // See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more context. |
| 45 | var malformedHTTPHeaders = map[string]struct{}{ |
| 46 | "connection": {}, |
| 47 | } |
| 48 | |
| 49 | type ( |
| 50 | rpcMethodKey struct{} |
| 51 | httpPathPatternKey struct{} |
| 52 | httpPatternKey struct{} |
| 53 | |
| 54 | AnnotateContextOption func(ctx context.Context) context.Context |
| 55 | ) |
| 56 | |
| 57 | func WithHTTPPathPattern(pattern string) AnnotateContextOption { |
| 58 | return func(ctx context.Context) context.Context { |
| 59 | return withHTTPPathPattern(ctx, pattern) |
| 60 | } |
| 61 | } |
| 62 | |
| 63 | func decodeBinHeader(v string) ([]byte, error) { |
| 64 | if len(v)%4 == 0 { |
| 65 | // Input was padded, or padding was not necessary. |
| 66 | return base64.StdEncoding.DecodeString(v) |
| 67 | } |
| 68 | return base64.RawStdEncoding.DecodeString(v) |
| 69 | } |
| 70 | |
| 71 | /* |
| 72 | AnnotateContext adds context information such as metadata from the request. |
| 73 | |
| 74 | At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For", |
| 75 | except that the forwarded destination is not another HTTP service but rather |
| 76 | a gRPC service. |
| 77 | */ |
| 78 | func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) { |
| 79 | ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...) |
| 80 | if err != nil { |
| 81 | return nil, err |
| 82 | } |
| 83 | if md == nil { |
| 84 | return ctx, nil |
| 85 | } |
| 86 | |
| 87 | return metadata.NewOutgoingContext(ctx, md), nil |
| 88 | } |
| 89 | |
| 90 | // AnnotateIncomingContext adds context information such as metadata from the request. |
| 91 | // Attach metadata as incoming context. |
| 92 | func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) { |
| 93 | ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...) |
| 94 | if err != nil { |
| 95 | return nil, err |
| 96 | } |
| 97 | if md == nil { |
| 98 | return ctx, nil |
| 99 | } |
| 100 | |
| 101 | return metadata.NewIncomingContext(ctx, md), nil |
| 102 | } |
| 103 | |
| 104 | func isValidGRPCMetadataKey(key string) bool { |
| 105 | // Must be a valid gRPC "Header-Name" as defined here: |
| 106 | // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md |
| 107 | // This means 0-9 a-z _ - . |
| 108 | // Only lowercase letters are valid in the wire protocol, but the client library will normalize |
| 109 | // uppercase ASCII to lowercase, so uppercase ASCII is also acceptable. |
| 110 | bytes := []byte(key) // gRPC validates strings on the byte level, not Unicode. |
| 111 | for _, ch := range bytes { |
| 112 | validLowercaseLetter := ch >= 'a' && ch <= 'z' |
| 113 | validUppercaseLetter := ch >= 'A' && ch <= 'Z' |
| 114 | validDigit := ch >= '0' && ch <= '9' |
| 115 | validOther := ch == '.' || ch == '-' || ch == '_' |
| 116 | if !validLowercaseLetter && !validUppercaseLetter && !validDigit && !validOther { |
| 117 | return false |
| 118 | } |
| 119 | } |
| 120 | return true |
| 121 | } |
| 122 | |
| 123 | func isValidGRPCMetadataTextValue(textValue string) bool { |
| 124 | // Must be a valid gRPC "ASCII-Value" as defined here: |
| 125 | // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md |
| 126 | // This means printable ASCII (including/plus spaces); 0x20 to 0x7E inclusive. |
| 127 | bytes := []byte(textValue) // gRPC validates strings on the byte level, not Unicode. |
| 128 | for _, ch := range bytes { |
| 129 | if ch < 0x20 || ch > 0x7E { |
| 130 | return false |
| 131 | } |
| 132 | } |
| 133 | return true |
| 134 | } |
| 135 | |
| 136 | func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) { |
| 137 | ctx = withRPCMethod(ctx, rpcMethodName) |
| 138 | for _, o := range options { |
| 139 | ctx = o(ctx) |
| 140 | } |
| 141 | timeout := DefaultContextTimeout |
| 142 | if tm := req.Header.Get(metadataGrpcTimeout); tm != "" { |
| 143 | var err error |
| 144 | timeout, err = timeoutDecode(tm) |
| 145 | if err != nil { |
| 146 | return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm) |
| 147 | } |
| 148 | } |
| 149 | var pairs []string |
| 150 | for key, vals := range req.Header { |
| 151 | key = textproto.CanonicalMIMEHeaderKey(key) |
| 152 | switch key { |
| 153 | case xForwardedFor, xForwardedHost: |
| 154 | // Handled separately below |
| 155 | continue |
| 156 | } |
| 157 | |
| 158 | for _, val := range vals { |
| 159 | // For backwards-compatibility, pass through 'authorization' header with no prefix. |
| 160 | if key == "Authorization" { |
| 161 | pairs = append(pairs, "authorization", val) |
| 162 | } |
| 163 | if h, ok := mux.incomingHeaderMatcher(key); ok { |
| 164 | if !isValidGRPCMetadataKey(h) { |
| 165 | grpclog.Errorf("HTTP header name %q is not valid as gRPC metadata key; skipping", h) |
| 166 | continue |
| 167 | } |
| 168 | // Handles "-bin" metadata in grpc, since grpc will do another base64 |
| 169 | // encode before sending to server, we need to decode it first. |
| 170 | if strings.HasSuffix(key, metadataHeaderBinarySuffix) { |
| 171 | b, err := decodeBinHeader(val) |
| 172 | if err != nil { |
| 173 | return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err) |
| 174 | } |
| 175 | |
| 176 | val = string(b) |
| 177 | } else if !isValidGRPCMetadataTextValue(val) { |
| 178 | grpclog.Errorf("Value of HTTP header %q contains non-ASCII value (not valid as gRPC metadata): skipping", h) |
| 179 | continue |
| 180 | } |
| 181 | pairs = append(pairs, h, val) |
| 182 | } |
| 183 | } |
| 184 | } |
| 185 | if host := req.Header.Get(xForwardedHost); host != "" { |
| 186 | pairs = append(pairs, strings.ToLower(xForwardedHost), host) |
| 187 | } else if req.Host != "" { |
| 188 | pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host) |
| 189 | } |
| 190 | |
| 191 | xff := req.Header.Values(xForwardedFor) |
| 192 | if addr := req.RemoteAddr; addr != "" { |
| 193 | if remoteIP, _, err := net.SplitHostPort(addr); err == nil { |
| 194 | xff = append(xff, remoteIP) |
| 195 | } |
| 196 | } |
| 197 | if len(xff) > 0 { |
| 198 | pairs = append(pairs, strings.ToLower(xForwardedFor), strings.Join(xff, ", ")) |
| 199 | } |
| 200 | |
| 201 | if timeout != 0 { |
| 202 | ctx, _ = context.WithTimeout(ctx, timeout) |
| 203 | } |
| 204 | if len(pairs) == 0 { |
| 205 | return ctx, nil, nil |
| 206 | } |
| 207 | md := metadata.Pairs(pairs...) |
| 208 | for _, mda := range mux.metadataAnnotators { |
| 209 | md = metadata.Join(md, mda(ctx, req)) |
| 210 | } |
| 211 | return ctx, md, nil |
| 212 | } |
| 213 | |
| 214 | // ServerMetadata consists of metadata sent from gRPC server. |
| 215 | type ServerMetadata struct { |
| 216 | HeaderMD metadata.MD |
| 217 | TrailerMD metadata.MD |
| 218 | } |
| 219 | |
| 220 | type serverMetadataKey struct{} |
| 221 | |
| 222 | // NewServerMetadataContext creates a new context with ServerMetadata |
| 223 | func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context { |
| 224 | if ctx == nil { |
| 225 | ctx = context.Background() |
| 226 | } |
| 227 | return context.WithValue(ctx, serverMetadataKey{}, md) |
| 228 | } |
| 229 | |
| 230 | // ServerMetadataFromContext returns the ServerMetadata in ctx |
| 231 | func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) { |
| 232 | if ctx == nil { |
| 233 | return md, false |
| 234 | } |
| 235 | md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata) |
| 236 | return |
| 237 | } |
| 238 | |
| 239 | // ServerTransportStream implements grpc.ServerTransportStream. |
| 240 | // It should only be used by the generated files to support grpc.SendHeader |
| 241 | // outside of gRPC server use. |
| 242 | type ServerTransportStream struct { |
| 243 | mu sync.Mutex |
| 244 | header metadata.MD |
| 245 | trailer metadata.MD |
| 246 | } |
| 247 | |
| 248 | // Method returns the method for the stream. |
| 249 | func (s *ServerTransportStream) Method() string { |
| 250 | return "" |
| 251 | } |
| 252 | |
| 253 | // Header returns the header metadata of the stream. |
| 254 | func (s *ServerTransportStream) Header() metadata.MD { |
| 255 | s.mu.Lock() |
| 256 | defer s.mu.Unlock() |
| 257 | return s.header.Copy() |
| 258 | } |
| 259 | |
| 260 | // SetHeader sets the header metadata. |
| 261 | func (s *ServerTransportStream) SetHeader(md metadata.MD) error { |
| 262 | if md.Len() == 0 { |
| 263 | return nil |
| 264 | } |
| 265 | |
| 266 | s.mu.Lock() |
| 267 | s.header = metadata.Join(s.header, md) |
| 268 | s.mu.Unlock() |
| 269 | return nil |
| 270 | } |
| 271 | |
| 272 | // SendHeader sets the header metadata. |
| 273 | func (s *ServerTransportStream) SendHeader(md metadata.MD) error { |
| 274 | return s.SetHeader(md) |
| 275 | } |
| 276 | |
| 277 | // Trailer returns the cached trailer metadata. |
| 278 | func (s *ServerTransportStream) Trailer() metadata.MD { |
| 279 | s.mu.Lock() |
| 280 | defer s.mu.Unlock() |
| 281 | return s.trailer.Copy() |
| 282 | } |
| 283 | |
| 284 | // SetTrailer sets the trailer metadata. |
| 285 | func (s *ServerTransportStream) SetTrailer(md metadata.MD) error { |
| 286 | if md.Len() == 0 { |
| 287 | return nil |
| 288 | } |
| 289 | |
| 290 | s.mu.Lock() |
| 291 | s.trailer = metadata.Join(s.trailer, md) |
| 292 | s.mu.Unlock() |
| 293 | return nil |
| 294 | } |
| 295 | |
| 296 | func timeoutDecode(s string) (time.Duration, error) { |
| 297 | size := len(s) |
| 298 | if size < 2 { |
| 299 | return 0, fmt.Errorf("timeout string is too short: %q", s) |
| 300 | } |
| 301 | d, ok := timeoutUnitToDuration(s[size-1]) |
| 302 | if !ok { |
| 303 | return 0, fmt.Errorf("timeout unit is not recognized: %q", s) |
| 304 | } |
| 305 | t, err := strconv.ParseInt(s[:size-1], 10, 64) |
| 306 | if err != nil { |
| 307 | return 0, err |
| 308 | } |
| 309 | return d * time.Duration(t), nil |
| 310 | } |
| 311 | |
| 312 | func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) { |
| 313 | switch u { |
| 314 | case 'H': |
| 315 | return time.Hour, true |
| 316 | case 'M': |
| 317 | return time.Minute, true |
| 318 | case 'S': |
| 319 | return time.Second, true |
| 320 | case 'm': |
| 321 | return time.Millisecond, true |
| 322 | case 'u': |
| 323 | return time.Microsecond, true |
| 324 | case 'n': |
| 325 | return time.Nanosecond, true |
| 326 | default: |
| 327 | return |
| 328 | } |
| 329 | } |
| 330 | |
| 331 | // isPermanentHTTPHeader checks whether hdr belongs to the list of |
| 332 | // permanent request headers maintained by IANA. |
| 333 | // http://www.iana.org/assignments/message-headers/message-headers.xml |
| 334 | func isPermanentHTTPHeader(hdr string) bool { |
| 335 | switch hdr { |
| 336 | case |
| 337 | "Accept", |
| 338 | "Accept-Charset", |
| 339 | "Accept-Language", |
| 340 | "Accept-Ranges", |
| 341 | "Authorization", |
| 342 | "Cache-Control", |
| 343 | "Content-Type", |
| 344 | "Cookie", |
| 345 | "Date", |
| 346 | "Expect", |
| 347 | "From", |
| 348 | "Host", |
| 349 | "If-Match", |
| 350 | "If-Modified-Since", |
| 351 | "If-None-Match", |
| 352 | "If-Schedule-Tag-Match", |
| 353 | "If-Unmodified-Since", |
| 354 | "Max-Forwards", |
| 355 | "Origin", |
| 356 | "Pragma", |
| 357 | "Referer", |
| 358 | "User-Agent", |
| 359 | "Via", |
| 360 | "Warning": |
| 361 | return true |
| 362 | } |
| 363 | return false |
| 364 | } |
| 365 | |
| 366 | // isMalformedHTTPHeader checks whether header belongs to the list of |
| 367 | // "malformed headers" and would be rejected by the gRPC server. |
| 368 | func isMalformedHTTPHeader(header string) bool { |
| 369 | _, isMalformed := malformedHTTPHeaders[strings.ToLower(header)] |
| 370 | return isMalformed |
| 371 | } |
| 372 | |
| 373 | // RPCMethod returns the method string for the server context. The returned |
| 374 | // string is in the format of "/package.service/method". |
| 375 | func RPCMethod(ctx context.Context) (string, bool) { |
| 376 | m := ctx.Value(rpcMethodKey{}) |
| 377 | if m == nil { |
| 378 | return "", false |
| 379 | } |
| 380 | ms, ok := m.(string) |
| 381 | if !ok { |
| 382 | return "", false |
| 383 | } |
| 384 | return ms, true |
| 385 | } |
| 386 | |
| 387 | func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context { |
| 388 | return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName) |
| 389 | } |
| 390 | |
| 391 | // HTTPPathPattern returns the HTTP path pattern string relating to the HTTP handler, if one exists. |
| 392 | // The format of the returned string is defined by the google.api.http path template type. |
| 393 | func HTTPPathPattern(ctx context.Context) (string, bool) { |
| 394 | m := ctx.Value(httpPathPatternKey{}) |
| 395 | if m == nil { |
| 396 | return "", false |
| 397 | } |
| 398 | ms, ok := m.(string) |
| 399 | if !ok { |
| 400 | return "", false |
| 401 | } |
| 402 | return ms, true |
| 403 | } |
| 404 | |
| 405 | func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context { |
| 406 | return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern) |
| 407 | } |
| 408 | |
| 409 | // HTTPPattern returns the HTTP path pattern struct relating to the HTTP handler, if one exists. |
| 410 | func HTTPPattern(ctx context.Context) (Pattern, bool) { |
| 411 | v, ok := ctx.Value(httpPatternKey{}).(Pattern) |
| 412 | return v, ok |
| 413 | } |
| 414 | |
| 415 | func withHTTPPattern(ctx context.Context, httpPattern Pattern) context.Context { |
| 416 | return context.WithValue(ctx, httpPatternKey{}, httpPattern) |
| 417 | } |