blob: d913d44e150a3a58a860d0d5ff151b4024e37925 [file] [log] [blame]
Abhay Kumara2ae5992025-11-10 14:02:24 +00001package sarama
2
3import (
4 "bytes"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "io"
9 "maps"
10 "net"
11 "reflect"
12 "strconv"
13 "sync"
14 "syscall"
15 "time"
16
17 "github.com/davecgh/go-spew/spew"
18)
19
20const (
21 expectationTimeout = 500 * time.Millisecond
22)
23
24type GSSApiHandlerFunc func([]byte) []byte
25
26type requestHandlerFunc func(req *request) (res encoderWithHeader)
27
28// RequestNotifierFunc is invoked when a mock broker processes a request successfully
29// and will provides the number of bytes read and written.
30type RequestNotifierFunc func(bytesRead, bytesWritten int)
31
32// MockBroker is a mock Kafka broker that is used in unit tests. It is exposed
33// to facilitate testing of higher level or specialized consumers and producers
34// built on top of Sarama. Note that it does not 'mimic' the Kafka API protocol,
35// but rather provides a facility to do that. It takes care of the TCP
36// transport, request unmarshalling, response marshaling, and makes it the test
37// writer responsibility to program correct according to the Kafka API protocol
38// MockBroker behavior.
39//
40// MockBroker is implemented as a TCP server listening on a kernel-selected
41// localhost port that can accept many connections. It reads Kafka requests
42// from that connection and returns responses programmed by the SetHandlerByMap
43// function. If a MockBroker receives a request that it has no programmed
44// response for, then it returns nothing and the request times out.
45//
46// A set of MockRequest builders to define mappings used by MockBroker is
47// provided by Sarama. But users can develop MockRequests of their own and use
48// them along with or instead of the standard ones.
49//
50// When running tests with MockBroker it is strongly recommended to specify
51// a timeout to `go test` so that if the broker hangs waiting for a response,
52// the test panics.
53//
54// It is not necessary to prefix message length or correlation ID to your
55// response bytes, the server does that automatically as a convenience.
56type MockBroker struct {
57 brokerID int32
58 port int32
59 closing chan none
60 stopper chan none
61 expectations chan encoderWithHeader
62 listener net.Listener
63 t TestReporter
64 latency time.Duration
65 handler requestHandlerFunc
66 notifier RequestNotifierFunc
67 history []RequestResponse
68 lock sync.Mutex
69 gssApiHandler GSSApiHandlerFunc
70}
71
72// RequestResponse represents a Request/Response pair processed by MockBroker.
73type RequestResponse struct {
74 Request protocolBody
75 Response encoder
76}
77
78// SetLatency makes broker pause for the specified period every time before
79// replying.
80func (b *MockBroker) SetLatency(latency time.Duration) {
81 b.latency = latency
82}
83
84// SetHandlerByMap defines mapping of Request types to MockResponses. When a
85// request is received by the broker, it looks up the request type in the map
86// and uses the found MockResponse instance to generate an appropriate reply.
87// If the request type is not found in the map then nothing is sent.
88func (b *MockBroker) SetHandlerByMap(handlerMap map[string]MockResponse) {
89 fnMap := maps.Clone(handlerMap)
90 b.setHandler(func(req *request) (res encoderWithHeader) {
91 reqTypeName := reflect.TypeOf(req.body).Elem().Name()
92 mockResponse := fnMap[reqTypeName]
93 if mockResponse == nil {
94 return nil
95 }
96 return mockResponse.For(req.body)
97 })
98}
99
100// SetHandlerFuncByMap defines mapping of Request types to RequestHandlerFunc. When a
101// request is received by the broker, it looks up the request type in the map
102// and invoke the found RequestHandlerFunc instance to generate an appropriate reply.
103func (b *MockBroker) SetHandlerFuncByMap(handlerMap map[string]requestHandlerFunc) {
104 fnMap := maps.Clone(handlerMap)
105 b.setHandler(func(req *request) (res encoderWithHeader) {
106 reqTypeName := reflect.TypeOf(req.body).Elem().Name()
107 return fnMap[reqTypeName](req)
108 })
109}
110
111// SetNotifier set a function that will get invoked whenever a request has been
112// processed successfully and will provide the number of bytes read and written
113func (b *MockBroker) SetNotifier(notifier RequestNotifierFunc) {
114 b.lock.Lock()
115 b.notifier = notifier
116 b.lock.Unlock()
117}
118
119// BrokerID returns broker ID assigned to the broker.
120func (b *MockBroker) BrokerID() int32 {
121 return b.brokerID
122}
123
124// History returns a slice of RequestResponse pairs in the order they were
125// processed by the broker. Note that in case of multiple connections to the
126// broker the order expected by a test can be different from the order recorded
127// in the history, unless some synchronization is implemented in the test.
128func (b *MockBroker) History() []RequestResponse {
129 b.lock.Lock()
130 history := make([]RequestResponse, len(b.history))
131 copy(history, b.history)
132 b.lock.Unlock()
133 return history
134}
135
136// Port returns the TCP port number the broker is listening for requests on.
137func (b *MockBroker) Port() int32 {
138 return b.port
139}
140
141// Addr returns the broker connection string in the form "<address>:<port>".
142func (b *MockBroker) Addr() string {
143 return b.listener.Addr().String()
144}
145
146// Close terminates the broker blocking until it stops internal goroutines and
147// releases all resources.
148func (b *MockBroker) Close() {
149 close(b.expectations)
150 if len(b.expectations) > 0 {
151 buf := bytes.NewBufferString(fmt.Sprintf("mockbroker/%d: not all expectations were satisfied! Still waiting on:\n", b.BrokerID()))
152 for e := range b.expectations {
153 _, _ = buf.WriteString(spew.Sdump(e))
154 }
155 b.t.Error(buf.String())
156 }
157 close(b.closing)
158 <-b.stopper
159}
160
161// setHandler sets the specified function as the request handler. Whenever
162// a mock broker reads a request from the wire it passes the request to the
163// function and sends back whatever the handler function returns.
164func (b *MockBroker) setHandler(handler requestHandlerFunc) {
165 b.lock.Lock()
166 b.handler = handler
167 b.lock.Unlock()
168}
169
170func (b *MockBroker) serverLoop() {
171 defer close(b.stopper)
172 var err error
173 var conn net.Conn
174
175 go func() {
176 <-b.closing
177 err := b.listener.Close()
178 if err != nil {
179 b.t.Error(err)
180 }
181 }()
182
183 wg := &sync.WaitGroup{}
184 i := 0
185 for conn, err = b.listener.Accept(); err == nil; conn, err = b.listener.Accept() {
186 wg.Add(1)
187 go b.handleRequests(conn, i, wg)
188 i++
189 }
190 wg.Wait()
191 if !isConnectionClosedError(err) {
192 Logger.Printf("*** mockbroker/%d: listener closed, err=%v", b.BrokerID(), err)
193 }
194}
195
196func (b *MockBroker) SetGSSAPIHandler(handler GSSApiHandlerFunc) {
197 b.gssApiHandler = handler
198}
199
200func (b *MockBroker) readToBytes(r io.Reader) ([]byte, error) {
201 var (
202 bytesRead int
203 lengthBytes = make([]byte, 4)
204 )
205
206 if _, err := io.ReadFull(r, lengthBytes); err != nil {
207 return nil, err
208 }
209
210 bytesRead += len(lengthBytes)
211 length := int32(binary.BigEndian.Uint32(lengthBytes))
212
213 if length <= 4 || length > MaxRequestSize {
214 return nil, PacketDecodingError{fmt.Sprintf("message of length %d too large or too small", length)}
215 }
216
217 encodedReq := make([]byte, length)
218 if _, err := io.ReadFull(r, encodedReq); err != nil {
219 return nil, err
220 }
221
222 bytesRead += len(encodedReq)
223
224 fullBytes := append(lengthBytes, encodedReq...)
225
226 return fullBytes, nil
227}
228
229func (b *MockBroker) isGSSAPI(buffer []byte) bool {
230 return buffer[4] == 0x60 || bytes.Equal(buffer[4:6], []byte{0x05, 0x04})
231}
232
233func (b *MockBroker) handleRequests(conn io.ReadWriteCloser, idx int, wg *sync.WaitGroup) {
234 defer wg.Done()
235 defer func() {
236 _ = conn.Close()
237 }()
238 s := spew.NewDefaultConfig()
239 s.MaxDepth = 1
240 Logger.Printf("*** mockbroker/%d/%d: connection opened", b.BrokerID(), idx)
241 var err error
242
243 abort := make(chan none)
244 defer close(abort)
245 go func() {
246 select {
247 case <-b.closing:
248 _ = conn.Close()
249 case <-abort:
250 }
251 }()
252
253 var bytesWritten int
254 var bytesRead int
255 for {
256 buffer, err := b.readToBytes(conn)
257 if err != nil {
258 if !isConnectionClosedError(err) {
259 Logger.Printf("*** mockbroker/%d/%d: invalid request: err=%+v, %+v", b.brokerID, idx, err, spew.Sdump(buffer))
260 b.serverError(err)
261 }
262 break
263 }
264
265 bytesWritten = 0
266 if !b.isGSSAPI(buffer) {
267 req, br, err := decodeRequest(bytes.NewReader(buffer))
268 bytesRead = br
269 if err != nil {
270 if !isConnectionClosedError(err) {
271 Logger.Printf("*** mockbroker/%d/%d: invalid request: err=%+v, %+v", b.brokerID, idx, err, spew.Sdump(req))
272 b.serverError(err)
273 }
274 break
275 }
276
277 if b.latency > 0 {
278 time.Sleep(b.latency)
279 }
280
281 b.lock.Lock()
282 res := b.handler(req)
283 b.history = append(b.history, RequestResponse{req.body, res})
284 b.lock.Unlock()
285
286 if res == nil {
287 Logger.Printf("*** mockbroker/%d/%d: ignored %v", b.brokerID, idx, spew.Sdump(req))
288 continue
289 }
290 Logger.Printf(
291 "*** mockbroker/%d/%d: replied to %T with %T\n-> %s\n-> %s",
292 b.brokerID, idx, req.body, res,
293 s.Sprintf("%#v", req.body),
294 s.Sprintf("%#v", res),
295 )
296
297 encodedRes, err := encode(res, nil)
298 if err != nil {
299 b.serverError(fmt.Errorf("failed to encode %T - %w", res, err))
300 break
301 }
302 if len(encodedRes) == 0 {
303 b.lock.Lock()
304 if b.notifier != nil {
305 b.notifier(bytesRead, 0)
306 }
307 b.lock.Unlock()
308 continue
309 }
310
311 resHeader := b.encodeHeader(res.headerVersion(), req.correlationID, uint32(len(encodedRes)))
312 if _, err = conn.Write(resHeader); err != nil {
313 b.serverError(err)
314 break
315 }
316 if _, err = conn.Write(encodedRes); err != nil {
317 b.serverError(err)
318 break
319 }
320 bytesWritten = len(resHeader) + len(encodedRes)
321 } else {
322 // GSSAPI is not part of kafka protocol, but is supported for authentication proposes.
323 // Don't support history for this kind of request as is only used for test GSSAPI authentication mechanism
324 b.lock.Lock()
325 res := b.gssApiHandler(buffer)
326 b.lock.Unlock()
327 if res == nil {
328 Logger.Printf("*** mockbroker/%d/%d: ignored %v", b.brokerID, idx, spew.Sdump(buffer))
329 continue
330 }
331 if _, err = conn.Write(res); err != nil {
332 b.serverError(err)
333 break
334 }
335 bytesWritten = len(res)
336 }
337
338 b.lock.Lock()
339 if b.notifier != nil {
340 b.notifier(bytesRead, bytesWritten)
341 }
342 b.lock.Unlock()
343 }
344 Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err)
345}
346
347func (b *MockBroker) encodeHeader(headerVersion int16, correlationId int32, payloadLength uint32) []byte {
348 headerLength := uint32(8)
349
350 if headerVersion >= 1 {
351 headerLength = 9
352 }
353
354 resHeader := make([]byte, headerLength)
355 binary.BigEndian.PutUint32(resHeader, payloadLength+headerLength-4)
356 binary.BigEndian.PutUint32(resHeader[4:], uint32(correlationId))
357
358 if headerVersion >= 1 {
359 binary.PutUvarint(resHeader[8:], 0)
360 }
361
362 return resHeader
363}
364
365func (b *MockBroker) defaultRequestHandler(req *request) (res encoderWithHeader) {
366 select {
367 case res, ok := <-b.expectations:
368 if !ok {
369 return nil
370 }
371 return res
372 case <-time.After(expectationTimeout):
373 return nil
374 }
375}
376
377func isConnectionClosedError(err error) bool {
378 var result bool
379 opError := &net.OpError{}
380 if errors.As(err, &opError) {
381 result = true
382 } else if errors.Is(err, io.EOF) {
383 result = true
384 } else if err.Error() == "use of closed network connection" {
385 result = true
386 }
387
388 return result
389}
390
391func (b *MockBroker) serverError(err error) {
392 b.t.Helper()
393 if isConnectionClosedError(err) {
394 return
395 }
396 b.t.Errorf(err.Error())
397}
398
399// NewMockBroker launches a fake Kafka broker. It takes a TestReporter as provided by the
400// test framework and a channel of responses to use. If an error occurs it is
401// simply logged to the TestReporter and the broker exits.
402func NewMockBroker(t TestReporter, brokerID int32) *MockBroker {
403 return NewMockBrokerAddr(t, brokerID, "localhost:0")
404}
405
406// NewMockBrokerAddr behaves like newMockBroker but listens on the address you give
407// it rather than just some ephemeral port.
408func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker {
409 var (
410 listener net.Listener
411 err error
412 )
413
414 // retry up to 20 times if address already in use (e.g., if replacing broker which hasn't cleanly shutdown)
415 for i := 0; i < 20; i++ {
416 listener, err = net.Listen("tcp", addr)
417 if err != nil {
418 if errors.Is(err, syscall.EADDRINUSE) {
419 Logger.Printf("*** mockbroker/%d waiting for %s (address already in use)", brokerID, addr)
420 time.Sleep(time.Millisecond * 100)
421 continue
422 }
423 t.Fatal(err)
424 }
425 break
426 }
427
428 if err != nil {
429 t.Fatal(err)
430 }
431
432 return NewMockBrokerListener(t, brokerID, listener)
433}
434
435// NewMockBrokerListener behaves like newMockBrokerAddr but accepts connections on the listener specified.
436func NewMockBrokerListener(t TestReporter, brokerID int32, listener net.Listener) *MockBroker {
437 var err error
438
439 broker := &MockBroker{
440 closing: make(chan none),
441 stopper: make(chan none),
442 t: t,
443 brokerID: brokerID,
444 expectations: make(chan encoderWithHeader, 512),
445 listener: listener,
446 }
447 broker.handler = broker.defaultRequestHandler
448
449 Logger.Printf("*** mockbroker/%d listening on %s\n", brokerID, broker.listener.Addr().String())
450 _, portStr, err := net.SplitHostPort(broker.listener.Addr().String())
451 if err != nil {
452 t.Fatal(err)
453 }
454 tmp, err := strconv.ParseInt(portStr, 10, 32)
455 if err != nil {
456 t.Fatal(err)
457 }
458 broker.port = int32(tmp)
459
460 go broker.serverLoop()
461
462 return broker
463}
464
465func (b *MockBroker) Returns(e encoderWithHeader) {
466 b.expectations <- e
467}