blob: 19255ec441e61db620f37ae81b3aae9bdc42720e [file] [log] [blame]
Abhay Kumara2ae5992025-11-10 14:02:24 +00001package runtime
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "net/http"
8 "net/textproto"
9 "regexp"
10 "strings"
11
12 "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
13 "google.golang.org/grpc/codes"
14 "google.golang.org/grpc/grpclog"
15 "google.golang.org/grpc/health/grpc_health_v1"
16 "google.golang.org/grpc/metadata"
17 "google.golang.org/grpc/status"
18 "google.golang.org/protobuf/proto"
19)
20
21// UnescapingMode defines the behavior of ServeMux when unescaping path parameters.
22type UnescapingMode int
23
24const (
25 // UnescapingModeLegacy is the default V2 behavior, which escapes the entire
26 // path string before doing any routing.
27 UnescapingModeLegacy UnescapingMode = iota
28
29 // UnescapingModeAllExceptReserved unescapes all path parameters except RFC 6570
30 // reserved characters.
31 UnescapingModeAllExceptReserved
32
33 // UnescapingModeAllExceptSlash unescapes URL path parameters except path
34 // separators, which will be left as "%2F".
35 UnescapingModeAllExceptSlash
36
37 // UnescapingModeAllCharacters unescapes all URL path parameters.
38 UnescapingModeAllCharacters
39
40 // UnescapingModeDefault is the default escaping type.
41 // TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
42 // reference implementation
43 UnescapingModeDefault = UnescapingModeLegacy
44)
45
46var encodedPathSplitter = regexp.MustCompile("(/|%2F)")
47
48// A HandlerFunc handles a specific pair of path pattern and HTTP method.
49type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
50
51// A Middleware handler wraps another HandlerFunc to do some pre- and/or post-processing of the request. This is used as an alternative to gRPC interceptors when using the direct-to-implementation
52// registration methods. It is generally recommended to use gRPC client or server interceptors instead
53// where possible.
54type Middleware func(HandlerFunc) HandlerFunc
55
56// ServeMux is a request multiplexer for grpc-gateway.
57// It matches http requests to patterns and invokes the corresponding handler.
58type ServeMux struct {
59 // handlers maps HTTP method to a list of handlers.
60 handlers map[string][]handler
61 middlewares []Middleware
62 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
63 forwardResponseRewriter ForwardResponseRewriter
64 marshalers marshalerRegistry
65 incomingHeaderMatcher HeaderMatcherFunc
66 outgoingHeaderMatcher HeaderMatcherFunc
67 outgoingTrailerMatcher HeaderMatcherFunc
68 metadataAnnotators []func(context.Context, *http.Request) metadata.MD
69 errorHandler ErrorHandlerFunc
70 streamErrorHandler StreamErrorHandlerFunc
71 routingErrorHandler RoutingErrorHandlerFunc
72 disablePathLengthFallback bool
73 unescapingMode UnescapingMode
74 writeContentLength bool
75}
76
77// ServeMuxOption is an option that can be given to a ServeMux on construction.
78type ServeMuxOption func(*ServeMux)
79
80// ForwardResponseRewriter is the signature of a function that is capable of rewriting messages
81// before they are forwarded in a unary, stream, or error response.
82type ForwardResponseRewriter func(ctx context.Context, response proto.Message) (any, error)
83
84// WithForwardResponseRewriter returns a ServeMuxOption that allows for implementers to insert logic
85// that can rewrite the final response before it is forwarded.
86//
87// The response rewriter function is called during unary message forwarding, stream message
88// forwarding and when errors are being forwarded.
89//
90// NOTE: Using this option will likely make what is generated by `protoc-gen-openapiv2` incorrect.
91// Since this option involves making runtime changes to the response shape or type.
92func WithForwardResponseRewriter(fwdResponseRewriter ForwardResponseRewriter) ServeMuxOption {
93 return func(sm *ServeMux) {
94 sm.forwardResponseRewriter = fwdResponseRewriter
95 }
96}
97
98// WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
99//
100// forwardResponseOption is an option that will be called on the relevant context.Context,
101// http.ResponseWriter, and proto.Message before every forwarded response.
102//
103// The message may be nil in the case where just a header is being sent.
104func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
105 return func(serveMux *ServeMux) {
106 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
107 }
108}
109
110// WithUnescapingMode sets the escaping type. See the definitions of UnescapingMode
111// for more information.
112func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
113 return func(serveMux *ServeMux) {
114 serveMux.unescapingMode = mode
115 }
116}
117
118// WithMiddlewares sets server middleware for all handlers. This is useful as an alternative to gRPC
119// interceptors when using the direct-to-implementation registration methods and cannot rely
120// on gRPC interceptors. It's recommended to use gRPC interceptors instead if possible.
121func WithMiddlewares(middlewares ...Middleware) ServeMuxOption {
122 return func(serveMux *ServeMux) {
123 serveMux.middlewares = append(serveMux.middlewares, middlewares...)
124 }
125}
126
127// SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
128// Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
129// done with careful consideration.
130func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
131 return func(serveMux *ServeMux) {
132 currentQueryParser = queryParameterParser
133 }
134}
135
136// HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
137type HeaderMatcherFunc func(string) (string, bool)
138
139// DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
140// keys (as specified by the IANA, e.g: Accept, Cookie, Host) to the gRPC metadata with the grpcgateway- prefix. If you want to know which headers are considered permanent, you can view the isPermanentHTTPHeader function.
141// HTTP headers that start with 'Grpc-Metadata-' are mapped to gRPC metadata after removing the prefix 'Grpc-Metadata-'.
142// Other headers are not added to the gRPC metadata.
143func DefaultHeaderMatcher(key string) (string, bool) {
144 switch key = textproto.CanonicalMIMEHeaderKey(key); {
145 case isPermanentHTTPHeader(key):
146 return MetadataPrefix + key, true
147 case strings.HasPrefix(key, MetadataHeaderPrefix):
148 return key[len(MetadataHeaderPrefix):], true
149 }
150 return "", false
151}
152
153func defaultOutgoingHeaderMatcher(key string) (string, bool) {
154 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
155}
156
157func defaultOutgoingTrailerMatcher(key string) (string, bool) {
158 return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true
159}
160
161// WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
162//
163// This matcher will be called with each header in http.Request. If matcher returns true, that header will be
164// passed to gRPC context. To transform the header before passing to gRPC context, matcher should return the modified header.
165func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
166 for _, header := range fn.matchedMalformedHeaders() {
167 grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header)
168 }
169
170 return func(mux *ServeMux) {
171 mux.incomingHeaderMatcher = fn
172 }
173}
174
175// matchedMalformedHeaders returns the malformed headers that would be forwarded to gRPC server.
176func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string {
177 if fn == nil {
178 return nil
179 }
180 headers := make([]string, 0)
181 for header := range malformedHTTPHeaders {
182 out, accept := fn(header)
183 if accept && isMalformedHTTPHeader(out) {
184 headers = append(headers, out)
185 }
186 }
187 return headers
188}
189
190// WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
191//
192// This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
193// passed to http response returned from gateway. To transform the header before passing to response,
194// matcher should return the modified header.
195func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
196 return func(mux *ServeMux) {
197 mux.outgoingHeaderMatcher = fn
198 }
199}
200
201// WithOutgoingTrailerMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
202//
203// This matcher will be called with each header in response trailer metadata. If matcher returns true, that header will be
204// passed to http response returned from gateway. To transform the header before passing to response,
205// matcher should return the modified header.
206func WithOutgoingTrailerMatcher(fn HeaderMatcherFunc) ServeMuxOption {
207 return func(mux *ServeMux) {
208 mux.outgoingTrailerMatcher = fn
209 }
210}
211
212// WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
213//
214// This can be used by services that need to read from http.Request and modify gRPC context. A common use case
215// is reading token from cookie and adding it in gRPC context.
216func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
217 return func(serveMux *ServeMux) {
218 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
219 }
220}
221
222// WithErrorHandler returns a ServeMuxOption for configuring a custom error handler.
223//
224// This can be used to configure a custom error response.
225func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption {
226 return func(serveMux *ServeMux) {
227 serveMux.errorHandler = fn
228 }
229}
230
231// WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
232// error handler, which allows for customizing the error trailer for server-streaming
233// calls.
234//
235// For stream errors that occur before any response has been written, the mux's
236// ErrorHandler will be invoked. However, once data has been written, the errors must
237// be handled differently: they must be included in the response body. The response body's
238// final message will include the error details returned by the stream error handler.
239func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
240 return func(serveMux *ServeMux) {
241 serveMux.streamErrorHandler = fn
242 }
243}
244
245// WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors.
246//
247// Method called for errors which can happen before gRPC route selected or executed.
248// The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest
249func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption {
250 return func(serveMux *ServeMux) {
251 serveMux.routingErrorHandler = fn
252 }
253}
254
255// WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
256func WithDisablePathLengthFallback() ServeMuxOption {
257 return func(serveMux *ServeMux) {
258 serveMux.disablePathLengthFallback = true
259 }
260}
261
262// WithWriteContentLength returns a ServeMuxOption to enable writing content length on non-streaming responses
263func WithWriteContentLength() ServeMuxOption {
264 return func(serveMux *ServeMux) {
265 serveMux.writeContentLength = true
266 }
267}
268
269// WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath.
270// When called the handler will forward the request to the upstream grpc service health check (defined in the
271// gRPC Health Checking Protocol).
272//
273// See here https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/health_check/ for more information on how
274// to setup the protocol in the grpc server.
275//
276// If you define a service as query parameter, this will also be forwarded as service in the HealthCheckRequest.
277func WithHealthEndpointAt(healthCheckClient grpc_health_v1.HealthClient, endpointPath string) ServeMuxOption {
278 return func(s *ServeMux) {
279 // error can be ignored since pattern is definitely valid
280 _ = s.HandlePath(
281 http.MethodGet, endpointPath, func(w http.ResponseWriter, r *http.Request, _ map[string]string,
282 ) {
283 _, outboundMarshaler := MarshalerForRequest(s, r)
284
285 resp, err := healthCheckClient.Check(r.Context(), &grpc_health_v1.HealthCheckRequest{
286 Service: r.URL.Query().Get("service"),
287 })
288 if err != nil {
289 s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
290 return
291 }
292
293 w.Header().Set("Content-Type", "application/json")
294
295 if resp.GetStatus() != grpc_health_v1.HealthCheckResponse_SERVING {
296 switch resp.GetStatus() {
297 case grpc_health_v1.HealthCheckResponse_NOT_SERVING, grpc_health_v1.HealthCheckResponse_UNKNOWN:
298 err = status.Error(codes.Unavailable, resp.String())
299 case grpc_health_v1.HealthCheckResponse_SERVICE_UNKNOWN:
300 err = status.Error(codes.NotFound, resp.String())
301 }
302
303 s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
304 return
305 }
306
307 _ = outboundMarshaler.NewEncoder(w).Encode(resp)
308 })
309 }
310}
311
312// WithHealthzEndpoint returns a ServeMuxOption that will add a /healthz endpoint to the created ServeMux.
313//
314// See WithHealthEndpointAt for the general implementation.
315func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMuxOption {
316 return WithHealthEndpointAt(healthCheckClient, "/healthz")
317}
318
319// NewServeMux returns a new ServeMux whose internal mapping is empty.
320func NewServeMux(opts ...ServeMuxOption) *ServeMux {
321 serveMux := &ServeMux{
322 handlers: make(map[string][]handler),
323 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
324 forwardResponseRewriter: func(ctx context.Context, response proto.Message) (any, error) { return response, nil },
325 marshalers: makeMarshalerMIMERegistry(),
326 errorHandler: DefaultHTTPErrorHandler,
327 streamErrorHandler: DefaultStreamErrorHandler,
328 routingErrorHandler: DefaultRoutingErrorHandler,
329 unescapingMode: UnescapingModeDefault,
330 }
331
332 for _, opt := range opts {
333 opt(serveMux)
334 }
335
336 if serveMux.incomingHeaderMatcher == nil {
337 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
338 }
339 if serveMux.outgoingHeaderMatcher == nil {
340 serveMux.outgoingHeaderMatcher = defaultOutgoingHeaderMatcher
341 }
342 if serveMux.outgoingTrailerMatcher == nil {
343 serveMux.outgoingTrailerMatcher = defaultOutgoingTrailerMatcher
344 }
345
346 return serveMux
347}
348
349// Handle associates "h" to the pair of HTTP method and path pattern.
350func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
351 if len(s.middlewares) > 0 {
352 h = chainMiddlewares(s.middlewares)(h)
353 }
354 s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
355}
356
357// HandlePath allows users to configure custom path handlers.
358// refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/
359func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error {
360 compiler, err := httprule.Parse(pathPattern)
361 if err != nil {
362 return fmt.Errorf("parsing path pattern: %w", err)
363 }
364 tp := compiler.Compile()
365 pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
366 if err != nil {
367 return fmt.Errorf("creating new pattern: %w", err)
368 }
369 s.Handle(meth, pattern, h)
370 return nil
371}
372
373// ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.URL.Path.
374func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
375 ctx := r.Context()
376
377 path := r.URL.Path
378 if !strings.HasPrefix(path, "/") {
379 _, outboundMarshaler := MarshalerForRequest(s, r)
380 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
381 return
382 }
383
384 // TODO(v3): remove UnescapingModeLegacy
385 if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
386 path = r.URL.RawPath
387 }
388
389 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
390 if err := r.ParseForm(); err != nil {
391 _, outboundMarshaler := MarshalerForRequest(s, r)
392 sterr := status.Error(codes.InvalidArgument, err.Error())
393 s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
394 return
395 }
396 r.Method = strings.ToUpper(override)
397 }
398
399 var pathComponents []string
400 // since in UnescapeModeLegacy, the URL will already have been fully unescaped, if we also split on "%2F"
401 // in this escaping mode we would be double unescaping but in UnescapingModeAllCharacters, we still do as the
402 // path is the RawPath (i.e. unescaped). That does mean that the behavior of this function will change its default
403 // behavior when the UnescapingModeDefault gets changed from UnescapingModeLegacy to UnescapingModeAllExceptReserved
404 if s.unescapingMode == UnescapingModeAllCharacters {
405 pathComponents = encodedPathSplitter.Split(path[1:], -1)
406 } else {
407 pathComponents = strings.Split(path[1:], "/")
408 }
409
410 lastPathComponent := pathComponents[len(pathComponents)-1]
411
412 for _, h := range s.handlers[r.Method] {
413 // If the pattern has a verb, explicitly look for a suffix in the last
414 // component that matches a colon plus the verb. This allows us to
415 // handle some cases that otherwise can't be correctly handled by the
416 // former LastIndex case, such as when the verb literal itself contains
417 // a colon. This should work for all cases that have run through the
418 // parser because we know what verb we're looking for, however, there
419 // are still some cases that the parser itself cannot disambiguate. See
420 // the comment there if interested.
421
422 var verb string
423 patVerb := h.pat.Verb()
424
425 idx := -1
426 if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
427 idx = len(lastPathComponent) - len(patVerb) - 1
428 }
429 if idx == 0 {
430 _, outboundMarshaler := MarshalerForRequest(s, r)
431 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
432 return
433 }
434
435 comps := make([]string, len(pathComponents))
436 copy(comps, pathComponents)
437
438 if idx > 0 {
439 comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
440 }
441
442 pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
443 if err != nil {
444 var mse MalformedSequenceError
445 if ok := errors.As(err, &mse); ok {
446 _, outboundMarshaler := MarshalerForRequest(s, r)
447 s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
448 HTTPStatus: http.StatusBadRequest,
449 Err: mse,
450 })
451 }
452 continue
453 }
454 s.handleHandler(h, w, r, pathParams)
455 return
456 }
457
458 // if no handler has found for the request, lookup for other methods
459 // to handle POST -> GET fallback if the request is subject to path
460 // length fallback.
461 // Note we are not eagerly checking the request here as we want to return the
462 // right HTTP status code, and we need to process the fallback candidates in
463 // order to do that.
464 for m, handlers := range s.handlers {
465 if m == r.Method {
466 continue
467 }
468 for _, h := range handlers {
469 var verb string
470 patVerb := h.pat.Verb()
471
472 idx := -1
473 if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
474 idx = len(lastPathComponent) - len(patVerb) - 1
475 }
476
477 comps := make([]string, len(pathComponents))
478 copy(comps, pathComponents)
479
480 if idx > 0 {
481 comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
482 }
483
484 pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
485 if err != nil {
486 var mse MalformedSequenceError
487 if ok := errors.As(err, &mse); ok {
488 _, outboundMarshaler := MarshalerForRequest(s, r)
489 s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
490 HTTPStatus: http.StatusBadRequest,
491 Err: mse,
492 })
493 }
494 continue
495 }
496
497 // X-HTTP-Method-Override is optional. Always allow fallback to POST.
498 // Also, only consider POST -> GET fallbacks, and avoid falling back to
499 // potentially dangerous operations like DELETE.
500 if s.isPathLengthFallback(r) && m == http.MethodGet {
501 if err := r.ParseForm(); err != nil {
502 _, outboundMarshaler := MarshalerForRequest(s, r)
503 sterr := status.Error(codes.InvalidArgument, err.Error())
504 s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
505 return
506 }
507 s.handleHandler(h, w, r, pathParams)
508 return
509 }
510 _, outboundMarshaler := MarshalerForRequest(s, r)
511 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed)
512 return
513 }
514 }
515
516 _, outboundMarshaler := MarshalerForRequest(s, r)
517 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
518}
519
520// GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
521func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
522 return s.forwardResponseOptions
523}
524
525func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
526 return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
527}
528
529type handler struct {
530 pat Pattern
531 h HandlerFunc
532}
533
534func (s *ServeMux) handleHandler(h handler, w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
535 h.h(w, r.WithContext(withHTTPPattern(r.Context(), h.pat)), pathParams)
536}
537
538func chainMiddlewares(mws []Middleware) Middleware {
539 return func(next HandlerFunc) HandlerFunc {
540 for i := len(mws); i > 0; i-- {
541 next = mws[i-1](next)
542 }
543 return next
544 }
545}