blob: 2f0b9e9e0f8691d89d455222e8025a89044460f6 [file] [log] [blame]
Abhay Kumara2ae5992025-11-10 14:02:24 +00001package runtime
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "net/http"
9 "net/textproto"
10 "strconv"
11 "strings"
12
13 "google.golang.org/genproto/googleapis/api/httpbody"
14 "google.golang.org/grpc/codes"
15 "google.golang.org/grpc/grpclog"
16 "google.golang.org/grpc/status"
17 "google.golang.org/protobuf/proto"
18)
19
20// ForwardResponseStream forwards the stream from gRPC server to REST client.
21func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
22 rc := http.NewResponseController(w)
23 md, ok := ServerMetadataFromContext(ctx)
24 if !ok {
25 grpclog.Error("Failed to extract ServerMetadata from context")
26 http.Error(w, "unexpected error", http.StatusInternalServerError)
27 return
28 }
29 handleForwardResponseServerMetadata(w, mux, md)
30
31 w.Header().Set("Transfer-Encoding", "chunked")
32 if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
33 HTTPError(ctx, mux, marshaler, w, req, err)
34 return
35 }
36
37 var delimiter []byte
38 if d, ok := marshaler.(Delimited); ok {
39 delimiter = d.Delimiter()
40 } else {
41 delimiter = []byte("\n")
42 }
43
44 var wroteHeader bool
45 for {
46 resp, err := recv()
47 if errors.Is(err, io.EOF) {
48 return
49 }
50 if err != nil {
51 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
52 return
53 }
54 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
55 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
56 return
57 }
58
59 respRw, err := mux.forwardResponseRewriter(ctx, resp)
60 if err != nil {
61 grpclog.Errorf("Rewrite error: %v", err)
62 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
63 return
64 }
65
66 if !wroteHeader {
67 var contentType string
68 if sct, ok := marshaler.(StreamContentType); ok {
69 contentType = sct.StreamContentType(respRw)
70 } else {
71 contentType = marshaler.ContentType(respRw)
72 }
73 w.Header().Set("Content-Type", contentType)
74 }
75
76 var buf []byte
77 httpBody, isHTTPBody := respRw.(*httpbody.HttpBody)
78 switch {
79 case respRw == nil:
80 buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
81 case isHTTPBody:
82 buf = httpBody.GetData()
83 default:
84 result := map[string]interface{}{"result": respRw}
85 if rb, ok := respRw.(responseBody); ok {
86 result["result"] = rb.XXX_ResponseBody()
87 }
88
89 buf, err = marshaler.Marshal(result)
90 }
91
92 if err != nil {
93 grpclog.Errorf("Failed to marshal response chunk: %v", err)
94 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
95 return
96 }
97 if _, err := w.Write(buf); err != nil {
98 grpclog.Errorf("Failed to send response chunk: %v", err)
99 return
100 }
101 wroteHeader = true
102 if _, err := w.Write(delimiter); err != nil {
103 grpclog.Errorf("Failed to send delimiter chunk: %v", err)
104 return
105 }
106 err = rc.Flush()
107 if err != nil {
108 if errors.Is(err, http.ErrNotSupported) {
109 grpclog.Errorf("Flush not supported in %T", w)
110 http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
111 return
112 }
113 grpclog.Errorf("Failed to flush response to client: %v", err)
114 return
115 }
116 }
117}
118
119func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
120 for k, vs := range md.HeaderMD {
121 if h, ok := mux.outgoingHeaderMatcher(k); ok {
122 for _, v := range vs {
123 w.Header().Add(h, v)
124 }
125 }
126 }
127}
128
129func handleForwardResponseTrailerHeader(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
130 for k := range md.TrailerMD {
131 if h, ok := mux.outgoingTrailerMatcher(k); ok {
132 w.Header().Add("Trailer", textproto.CanonicalMIMEHeaderKey(h))
133 }
134 }
135}
136
137func handleForwardResponseTrailer(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
138 for k, vs := range md.TrailerMD {
139 if h, ok := mux.outgoingTrailerMatcher(k); ok {
140 for _, v := range vs {
141 w.Header().Add(h, v)
142 }
143 }
144 }
145}
146
147// responseBody interface contains method for getting field for marshaling to the response body
148// this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
149type responseBody interface {
150 XXX_ResponseBody() interface{}
151}
152
153// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
154func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
155 md, ok := ServerMetadataFromContext(ctx)
156 if ok {
157 handleForwardResponseServerMetadata(w, mux, md)
158 }
159
160 // RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
161 // Unless the request includes a TE header field indicating "trailers"
162 // is acceptable, as described in Section 4.3, a server SHOULD NOT
163 // generate trailer fields that it believes are necessary for the user
164 // agent to receive.
165 doForwardTrailers := requestAcceptsTrailers(req)
166
167 if ok && doForwardTrailers {
168 handleForwardResponseTrailerHeader(w, mux, md)
169 w.Header().Set("Transfer-Encoding", "chunked")
170 }
171
172 contentType := marshaler.ContentType(resp)
173 w.Header().Set("Content-Type", contentType)
174
175 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
176 HTTPError(ctx, mux, marshaler, w, req, err)
177 return
178 }
179 respRw, err := mux.forwardResponseRewriter(ctx, resp)
180 if err != nil {
181 grpclog.Errorf("Rewrite error: %v", err)
182 HTTPError(ctx, mux, marshaler, w, req, err)
183 return
184 }
185 var buf []byte
186 if rb, ok := respRw.(responseBody); ok {
187 buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
188 } else {
189 buf, err = marshaler.Marshal(respRw)
190 }
191 if err != nil {
192 grpclog.Errorf("Marshal error: %v", err)
193 HTTPError(ctx, mux, marshaler, w, req, err)
194 return
195 }
196
197 if !doForwardTrailers && mux.writeContentLength {
198 w.Header().Set("Content-Length", strconv.Itoa(len(buf)))
199 }
200
201 if _, err = w.Write(buf); err != nil && !errors.Is(err, http.ErrBodyNotAllowed) {
202 grpclog.Errorf("Failed to write response: %v", err)
203 }
204
205 if ok && doForwardTrailers {
206 handleForwardResponseTrailer(w, mux, md)
207 }
208}
209
210func requestAcceptsTrailers(req *http.Request) bool {
211 te := req.Header.Get("TE")
212 return strings.Contains(strings.ToLower(te), "trailers")
213}
214
215func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
216 if len(opts) == 0 {
217 return nil
218 }
219 for _, opt := range opts {
220 if err := opt(ctx, w, resp); err != nil {
221 return fmt.Errorf("error handling ForwardResponseOptions: %w", err)
222 }
223 }
224 return nil
225}
226
227func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error, delimiter []byte) {
228 st := mux.streamErrorHandler(ctx, err)
229 msg := errorChunk(st)
230 if !wroteHeader {
231 w.Header().Set("Content-Type", marshaler.ContentType(msg))
232 w.WriteHeader(HTTPStatusFromCode(st.Code()))
233 }
234 buf, err := marshaler.Marshal(msg)
235 if err != nil {
236 grpclog.Errorf("Failed to marshal an error: %v", err)
237 return
238 }
239 if _, err := w.Write(buf); err != nil {
240 grpclog.Errorf("Failed to notify error to client: %v", err)
241 return
242 }
243 if _, err := w.Write(delimiter); err != nil {
244 grpclog.Errorf("Failed to send delimiter chunk: %v", err)
245 return
246 }
247}
248
249func errorChunk(st *status.Status) map[string]proto.Message {
250 return map[string]proto.Message{"error": st.Proto()}
251}