blob: 7c559cfcbb6a9096ef6fb2fadd44e6978cdc9249 [file] [log] [blame]
Abhay Kumar40252eb2025-10-13 13:25:53 +00001package sarama
2
3import (
4 "crypto/tls"
5 "encoding/binary"
6 "errors"
7 "fmt"
8 "io"
9 "math/rand"
10 "net"
11 "sort"
12 "strconv"
13 "strings"
14 "sync"
15 "sync/atomic"
16 "time"
17
18 "github.com/rcrowley/go-metrics"
19)
20
21// Broker represents a single Kafka broker connection. All operations on this object are entirely concurrency-safe.
22type Broker struct {
23 conf *Config
24 rack *string
25
26 id int32
27 addr string
28 correlationID int32
29 conn net.Conn
30 connErr error
31 lock sync.Mutex
32 opened atomic.Bool
33 responses chan *responsePromise
34 done chan bool
35
36 metricRegistry metrics.Registry
37 incomingByteRate metrics.Meter
38 requestRate metrics.Meter
39 fetchRate metrics.Meter
40 requestSize metrics.Histogram
41 requestLatency metrics.Histogram
42 outgoingByteRate metrics.Meter
43 responseRate metrics.Meter
44 responseSize metrics.Histogram
45 requestsInFlight metrics.Counter
46 protocolRequestsRate map[int16]metrics.Meter
47 brokerIncomingByteRate metrics.Meter
48 brokerRequestRate metrics.Meter
49 brokerFetchRate metrics.Meter
50 brokerRequestSize metrics.Histogram
51 brokerRequestLatency metrics.Histogram
52 brokerOutgoingByteRate metrics.Meter
53 brokerResponseRate metrics.Meter
54 brokerResponseSize metrics.Histogram
55 brokerRequestsInFlight metrics.Counter
56 brokerThrottleTime metrics.Histogram
57 brokerProtocolRequestsRate map[int16]metrics.Meter
58 brokerAPIVersions apiVersionMap
59
60 kerberosAuthenticator GSSAPIKerberosAuth
61 clientSessionReauthenticationTimeMs int64
62
63 throttleTimer *time.Timer
64 throttleTimerLock sync.Mutex
65}
66
67// SASLMechanism specifies the SASL mechanism the client uses to authenticate with the broker
68type SASLMechanism string
69
70const (
71 // SASLTypeOAuth represents the SASL/OAUTHBEARER mechanism (Kafka 2.0.0+)
72 SASLTypeOAuth = "OAUTHBEARER"
73 // SASLTypePlaintext represents the SASL/PLAIN mechanism
74 SASLTypePlaintext = "PLAIN"
75 // SASLTypeSCRAMSHA256 represents the SCRAM-SHA-256 mechanism.
76 SASLTypeSCRAMSHA256 = "SCRAM-SHA-256"
77 // SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism.
78 SASLTypeSCRAMSHA512 = "SCRAM-SHA-512"
79 SASLTypeGSSAPI = "GSSAPI"
80 // SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and
81 // server negotiate SASL auth using opaque packets.
82 SASLHandshakeV0 = int16(0)
83 // SASLHandshakeV1 is v1 of the Kafka SASL handshake protocol. Client and
84 // server negotiate SASL by wrapping tokens with Kafka protocol headers.
85 SASLHandshakeV1 = int16(1)
86 // SASLExtKeyAuth is the reserved extension key name sent as part of the
87 // SASL/OAUTHBEARER initial client response
88 SASLExtKeyAuth = "auth"
89)
90
91// AccessToken contains an access token used to authenticate a
92// SASL/OAUTHBEARER client along with associated metadata.
93type AccessToken struct {
94 // Token is the access token payload.
95 Token string
96 // Extensions is a optional map of arbitrary key-value pairs that can be
97 // sent with the SASL/OAUTHBEARER initial client response. These values are
98 // ignored by the SASL server if they are unexpected. This feature is only
99 // supported by Kafka >= 2.1.0.
100 Extensions map[string]string
101}
102
103// AccessTokenProvider is the interface that encapsulates how implementors
104// can generate access tokens for Kafka broker authentication.
105type AccessTokenProvider interface {
106 // Token returns an access token. The implementation should ensure token
107 // reuse so that multiple calls at connect time do not create multiple
108 // tokens. The implementation should also periodically refresh the token in
109 // order to guarantee that each call returns an unexpired token. This
110 // method should not block indefinitely--a timeout error should be returned
111 // after a short period of inactivity so that the broker connection logic
112 // can log debugging information and retry.
113 Token() (*AccessToken, error)
114}
115
116// SCRAMClient is a an interface to a SCRAM
117// client implementation.
118type SCRAMClient interface {
119 // Begin prepares the client for the SCRAM exchange
120 // with the server with a user name and a password
121 Begin(userName, password, authzID string) error
122 // Step steps client through the SCRAM exchange. It is
123 // called repeatedly until it errors or `Done` returns true.
124 Step(challenge string) (response string, err error)
125 // Done should return true when the SCRAM conversation
126 // is over.
127 Done() bool
128}
129
130type responsePromise struct {
131 requestTime time.Time
132 correlationID int32
133 response protocolBody
134 handler func([]byte, error)
135 packets chan []byte
136 errors chan error
137}
138
139func (p *responsePromise) handle(packets []byte, err error) {
140 // Use callback when provided
141 if p.handler != nil {
142 p.handler(packets, err)
143 return
144 }
145 // Otherwise fallback to using channels
146 if err != nil {
147 p.errors <- err
148 return
149 }
150 p.packets <- packets
151}
152
153// NewBroker creates and returns a Broker targeting the given host:port address.
154// This does not attempt to actually connect, you have to call Open() for that.
155func NewBroker(addr string) *Broker {
156 return &Broker{id: -1, addr: addr}
157}
158
159// Open tries to connect to the Broker if it is not already connected or connecting, but does not block
160// waiting for the connection to complete. This means that any subsequent operations on the broker will
161// block waiting for the connection to succeed or fail. To get the effect of a fully synchronous Open call,
162// follow it by a call to Connected(). The only errors Open will return directly are ConfigurationError or
163// AlreadyConnected. If conf is nil, the result of NewConfig() is used.
164func (b *Broker) Open(conf *Config) error {
165 if !b.opened.CompareAndSwap(false, true) {
166 return ErrAlreadyConnected
167 }
168
169 if conf == nil {
170 conf = NewConfig()
171 }
172
173 err := conf.Validate()
174 if err != nil {
175 return err
176 }
177
178 b.lock.Lock()
179
180 if b.metricRegistry == nil {
181 b.metricRegistry = newCleanupRegistry(conf.MetricRegistry)
182 }
183
184 go withRecover(func() {
185 defer b.lock.Unlock()
186
187 dialer := conf.getDialer()
188 b.conn, b.connErr = dialer.Dial("tcp", b.addr)
189 if b.connErr != nil {
190 Logger.Printf("Failed to connect to broker %s: %s\n", b.addr, b.connErr)
191 b.conn = nil
192 b.opened.Store(false)
193 return
194 }
195 if conf.Net.TLS.Enable {
196 b.conn = tls.Client(b.conn, validServerNameTLS(b.addr, conf.Net.TLS.Config))
197 }
198
199 b.conn = newBufConn(b.conn)
200 b.conf = conf
201
202 // Create or reuse the global metrics shared between brokers
203 b.incomingByteRate = metrics.GetOrRegisterMeter("incoming-byte-rate", b.metricRegistry)
204 b.requestRate = metrics.GetOrRegisterMeter("request-rate", b.metricRegistry)
205 b.fetchRate = metrics.GetOrRegisterMeter("consumer-fetch-rate", b.metricRegistry)
206 b.requestSize = getOrRegisterHistogram("request-size", b.metricRegistry)
207 b.requestLatency = getOrRegisterHistogram("request-latency-in-ms", b.metricRegistry)
208 b.outgoingByteRate = metrics.GetOrRegisterMeter("outgoing-byte-rate", b.metricRegistry)
209 b.responseRate = metrics.GetOrRegisterMeter("response-rate", b.metricRegistry)
210 b.responseSize = getOrRegisterHistogram("response-size", b.metricRegistry)
211 b.requestsInFlight = metrics.GetOrRegisterCounter("requests-in-flight", b.metricRegistry)
212 b.protocolRequestsRate = map[int16]metrics.Meter{}
213 // Do not gather metrics for seeded broker (only used during bootstrap) because they share
214 // the same id (-1) and are already exposed through the global metrics above
215 if b.id >= 0 && !metrics.UseNilMetrics {
216 b.registerMetrics()
217 }
218
219 // Send an ApiVersionsRequest to identify the client (KIP-511).
220 // Store the response in the brokerAPIVersions map.
221 // It will be used to determine the supported API versions for each request.
222 // This should happen before SASL authentication: https://kafka.apache.org/26/protocol.html#api_versions
223 if conf.ApiVersionsRequest {
224 apiVersionsResponse, err := b.sendAndReceiveApiVersions(3)
225 if err != nil {
226 Logger.Printf("Error while sending ApiVersionsRequest V3 to broker %s: %s\n", b.addr, err)
227 // send a lower version request in case remote cluster is <= 2.4.0.0
228 maxVersion := int16(0)
229 if apiVersionsResponse != nil {
230 for _, k := range apiVersionsResponse.ApiKeys {
231 if k.ApiKey == apiKeyApiVersions {
232 maxVersion = k.MaxVersion
233 break
234 }
235 }
236 }
237 apiVersionsResponse, err = b.sendAndReceiveApiVersions(maxVersion)
238 if err != nil {
239 Logger.Printf("Error while sending ApiVersionsRequest V%d to broker %s: %s\n", maxVersion, b.addr, err)
240 }
241 }
242 if apiVersionsResponse != nil {
243 b.brokerAPIVersions = make(apiVersionMap, len(apiVersionsResponse.ApiKeys))
244 for _, key := range apiVersionsResponse.ApiKeys {
245 b.brokerAPIVersions[key.ApiKey] = &apiVersionRange{
246 minVersion: key.MinVersion,
247 maxVersion: key.MaxVersion,
248 }
249 }
250 }
251 }
252
253 if conf.Net.SASL.Mechanism == SASLTypeOAuth && conf.Net.SASL.Version == SASLHandshakeV0 {
254 conf.Net.SASL.Version = SASLHandshakeV1
255 }
256
257 useSaslV0 := conf.Net.SASL.Version == SASLHandshakeV0
258 if conf.Net.SASL.Enable && useSaslV0 {
259 b.connErr = b.authenticateViaSASLv0()
260
261 if b.connErr != nil {
262 err = b.conn.Close()
263 if err == nil {
264 DebugLogger.Printf("Closed connection to broker %s due to SASL v0 auth error: %s\n", b.addr, b.connErr)
265 } else {
266 Logger.Printf("Error while closing connection to broker %s (due to SASL v0 auth error: %s): %s\n", b.addr, b.connErr, err)
267 }
268 b.conn = nil
269 b.opened.Store(false)
270 return
271 }
272 }
273
274 b.done = make(chan bool)
275 b.responses = make(chan *responsePromise, b.conf.Net.MaxOpenRequests-1)
276
277 go withRecover(b.responseReceiver)
278 if conf.Net.SASL.Enable && !useSaslV0 {
279 b.connErr = b.authenticateViaSASLv1()
280 if b.connErr != nil {
281 close(b.responses)
282 <-b.done
283 err = b.conn.Close()
284 if err == nil {
285 DebugLogger.Printf("Closed connection to broker %s due to SASL v1 auth error: %s\n", b.addr, b.connErr)
286 } else {
287 Logger.Printf("Error while closing connection to broker %s (due to SASL v1 auth error: %s): %s\n", b.addr, b.connErr, err)
288 }
289 b.conn = nil
290 b.opened.Store(false)
291 return
292 }
293 }
294 if b.id >= 0 {
295 DebugLogger.Printf("Connected to broker at %s (registered as #%d)\n", b.addr, b.id)
296 } else {
297 DebugLogger.Printf("Connected to broker at %s (unregistered)\n", b.addr)
298 }
299 })
300
301 return nil
302}
303
304func (b *Broker) ResponseSize() int {
305 b.lock.Lock()
306 defer b.lock.Unlock()
307
308 return len(b.responses)
309}
310
311// Connected returns true if the broker is connected and false otherwise. If the broker is not
312// connected but it had tried to connect, the error from that connection attempt is also returned.
313func (b *Broker) Connected() (bool, error) {
314 b.lock.Lock()
315 defer b.lock.Unlock()
316
317 return b.conn != nil, b.connErr
318}
319
320// TLSConnectionState returns the client's TLS connection state. The second return value is false if this is not a tls connection or the connection has not yet been established.
321func (b *Broker) TLSConnectionState() (state tls.ConnectionState, ok bool) {
322 b.lock.Lock()
323 defer b.lock.Unlock()
324
325 if b.conn == nil {
326 return state, false
327 }
328 conn := b.conn
329 if bconn, ok := b.conn.(*bufConn); ok {
330 conn = bconn.Conn
331 }
332 if tc, ok := conn.(*tls.Conn); ok {
333 return tc.ConnectionState(), true
334 }
335 return state, false
336}
337
338// Close closes the broker resources
339func (b *Broker) Close() error {
340 b.lock.Lock()
341 defer b.lock.Unlock()
342
343 if b.conn == nil {
344 return ErrNotConnected
345 }
346
347 close(b.responses)
348 <-b.done
349
350 err := b.conn.Close()
351
352 b.conn = nil
353 b.connErr = nil
354 b.done = nil
355 b.responses = nil
356
357 b.metricRegistry.UnregisterAll()
358
359 if err == nil {
360 DebugLogger.Printf("Closed connection to broker %s\n", b.addr)
361 } else {
362 Logger.Printf("Error while closing connection to broker %s: %s\n", b.addr, err)
363 }
364 b.opened.Store(false)
365
366 return err
367}
368
369// ID returns the broker ID retrieved from Kafka's metadata, or -1 if that is not known.
370func (b *Broker) ID() int32 {
371 return b.id
372}
373
374// Addr returns the broker address as either retrieved from Kafka's metadata or passed to NewBroker.
375func (b *Broker) Addr() string {
376 return b.addr
377}
378
379// Rack returns the broker's rack as retrieved from Kafka's metadata or the
380// empty string if it is not known. The returned value corresponds to the
381// broker's broker.rack configuration setting. Requires protocol version to be
382// at least v0.10.0.0.
383func (b *Broker) Rack() string {
384 if b.rack == nil {
385 return ""
386 }
387 return *b.rack
388}
389
390// GetMetadata send a metadata request and returns a metadata response or error
391func (b *Broker) GetMetadata(request *MetadataRequest) (*MetadataResponse, error) {
392 response := new(MetadataResponse)
393 response.Version = request.Version // Required to ensure use of the correct response header version
394
395 err := b.sendAndReceive(request, response)
396 if err != nil {
397 return nil, err
398 }
399
400 return response, nil
401}
402
403// GetConsumerMetadata send a consumer metadata request and returns a consumer metadata response or error
404func (b *Broker) GetConsumerMetadata(request *ConsumerMetadataRequest) (*ConsumerMetadataResponse, error) {
405 response := new(ConsumerMetadataResponse)
406
407 err := b.sendAndReceive(request, response)
408 if err != nil {
409 return nil, err
410 }
411
412 return response, nil
413}
414
415// FindCoordinator sends a find coordinate request and returns a response or error
416func (b *Broker) FindCoordinator(request *FindCoordinatorRequest) (*FindCoordinatorResponse, error) {
417 response := new(FindCoordinatorResponse)
418
419 err := b.sendAndReceive(request, response)
420 if err != nil {
421 return nil, err
422 }
423
424 return response, nil
425}
426
427// GetAvailableOffsets return an offset response or error
428func (b *Broker) GetAvailableOffsets(request *OffsetRequest) (*OffsetResponse, error) {
429 response := new(OffsetResponse)
430
431 err := b.sendAndReceive(request, response)
432 if err != nil {
433 return nil, err
434 }
435
436 return response, nil
437}
438
439// ProduceCallback function is called once the produce response has been parsed
440// or could not be read.
441type ProduceCallback func(*ProduceResponse, error)
442
443// AsyncProduce sends a produce request and eventually call the provided callback
444// with a produce response or an error.
445//
446// Waiting for the response is generally not blocking on the contrary to using Produce.
447// If the maximum number of in flight request configured is reached then
448// the request will be blocked till a previous response is received.
449//
450// When configured with RequiredAcks == NoResponse, the callback will not be invoked.
451// If an error is returned because the request could not be sent then the callback
452// will not be invoked either.
453//
454// Make sure not to Close the broker in the callback as it will lead to a deadlock.
455func (b *Broker) AsyncProduce(request *ProduceRequest, cb ProduceCallback) error {
456 b.lock.Lock()
457 defer b.lock.Unlock()
458
459 needAcks := request.RequiredAcks != NoResponse
460 // Use a nil promise when no acks is required
461 var promise *responsePromise
462
463 if needAcks {
464 metricRegistry := b.metricRegistry
465
466 // Create ProduceResponse early to provide the header version
467 res := new(ProduceResponse)
468 promise = &responsePromise{
469 response: res,
470 // Packets will be converted to a ProduceResponse in the responseReceiver goroutine
471 handler: func(packets []byte, err error) {
472 if err != nil {
473 // Failed request
474 cb(nil, err)
475 return
476 }
477
478 if err := versionedDecode(packets, res, request.version(), metricRegistry); err != nil {
479 // Malformed response
480 cb(nil, err)
481 return
482 }
483
484 // Well-formed response
485 b.handleThrottledResponse(res)
486 cb(res, nil)
487 },
488 }
489 }
490
491 return b.sendWithPromise(request, promise)
492}
493
494// Produce returns a produce response or error
495func (b *Broker) Produce(request *ProduceRequest) (*ProduceResponse, error) {
496 var (
497 response *ProduceResponse
498 err error
499 )
500
501 if request.RequiredAcks == NoResponse {
502 err = b.sendAndReceive(request, nil)
503 } else {
504 response = new(ProduceResponse)
505 err = b.sendAndReceive(request, response)
506 }
507
508 if err != nil {
509 return nil, err
510 }
511
512 return response, nil
513}
514
515// Fetch returns a FetchResponse or error
516func (b *Broker) Fetch(request *FetchRequest) (*FetchResponse, error) {
517 defer func() {
518 if b.fetchRate != nil {
519 b.fetchRate.Mark(1)
520 }
521 if b.brokerFetchRate != nil {
522 b.brokerFetchRate.Mark(1)
523 }
524 }()
525
526 response := new(FetchResponse)
527
528 err := b.sendAndReceive(request, response)
529 if err != nil {
530 return nil, err
531 }
532
533 return response, nil
534}
535
536// CommitOffset return an Offset commit response or error
537func (b *Broker) CommitOffset(request *OffsetCommitRequest) (*OffsetCommitResponse, error) {
538 response := new(OffsetCommitResponse)
539
540 err := b.sendAndReceive(request, response)
541 if err != nil {
542 return nil, err
543 }
544
545 return response, nil
546}
547
548// FetchOffset returns an offset fetch response or error
549func (b *Broker) FetchOffset(request *OffsetFetchRequest) (*OffsetFetchResponse, error) {
550 response := new(OffsetFetchResponse)
551 response.Version = request.Version // needed to handle the two header versions
552
553 err := b.sendAndReceive(request, response)
554 if err != nil {
555 return nil, err
556 }
557
558 return response, nil
559}
560
561// JoinGroup returns a join group response or error
562func (b *Broker) JoinGroup(request *JoinGroupRequest) (*JoinGroupResponse, error) {
563 response := new(JoinGroupResponse)
564
565 err := b.sendAndReceive(request, response)
566 if err != nil {
567 return nil, err
568 }
569
570 return response, nil
571}
572
573// SyncGroup returns a sync group response or error
574func (b *Broker) SyncGroup(request *SyncGroupRequest) (*SyncGroupResponse, error) {
575 response := new(SyncGroupResponse)
576
577 err := b.sendAndReceive(request, response)
578 if err != nil {
579 return nil, err
580 }
581
582 return response, nil
583}
584
585// LeaveGroup return a leave group response or error
586func (b *Broker) LeaveGroup(request *LeaveGroupRequest) (*LeaveGroupResponse, error) {
587 response := new(LeaveGroupResponse)
588
589 err := b.sendAndReceive(request, response)
590 if err != nil {
591 return nil, err
592 }
593
594 return response, nil
595}
596
597// Heartbeat returns a heartbeat response or error
598func (b *Broker) Heartbeat(request *HeartbeatRequest) (*HeartbeatResponse, error) {
599 response := new(HeartbeatResponse)
600
601 err := b.sendAndReceive(request, response)
602 if err != nil {
603 return nil, err
604 }
605
606 return response, nil
607}
608
609// ListGroups return a list group response or error
610func (b *Broker) ListGroups(request *ListGroupsRequest) (*ListGroupsResponse, error) {
611 response := new(ListGroupsResponse)
612 response.Version = request.Version // Required to ensure use of the correct response header version
613
614 err := b.sendAndReceive(request, response)
615 if err != nil {
616 return nil, err
617 }
618
619 return response, nil
620}
621
622// DescribeGroups return describe group response or error
623func (b *Broker) DescribeGroups(request *DescribeGroupsRequest) (*DescribeGroupsResponse, error) {
624 response := new(DescribeGroupsResponse)
625
626 err := b.sendAndReceive(request, response)
627 if err != nil {
628 return nil, err
629 }
630
631 return response, nil
632}
633
634// ApiVersions return api version response or error
635func (b *Broker) ApiVersions(request *ApiVersionsRequest) (*ApiVersionsResponse, error) {
636 response := new(ApiVersionsResponse)
637
638 err := b.sendAndReceive(request, response)
639 if err != nil {
640 return nil, err
641 }
642
643 return response, nil
644}
645
646// CreateTopics send a create topic request and returns create topic response
647func (b *Broker) CreateTopics(request *CreateTopicsRequest) (*CreateTopicsResponse, error) {
648 response := new(CreateTopicsResponse)
649
650 err := b.sendAndReceive(request, response)
651 if err != nil {
652 return nil, err
653 }
654
655 return response, nil
656}
657
658// DeleteTopics sends a delete topic request and returns delete topic response
659func (b *Broker) DeleteTopics(request *DeleteTopicsRequest) (*DeleteTopicsResponse, error) {
660 response := new(DeleteTopicsResponse)
661
662 err := b.sendAndReceive(request, response)
663 if err != nil {
664 return nil, err
665 }
666
667 return response, nil
668}
669
670// CreatePartitions sends a create partition request and returns create
671// partitions response or error
672func (b *Broker) CreatePartitions(request *CreatePartitionsRequest) (*CreatePartitionsResponse, error) {
673 response := new(CreatePartitionsResponse)
674
675 err := b.sendAndReceive(request, response)
676 if err != nil {
677 return nil, err
678 }
679
680 return response, nil
681}
682
683// AlterPartitionReassignments sends a alter partition reassignments request and
684// returns alter partition reassignments response
685func (b *Broker) AlterPartitionReassignments(request *AlterPartitionReassignmentsRequest) (*AlterPartitionReassignmentsResponse, error) {
686 response := new(AlterPartitionReassignmentsResponse)
687
688 err := b.sendAndReceive(request, response)
689 if err != nil {
690 return nil, err
691 }
692
693 return response, nil
694}
695
696// ListPartitionReassignments sends a list partition reassignments request and
697// returns list partition reassignments response
698func (b *Broker) ListPartitionReassignments(request *ListPartitionReassignmentsRequest) (*ListPartitionReassignmentsResponse, error) {
699 response := new(ListPartitionReassignmentsResponse)
700
701 err := b.sendAndReceive(request, response)
702 if err != nil {
703 return nil, err
704 }
705
706 return response, nil
707}
708
709// ElectLeaders sends aa elect leaders request and returns list partitions elect result
710func (b *Broker) ElectLeaders(request *ElectLeadersRequest) (*ElectLeadersResponse, error) {
711 response := new(ElectLeadersResponse)
712
713 err := b.sendAndReceive(request, response)
714 if err != nil {
715 return nil, err
716 }
717
718 return response, nil
719}
720
721// DeleteRecords send a request to delete records and return delete record
722// response or error
723func (b *Broker) DeleteRecords(request *DeleteRecordsRequest) (*DeleteRecordsResponse, error) {
724 response := new(DeleteRecordsResponse)
725
726 err := b.sendAndReceive(request, response)
727 if err != nil {
728 return nil, err
729 }
730
731 return response, nil
732}
733
734// DescribeAcls sends a describe acl request and returns a response or error
735func (b *Broker) DescribeAcls(request *DescribeAclsRequest) (*DescribeAclsResponse, error) {
736 response := new(DescribeAclsResponse)
737
738 err := b.sendAndReceive(request, response)
739 if err != nil {
740 return nil, err
741 }
742
743 return response, nil
744}
745
746// CreateAcls sends a create acl request and returns a response or error
747func (b *Broker) CreateAcls(request *CreateAclsRequest) (*CreateAclsResponse, error) {
748 response := new(CreateAclsResponse)
749
750 err := b.sendAndReceive(request, response)
751 if err != nil {
752 return nil, err
753 }
754
755 errs := make([]error, 0)
756 for _, res := range response.AclCreationResponses {
757 if !errors.Is(res.Err, ErrNoError) {
758 errs = append(errs, res.Err)
759 }
760 }
761
762 if len(errs) > 0 {
763 return response, Wrap(ErrCreateACLs, errs...)
764 }
765
766 return response, nil
767}
768
769// DeleteAcls sends a delete acl request and returns a response or error
770func (b *Broker) DeleteAcls(request *DeleteAclsRequest) (*DeleteAclsResponse, error) {
771 response := new(DeleteAclsResponse)
772
773 err := b.sendAndReceive(request, response)
774 if err != nil {
775 return nil, err
776 }
777
778 return response, nil
779}
780
781// InitProducerID sends an init producer request and returns a response or error
782func (b *Broker) InitProducerID(request *InitProducerIDRequest) (*InitProducerIDResponse, error) {
783 response := new(InitProducerIDResponse)
784 response.Version = request.version()
785
786 err := b.sendAndReceive(request, response)
787 if err != nil {
788 return nil, err
789 }
790
791 return response, nil
792}
793
794// AddPartitionsToTxn send a request to add partition to txn and returns
795// a response or error
796func (b *Broker) AddPartitionsToTxn(request *AddPartitionsToTxnRequest) (*AddPartitionsToTxnResponse, error) {
797 response := new(AddPartitionsToTxnResponse)
798
799 err := b.sendAndReceive(request, response)
800 if err != nil {
801 return nil, err
802 }
803
804 return response, nil
805}
806
807// AddOffsetsToTxn sends a request to add offsets to txn and returns a response
808// or error
809func (b *Broker) AddOffsetsToTxn(request *AddOffsetsToTxnRequest) (*AddOffsetsToTxnResponse, error) {
810 response := new(AddOffsetsToTxnResponse)
811
812 err := b.sendAndReceive(request, response)
813 if err != nil {
814 return nil, err
815 }
816
817 return response, nil
818}
819
820// EndTxn sends a request to end txn and returns a response or error
821func (b *Broker) EndTxn(request *EndTxnRequest) (*EndTxnResponse, error) {
822 response := new(EndTxnResponse)
823
824 err := b.sendAndReceive(request, response)
825 if err != nil {
826 return nil, err
827 }
828
829 return response, nil
830}
831
832// TxnOffsetCommit sends a request to commit transaction offsets and returns
833// a response or error
834func (b *Broker) TxnOffsetCommit(request *TxnOffsetCommitRequest) (*TxnOffsetCommitResponse, error) {
835 response := new(TxnOffsetCommitResponse)
836
837 err := b.sendAndReceive(request, response)
838 if err != nil {
839 return nil, err
840 }
841
842 return response, nil
843}
844
845// DescribeConfigs sends a request to describe config and returns a response or
846// error
847func (b *Broker) DescribeConfigs(request *DescribeConfigsRequest) (*DescribeConfigsResponse, error) {
848 response := new(DescribeConfigsResponse)
849
850 err := b.sendAndReceive(request, response)
851 if err != nil {
852 return nil, err
853 }
854
855 return response, nil
856}
857
858// AlterConfigs sends a request to alter config and return a response or error
859func (b *Broker) AlterConfigs(request *AlterConfigsRequest) (*AlterConfigsResponse, error) {
860 response := new(AlterConfigsResponse)
861
862 err := b.sendAndReceive(request, response)
863 if err != nil {
864 return nil, err
865 }
866
867 return response, nil
868}
869
870// IncrementalAlterConfigs sends a request to incremental alter config and return a response or error
871func (b *Broker) IncrementalAlterConfigs(request *IncrementalAlterConfigsRequest) (*IncrementalAlterConfigsResponse, error) {
872 response := new(IncrementalAlterConfigsResponse)
873
874 err := b.sendAndReceive(request, response)
875 if err != nil {
876 return nil, err
877 }
878
879 return response, nil
880}
881
882// DeleteGroups sends a request to delete groups and returns a response or error
883func (b *Broker) DeleteGroups(request *DeleteGroupsRequest) (*DeleteGroupsResponse, error) {
884 response := new(DeleteGroupsResponse)
885
886 if err := b.sendAndReceive(request, response); err != nil {
887 return nil, err
888 }
889
890 return response, nil
891}
892
893// DeleteOffsets sends a request to delete group offsets and returns a response or error
894func (b *Broker) DeleteOffsets(request *DeleteOffsetsRequest) (*DeleteOffsetsResponse, error) {
895 response := new(DeleteOffsetsResponse)
896
897 if err := b.sendAndReceive(request, response); err != nil {
898 return nil, err
899 }
900
901 return response, nil
902}
903
904// DescribeLogDirs sends a request to get the broker's log dir paths and sizes
905func (b *Broker) DescribeLogDirs(request *DescribeLogDirsRequest) (*DescribeLogDirsResponse, error) {
906 response := new(DescribeLogDirsResponse)
907
908 err := b.sendAndReceive(request, response)
909 if err != nil {
910 return nil, err
911 }
912
913 return response, nil
914}
915
916// DescribeUserScramCredentials sends a request to get SCRAM users
917func (b *Broker) DescribeUserScramCredentials(req *DescribeUserScramCredentialsRequest) (*DescribeUserScramCredentialsResponse, error) {
918 res := new(DescribeUserScramCredentialsResponse)
919
920 err := b.sendAndReceive(req, res)
921 if err != nil {
922 return nil, err
923 }
924
925 return res, err
926}
927
928func (b *Broker) AlterUserScramCredentials(req *AlterUserScramCredentialsRequest) (*AlterUserScramCredentialsResponse, error) {
929 res := new(AlterUserScramCredentialsResponse)
930
931 err := b.sendAndReceive(req, res)
932 if err != nil {
933 return nil, err
934 }
935
936 return res, nil
937}
938
939// DescribeClientQuotas sends a request to get the broker's quotas
940func (b *Broker) DescribeClientQuotas(request *DescribeClientQuotasRequest) (*DescribeClientQuotasResponse, error) {
941 response := new(DescribeClientQuotasResponse)
942
943 err := b.sendAndReceive(request, response)
944 if err != nil {
945 return nil, err
946 }
947
948 return response, nil
949}
950
951// AlterClientQuotas sends a request to alter the broker's quotas
952func (b *Broker) AlterClientQuotas(request *AlterClientQuotasRequest) (*AlterClientQuotasResponse, error) {
953 response := new(AlterClientQuotasResponse)
954
955 err := b.sendAndReceive(request, response)
956 if err != nil {
957 return nil, err
958 }
959
960 return response, nil
961}
962
963// readFull ensures the conn ReadDeadline has been setup before making a
964// call to io.ReadFull
965func (b *Broker) readFull(buf []byte) (n int, err error) {
966 if err := b.conn.SetReadDeadline(time.Now().Add(b.conf.Net.ReadTimeout)); err != nil {
967 return 0, err
968 }
969
970 return io.ReadFull(b.conn, buf)
971}
972
973// write ensures the conn Deadline has been setup before making a
974// call to conn.Write
975func (b *Broker) write(buf []byte) (n int, err error) {
976 now := time.Now()
977 if err := b.conn.SetWriteDeadline(now.Add(b.conf.Net.WriteTimeout)); err != nil {
978 return 0, err
979 }
980 // TLS connections require both read and write deadlines to be set
981 // to avoid handshake indefinite blocking
982 // see https://github.com/golang/go/blob/go1.23.0/src/crypto/tls/conn.go#L1192-L1195
983 if b.conf.Net.TLS.Enable {
984 if err := b.conn.SetReadDeadline(now.Add(b.conf.Net.ReadTimeout)); err != nil {
985 return 0, err
986 }
987 }
988
989 return b.conn.Write(buf)
990}
991
992// b.lock must be held by caller
993//
994// a non-nil res results in a response promise being created
995func (b *Broker) send(req, res protocolBody) (*responsePromise, error) {
996 var promise *responsePromise
997 if res != nil {
998 // Packets or error will be sent to the following channels
999 // once the response is received
1000 promise = makeResponsePromise(res)
1001 }
1002
1003 if err := b.sendWithPromise(req, promise); err != nil {
1004 return nil, err
1005 }
1006
1007 return promise, nil
1008}
1009
1010func makeResponsePromise(res protocolBody) *responsePromise {
1011 promise := &responsePromise{
1012 response: res,
1013 packets: make(chan []byte),
1014 errors: make(chan error),
1015 }
1016 return promise
1017}
1018
1019// b.lock must be held by caller
1020func (b *Broker) sendWithPromise(rb protocolBody, promise *responsePromise) error {
1021 if b.conn == nil {
1022 if b.connErr != nil {
1023 return b.connErr
1024 }
1025 return ErrNotConnected
1026 }
1027
1028 if b.clientSessionReauthenticationTimeMs > 0 && currentUnixMilli() > b.clientSessionReauthenticationTimeMs {
1029 err := b.authenticateViaSASLv1()
1030 if err != nil {
1031 return err
1032 }
1033 }
1034
1035 return b.sendInternal(rb, promise)
1036}
1037
1038// b.lock must be held by caller
1039func (b *Broker) sendInternal(rb protocolBody, promise *responsePromise) error {
1040 // try restricting API version to ranges advertised by the broker
1041 if err := restrictApiVersion(rb, b.brokerAPIVersions); err != nil {
1042 return err
1043 }
1044
1045 // response versions must always match their corresponding request's
1046 if promise != nil && promise.response != nil {
1047 promise.response.setVersion(rb.version())
1048 }
1049
1050 if !b.conf.Version.IsAtLeast(rb.requiredVersion()) {
1051 return ErrUnsupportedVersion
1052 }
1053
1054 req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
1055 buf, err := encode(req, b.metricRegistry)
1056 if err != nil {
1057 return err
1058 }
1059
1060 // check and wait if throttled
1061 b.waitIfThrottled()
1062
1063 requestTime := time.Now()
1064 // Will be decremented in responseReceiver (except error or request with NoResponse)
1065 b.addRequestInFlightMetrics(1)
1066 bytes, err := b.write(buf)
1067 b.updateOutgoingCommunicationMetrics(bytes)
1068 b.updateProtocolMetrics(rb)
1069 if err != nil {
1070 b.addRequestInFlightMetrics(-1)
1071 return err
1072 }
1073 b.correlationID++
1074
1075 if promise == nil {
1076 // Record request latency without the response
1077 b.updateRequestLatencyAndInFlightMetrics(time.Since(requestTime))
1078 return nil
1079 }
1080
1081 promise.requestTime = requestTime
1082 promise.correlationID = req.correlationID
1083 b.responses <- promise
1084
1085 return nil
1086}
1087
1088func (b *Broker) sendAndReceive(req protocolBody, res protocolBody) error {
1089 b.lock.Lock()
1090 defer b.lock.Unlock()
1091
1092 promise, err := b.send(req, res)
1093 if err != nil {
1094 return err
1095 }
1096
1097 if promise == nil {
1098 return nil
1099 }
1100
1101 err = handleResponsePromise(req, res, promise, b.metricRegistry)
1102 if err != nil {
1103 return err
1104 }
1105 if res != nil {
1106 b.handleThrottledResponse(res)
1107 }
1108 return nil
1109}
1110
1111func handleResponsePromise(req protocolBody, res protocolBody, promise *responsePromise, metricRegistry metrics.Registry) error {
1112 select {
1113 case buf := <-promise.packets:
1114 return versionedDecode(buf, res, req.version(), metricRegistry)
1115 case err := <-promise.errors:
1116 return err
1117 }
1118}
1119
1120func (b *Broker) decode(pd packetDecoder, version int16) (err error) {
1121 b.id, err = pd.getInt32()
1122 if err != nil {
1123 return err
1124 }
1125
1126 host, err := pd.getString()
1127 if err != nil {
1128 return err
1129 }
1130
1131 port, err := pd.getInt32()
1132 if err != nil {
1133 return err
1134 }
1135
1136 if version >= 1 {
1137 b.rack, err = pd.getNullableString()
1138 if err != nil {
1139 return err
1140 }
1141 }
1142
1143 b.addr = net.JoinHostPort(host, fmt.Sprint(port))
1144 if _, _, err := net.SplitHostPort(b.addr); err != nil {
1145 return err
1146 }
1147
1148 _, err = pd.getEmptyTaggedFieldArray()
1149 return err
1150}
1151
1152func (b *Broker) encode(pe packetEncoder, version int16) (err error) {
1153 host, portstr, err := net.SplitHostPort(b.addr)
1154 if err != nil {
1155 return err
1156 }
1157
1158 port, err := strconv.ParseInt(portstr, 10, 32)
1159 if err != nil {
1160 return err
1161 }
1162
1163 pe.putInt32(b.id)
1164
1165 err = pe.putString(host)
1166 if err != nil {
1167 return err
1168 }
1169
1170 pe.putInt32(int32(port))
1171
1172 if version >= 1 {
1173 err = pe.putNullableString(b.rack)
1174 if err != nil {
1175 return err
1176 }
1177 }
1178
1179 pe.putEmptyTaggedFieldArray()
1180 return nil
1181}
1182
1183func (b *Broker) responseReceiver() {
1184 var dead error
1185
1186 for promise := range b.responses {
1187 if dead != nil {
1188 // This was previously incremented in send() and
1189 // we are not calling updateIncomingCommunicationMetrics()
1190 b.addRequestInFlightMetrics(-1)
1191 promise.handle(nil, dead)
1192 continue
1193 }
1194
1195 headerLength := getHeaderLength(promise.response.headerVersion())
1196 header := make([]byte, headerLength)
1197
1198 bytesReadHeader, err := b.readFull(header)
1199 requestLatency := time.Since(promise.requestTime)
1200 if err != nil {
1201 b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
1202 dead = err
1203 promise.handle(nil, err)
1204 continue
1205 }
1206
1207 decodedHeader := responseHeader{}
1208 err = versionedDecode(header, &decodedHeader, promise.response.headerVersion(), b.metricRegistry)
1209 if err != nil {
1210 b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
1211 dead = err
1212 promise.handle(nil, err)
1213 continue
1214 }
1215 if decodedHeader.correlationID != promise.correlationID {
1216 b.updateIncomingCommunicationMetrics(bytesReadHeader, requestLatency)
1217 // TODO if decoded ID < cur ID, discard until we catch up
1218 // TODO if decoded ID > cur ID, save it so when cur ID catches up we have a response
1219 dead = PacketDecodingError{fmt.Sprintf("correlation ID didn't match, wanted %d, got %d", promise.correlationID, decodedHeader.correlationID)}
1220 promise.handle(nil, dead)
1221 continue
1222 }
1223
1224 buf := make([]byte, decodedHeader.length-int32(headerLength)+4)
1225 bytesReadBody, err := b.readFull(buf)
1226 b.updateIncomingCommunicationMetrics(bytesReadHeader+bytesReadBody, requestLatency)
1227 if err != nil {
1228 dead = err
1229 promise.handle(nil, err)
1230 continue
1231 }
1232
1233 promise.handle(buf, nil)
1234 }
1235 close(b.done)
1236}
1237
1238func getHeaderLength(headerVersion int16) int8 {
1239 if headerVersion < 1 {
1240 return 8
1241 } else {
1242 // header contains additional tagged field length (0), we don't support actual tags yet.
1243 return 9
1244 }
1245}
1246
1247func (b *Broker) sendAndReceiveApiVersions(v int16) (*ApiVersionsResponse, error) {
1248 rb := &ApiVersionsRequest{
1249 Version: v,
1250 ClientSoftwareName: defaultClientSoftwareName,
1251 ClientSoftwareVersion: version(),
1252 }
1253
1254 req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
1255 buf, err := encode(req, b.metricRegistry)
1256 if err != nil {
1257 return nil, err
1258 }
1259
1260 requestTime := time.Now()
1261 // Will be decremented in updateIncomingCommunicationMetrics (except error)
1262 b.addRequestInFlightMetrics(1)
1263 bytes, err := b.write(buf)
1264 b.updateOutgoingCommunicationMetrics(bytes)
1265 if err != nil {
1266 b.addRequestInFlightMetrics(-1)
1267 Logger.Printf("Failed to send ApiVersionsRequest V%d to %s: %s\n", v, b.addr, err)
1268 return nil, err
1269 }
1270 b.correlationID++
1271
1272 // Kafka protocol response structure:
1273 // - Message length (4 bytes): Total length of the response excluding this field
1274 // - ResponseHeader v0 (4 bytes): Contains correlation ID for request-response matching
1275 header := make([]byte, 8)
1276 _, err = b.readFull(header)
1277 if err != nil {
1278 b.addRequestInFlightMetrics(-1)
1279 Logger.Printf("Failed to read ApiVersionsResponse V%d header from %s: %s\n", v, b.addr, err)
1280 return nil, err
1281 }
1282
1283 length := binary.BigEndian.Uint32(header[:4])
1284 // we're not using the correlation ID here, but it is part of the response header
1285 // correlationID := binary.BigEndian.Uint32(header[4:])
1286
1287 payload := make([]byte, length-4)
1288 n, err := b.readFull(payload)
1289 if err != nil {
1290 b.addRequestInFlightMetrics(-1)
1291 Logger.Printf("Failed to read ApiVersionsResponse V%d payload from %s: %s\n", v, b.addr, err)
1292 return nil, err
1293 }
1294
1295 b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime))
1296 res := &ApiVersionsResponse{Version: rb.version()}
1297 err = versionedDecode(payload, res, rb.version(), b.metricRegistry)
1298 if err != nil {
1299 Logger.Printf("Failed to parse ApiVersionsResponse V%d from %s: %s\n", v, b.addr, err)
1300 return nil, err
1301 }
1302
1303 kerr := KError(res.ErrorCode)
1304 if kerr != ErrNoError {
1305 return res, fmt.Errorf("Error in ApiVersionsResponse V%d from %s: %w", res.Version, b.addr, kerr)
1306 }
1307
1308 DebugLogger.Printf("Completed ApiVersionsRequest V%d to %s. Broker supports %d APIs\n", v, b.addr, len(res.ApiKeys))
1309 return res, nil
1310}
1311
1312func (b *Broker) authenticateViaSASLv0() error {
1313 switch b.conf.Net.SASL.Mechanism {
1314 case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
1315 return b.sendAndReceiveSASLSCRAMv0()
1316 case SASLTypeGSSAPI:
1317 return b.sendAndReceiveKerberos()
1318 default:
1319 return b.sendAndReceiveSASLPlainAuthV0()
1320 }
1321}
1322
1323func (b *Broker) authenticateViaSASLv1() error {
1324 metricRegistry := b.metricRegistry
1325 if b.conf.Net.SASL.Handshake {
1326 handshakeRequest := &SaslHandshakeRequest{Mechanism: string(b.conf.Net.SASL.Mechanism), Version: b.conf.Net.SASL.Version}
1327 handshakeResponse := new(SaslHandshakeResponse)
1328 prom := makeResponsePromise(handshakeResponse)
1329
1330 handshakeErr := b.sendInternal(handshakeRequest, prom)
1331 if handshakeErr != nil {
1332 Logger.Printf("Error while performing SASL handshake %s: %s\n", b.addr, handshakeErr)
1333 return handshakeErr
1334 }
1335 handshakeErr = handleResponsePromise(handshakeRequest, handshakeResponse, prom, metricRegistry)
1336 if handshakeErr != nil {
1337 Logger.Printf("Error while handling SASL handshake response %s: %s\n", b.addr, handshakeErr)
1338 return handshakeErr
1339 }
1340
1341 if !errors.Is(handshakeResponse.Err, ErrNoError) {
1342 return handshakeResponse.Err
1343 }
1344 }
1345
1346 authSendReceiver := func(authBytes []byte) (*SaslAuthenticateResponse, error) {
1347 authenticateRequest := b.createSaslAuthenticateRequest(authBytes)
1348 authenticateResponse := new(SaslAuthenticateResponse)
1349 prom := makeResponsePromise(authenticateResponse)
1350 authErr := b.sendInternal(authenticateRequest, prom)
1351 if authErr != nil {
1352 Logger.Printf("Error while performing SASL Auth %s\n", b.addr)
1353 return nil, authErr
1354 }
1355 authErr = handleResponsePromise(authenticateRequest, authenticateResponse, prom, metricRegistry)
1356 if authErr != nil {
1357 Logger.Printf("Error while performing SASL Auth %s: %s\n", b.addr, authErr)
1358 return nil, authErr
1359 }
1360
1361 if !errors.Is(authenticateResponse.Err, ErrNoError) {
1362 var err error = authenticateResponse.Err
1363 if authenticateResponse.ErrorMessage != nil {
1364 err = Wrap(authenticateResponse.Err, errors.New(*authenticateResponse.ErrorMessage))
1365 }
1366 return nil, err
1367 }
1368
1369 b.computeSaslSessionLifetime(authenticateResponse)
1370 return authenticateResponse, nil
1371 }
1372
1373 switch b.conf.Net.SASL.Mechanism {
1374 case SASLTypeGSSAPI:
1375 b.kerberosAuthenticator.Config = &b.conf.Net.SASL.GSSAPI
1376 if b.kerberosAuthenticator.NewKerberosClientFunc == nil {
1377 b.kerberosAuthenticator.NewKerberosClientFunc = NewKerberosClient
1378 }
1379 return b.kerberosAuthenticator.AuthorizeV2(b, authSendReceiver)
1380 case SASLTypeOAuth:
1381 provider := b.conf.Net.SASL.TokenProvider
1382 return b.sendAndReceiveSASLOAuth(authSendReceiver, provider)
1383 case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512:
1384 return b.sendAndReceiveSASLSCRAMv1(authSendReceiver, b.conf.Net.SASL.SCRAMClientGeneratorFunc())
1385 default:
1386 return b.sendAndReceiveSASLPlainAuthV1(authSendReceiver)
1387 }
1388}
1389
1390func (b *Broker) sendAndReceiveKerberos() error {
1391 b.kerberosAuthenticator.Config = &b.conf.Net.SASL.GSSAPI
1392 if b.kerberosAuthenticator.NewKerberosClientFunc == nil {
1393 b.kerberosAuthenticator.NewKerberosClientFunc = NewKerberosClient
1394 }
1395 return b.kerberosAuthenticator.Authorize(b)
1396}
1397
1398func (b *Broker) sendAndReceiveSASLHandshake(saslType SASLMechanism, version int16) error {
1399 rb := &SaslHandshakeRequest{Mechanism: string(saslType), Version: version}
1400
1401 req := &request{correlationID: b.correlationID, clientID: b.conf.ClientID, body: rb}
1402 buf, err := encode(req, b.metricRegistry)
1403 if err != nil {
1404 return err
1405 }
1406
1407 requestTime := time.Now()
1408 // Will be decremented in updateIncomingCommunicationMetrics (except error)
1409 b.addRequestInFlightMetrics(1)
1410 bytes, err := b.write(buf)
1411 b.updateOutgoingCommunicationMetrics(bytes)
1412 if err != nil {
1413 b.addRequestInFlightMetrics(-1)
1414 Logger.Printf("Failed to send SASL handshake %s: %s\n", b.addr, err.Error())
1415 return err
1416 }
1417 b.correlationID++
1418
1419 header := make([]byte, 8) // response header
1420 _, err = b.readFull(header)
1421 if err != nil {
1422 b.addRequestInFlightMetrics(-1)
1423 Logger.Printf("Failed to read SASL handshake header : %s\n", err.Error())
1424 return err
1425 }
1426
1427 length := binary.BigEndian.Uint32(header[:4])
1428 payload := make([]byte, length-4)
1429 n, err := b.readFull(payload)
1430 if err != nil {
1431 b.addRequestInFlightMetrics(-1)
1432 Logger.Printf("Failed to read SASL handshake payload : %s\n", err.Error())
1433 return err
1434 }
1435
1436 b.updateIncomingCommunicationMetrics(n+8, time.Since(requestTime))
1437 res := &SaslHandshakeResponse{}
1438
1439 err = versionedDecode(payload, res, 0, b.metricRegistry)
1440 if err != nil {
1441 Logger.Printf("Failed to parse SASL handshake : %s\n", err.Error())
1442 return err
1443 }
1444
1445 if !errors.Is(res.Err, ErrNoError) {
1446 Logger.Printf("Invalid SASL Mechanism : %s\n", res.Err.Error())
1447 return res.Err
1448 }
1449
1450 DebugLogger.Print("Completed pre-auth SASL handshake. Available mechanisms: ", res.EnabledMechanisms)
1451 return nil
1452}
1453
1454//
1455// In SASL Plain, Kafka expects the auth header to be in the following format
1456// Message format (from https://tools.ietf.org/html/rfc4616):
1457//
1458// message = [authzid] UTF8NUL authcid UTF8NUL passwd
1459// authcid = 1*SAFE ; MUST accept up to 255 octets
1460// authzid = 1*SAFE ; MUST accept up to 255 octets
1461// passwd = 1*SAFE ; MUST accept up to 255 octets
1462// UTF8NUL = %x00 ; UTF-8 encoded NUL character
1463//
1464// SAFE = UTF1 / UTF2 / UTF3 / UTF4
1465// ;; any UTF-8 encoded Unicode character except NUL
1466//
1467//
1468
1469// Kafka 0.10.x supported SASL PLAIN/Kerberos via KAFKA-3149 (KIP-43).
1470// sendAndReceiveSASLPlainAuthV0 flows the v0 sasl auth NOT wrapped in the kafka protocol
1471//
1472// With SASL v0 handshake and auth then:
1473// When credentials are valid, Kafka returns a 4 byte array of null characters.
1474// When credentials are invalid, Kafka closes the connection.
1475func (b *Broker) sendAndReceiveSASLPlainAuthV0() error {
1476 // default to V0 to allow for backward compatibility when SASL is enabled
1477 // but not the handshake
1478 if b.conf.Net.SASL.Handshake {
1479 handshakeErr := b.sendAndReceiveSASLHandshake(SASLTypePlaintext, b.conf.Net.SASL.Version)
1480 if handshakeErr != nil {
1481 Logger.Printf("Error while performing SASL handshake %s: %s\n", b.addr, handshakeErr)
1482 return handshakeErr
1483 }
1484 }
1485
1486 length := len(b.conf.Net.SASL.AuthIdentity) + 1 + len(b.conf.Net.SASL.User) + 1 + len(b.conf.Net.SASL.Password)
1487 authBytes := make([]byte, length+4) // 4 byte length header + auth data
1488 binary.BigEndian.PutUint32(authBytes, uint32(length))
1489 copy(authBytes[4:], b.conf.Net.SASL.AuthIdentity+"\x00"+b.conf.Net.SASL.User+"\x00"+b.conf.Net.SASL.Password)
1490
1491 requestTime := time.Now()
1492 // Will be decremented in updateIncomingCommunicationMetrics (except error)
1493 b.addRequestInFlightMetrics(1)
1494 bytesWritten, err := b.write(authBytes)
1495 b.updateOutgoingCommunicationMetrics(bytesWritten)
1496 if err != nil {
1497 b.addRequestInFlightMetrics(-1)
1498 Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
1499 return err
1500 }
1501
1502 header := make([]byte, 4)
1503 n, err := b.readFull(header)
1504 b.updateIncomingCommunicationMetrics(n, time.Since(requestTime))
1505 // If the credentials are valid, we would get a 4 byte response filled with null characters.
1506 // Otherwise, the broker closes the connection and we get an EOF
1507 if err != nil {
1508 Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
1509 return err
1510 }
1511
1512 DebugLogger.Printf("SASL authentication successful with broker %s:%v - %v\n", b.addr, n, header)
1513 return nil
1514}
1515
1516// Kafka 1.x.x onward added a SaslAuthenticate request/response message which
1517// wraps the SASL flow in the Kafka protocol, which allows for returning
1518// meaningful errors on authentication failure.
1519func (b *Broker) sendAndReceiveSASLPlainAuthV1(authSendReceiver func(authBytes []byte) (*SaslAuthenticateResponse, error)) error {
1520 authBytes := []byte(b.conf.Net.SASL.AuthIdentity + "\x00" + b.conf.Net.SASL.User + "\x00" + b.conf.Net.SASL.Password)
1521 _, err := authSendReceiver(authBytes)
1522 return err
1523}
1524
1525func currentUnixMilli() int64 {
1526 return time.Now().UnixNano() / int64(time.Millisecond)
1527}
1528
1529// sendAndReceiveSASLOAuth performs the authentication flow as described by KIP-255
1530// https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=75968876
1531func (b *Broker) sendAndReceiveSASLOAuth(authSendReceiver func(authBytes []byte) (*SaslAuthenticateResponse, error), provider AccessTokenProvider) error {
1532 token, err := provider.Token()
1533 if err != nil {
1534 return err
1535 }
1536
1537 message, err := buildClientFirstMessage(token)
1538 if err != nil {
1539 return err
1540 }
1541
1542 res, err := authSendReceiver(message)
1543 if err != nil {
1544 return err
1545 }
1546 isChallenge := len(res.SaslAuthBytes) > 0
1547
1548 if isChallenge {
1549 // Abort the token exchange. The broker returns the failure code.
1550 _, err = authSendReceiver([]byte(`\x01`))
1551 }
1552 return err
1553}
1554
1555func (b *Broker) sendAndReceiveSASLSCRAMv0() error {
1556 if err := b.sendAndReceiveSASLHandshake(b.conf.Net.SASL.Mechanism, SASLHandshakeV0); err != nil {
1557 return err
1558 }
1559
1560 scramClient := b.conf.Net.SASL.SCRAMClientGeneratorFunc()
1561 if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
1562 return fmt.Errorf("failed to start SCRAM exchange with the server: %w", err)
1563 }
1564
1565 msg, err := scramClient.Step("")
1566 if err != nil {
1567 return fmt.Errorf("failed to advance the SCRAM exchange: %w", err)
1568 }
1569
1570 for !scramClient.Done() {
1571 requestTime := time.Now()
1572 // Will be decremented in updateIncomingCommunicationMetrics (except error)
1573 b.addRequestInFlightMetrics(1)
1574 length := len(msg)
1575 authBytes := make([]byte, length+4) // 4 byte length header + auth data
1576 binary.BigEndian.PutUint32(authBytes, uint32(length))
1577 copy(authBytes[4:], msg)
1578 _, err := b.write(authBytes)
1579 b.updateOutgoingCommunicationMetrics(length + 4)
1580 if err != nil {
1581 b.addRequestInFlightMetrics(-1)
1582 Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
1583 return err
1584 }
1585 b.correlationID++
1586 header := make([]byte, 4)
1587 _, err = b.readFull(header)
1588 if err != nil {
1589 b.addRequestInFlightMetrics(-1)
1590 Logger.Printf("Failed to read response header while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
1591 return err
1592 }
1593 payload := make([]byte, int32(binary.BigEndian.Uint32(header)))
1594 n, err := b.readFull(payload)
1595 if err != nil {
1596 b.addRequestInFlightMetrics(-1)
1597 Logger.Printf("Failed to read response payload while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
1598 return err
1599 }
1600 b.updateIncomingCommunicationMetrics(n+4, time.Since(requestTime))
1601 msg, err = scramClient.Step(string(payload))
1602 if err != nil {
1603 Logger.Println("SASL authentication failed", err)
1604 return err
1605 }
1606 }
1607
1608 DebugLogger.Println("SASL authentication succeeded")
1609 return nil
1610}
1611
1612func (b *Broker) sendAndReceiveSASLSCRAMv1(authSendReceiver func(authBytes []byte) (*SaslAuthenticateResponse, error), scramClient SCRAMClient) error {
1613 if err := scramClient.Begin(b.conf.Net.SASL.User, b.conf.Net.SASL.Password, b.conf.Net.SASL.SCRAMAuthzID); err != nil {
1614 return fmt.Errorf("failed to start SCRAM exchange with the server: %w", err)
1615 }
1616
1617 msg, err := scramClient.Step("")
1618 if err != nil {
1619 return fmt.Errorf("failed to advance the SCRAM exchange: %w", err)
1620 }
1621
1622 for !scramClient.Done() {
1623 res, err := authSendReceiver([]byte(msg))
1624 if err != nil {
1625 return err
1626 }
1627
1628 msg, err = scramClient.Step(string(res.SaslAuthBytes))
1629 if err != nil {
1630 Logger.Println("SASL authentication failed", err)
1631 return err
1632 }
1633 }
1634
1635 DebugLogger.Println("SASL authentication succeeded")
1636
1637 return nil
1638}
1639
1640func (b *Broker) createSaslAuthenticateRequest(msg []byte) *SaslAuthenticateRequest {
1641 authenticateRequest := SaslAuthenticateRequest{SaslAuthBytes: msg}
1642 if b.conf.Version.IsAtLeast(V2_2_0_0) {
1643 authenticateRequest.Version = 1
1644 }
1645
1646 return &authenticateRequest
1647}
1648
1649// Build SASL/OAUTHBEARER initial client response as described by RFC-7628
1650// https://tools.ietf.org/html/rfc7628
1651func buildClientFirstMessage(token *AccessToken) ([]byte, error) {
1652 var ext string
1653
1654 if token == nil {
1655 return []byte{}, fmt.Errorf("failed to build client first message: token is nil")
1656 }
1657
1658 if len(token.Extensions) > 0 {
1659 if _, ok := token.Extensions[SASLExtKeyAuth]; ok {
1660 return []byte{}, fmt.Errorf("the extension `%s` is invalid", SASLExtKeyAuth)
1661 }
1662 ext = "\x01" + mapToString(token.Extensions, "=", "\x01")
1663 }
1664
1665 resp := fmt.Appendf(nil, "n,,\x01auth=Bearer %s%s\x01\x01", token.Token, ext)
1666
1667 return resp, nil
1668}
1669
1670// mapToString returns a list of key-value pairs ordered by key.
1671// keyValSep separates the key from the value. elemSep separates each pair.
1672func mapToString(extensions map[string]string, keyValSep string, elemSep string) string {
1673 buf := make([]string, 0, len(extensions))
1674
1675 for k, v := range extensions {
1676 buf = append(buf, k+keyValSep+v)
1677 }
1678
1679 sort.Strings(buf)
1680
1681 return strings.Join(buf, elemSep)
1682}
1683
1684func (b *Broker) computeSaslSessionLifetime(res *SaslAuthenticateResponse) {
1685 if res.SessionLifetimeMs > 0 {
1686 // Follows the Java Kafka implementation from SaslClientAuthenticator.ReauthInfo#setAuthenticationEndAndSessionReauthenticationTimes
1687 // pick a random percentage between 85% and 95% for session re-authentication
1688 positiveSessionLifetimeMs := res.SessionLifetimeMs
1689 authenticationEndMs := currentUnixMilli()
1690 pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount := 0.85
1691 pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously := 0.10
1692 pctToUse := pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + rand.Float64()*pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously
1693 sessionLifetimeMsToUse := int64(float64(positiveSessionLifetimeMs) * pctToUse)
1694 DebugLogger.Printf("Session expiration in %d ms and session re-authentication on or after %d ms", positiveSessionLifetimeMs, sessionLifetimeMsToUse)
1695 b.clientSessionReauthenticationTimeMs = authenticationEndMs + sessionLifetimeMsToUse
1696 } else {
1697 b.clientSessionReauthenticationTimeMs = 0
1698 }
1699}
1700
1701func (b *Broker) updateIncomingCommunicationMetrics(bytes int, requestLatency time.Duration) {
1702 b.updateRequestLatencyAndInFlightMetrics(requestLatency)
1703 b.responseRate.Mark(1)
1704
1705 if b.brokerResponseRate != nil {
1706 b.brokerResponseRate.Mark(1)
1707 }
1708
1709 responseSize := int64(bytes)
1710 b.incomingByteRate.Mark(responseSize)
1711 if b.brokerIncomingByteRate != nil {
1712 b.brokerIncomingByteRate.Mark(responseSize)
1713 }
1714
1715 b.responseSize.Update(responseSize)
1716 if b.brokerResponseSize != nil {
1717 b.brokerResponseSize.Update(responseSize)
1718 }
1719}
1720
1721func (b *Broker) updateRequestLatencyAndInFlightMetrics(requestLatency time.Duration) {
1722 requestLatencyInMs := int64(requestLatency / time.Millisecond)
1723 b.requestLatency.Update(requestLatencyInMs)
1724
1725 if b.brokerRequestLatency != nil {
1726 b.brokerRequestLatency.Update(requestLatencyInMs)
1727 }
1728
1729 b.addRequestInFlightMetrics(-1)
1730}
1731
1732func (b *Broker) addRequestInFlightMetrics(i int64) {
1733 b.requestsInFlight.Inc(i)
1734 if b.brokerRequestsInFlight != nil {
1735 b.brokerRequestsInFlight.Inc(i)
1736 }
1737}
1738
1739func (b *Broker) updateOutgoingCommunicationMetrics(bytes int) {
1740 b.requestRate.Mark(1)
1741 if b.brokerRequestRate != nil {
1742 b.brokerRequestRate.Mark(1)
1743 }
1744
1745 requestSize := int64(bytes)
1746 b.outgoingByteRate.Mark(requestSize)
1747 if b.brokerOutgoingByteRate != nil {
1748 b.brokerOutgoingByteRate.Mark(requestSize)
1749 }
1750
1751 b.requestSize.Update(requestSize)
1752 if b.brokerRequestSize != nil {
1753 b.brokerRequestSize.Update(requestSize)
1754 }
1755}
1756
1757func (b *Broker) updateProtocolMetrics(rb protocolBody) {
1758 protocolRequestsRate := b.protocolRequestsRate[rb.key()]
1759 if protocolRequestsRate == nil {
1760 protocolRequestsRate = metrics.GetOrRegisterMeter(fmt.Sprintf("protocol-requests-rate-%d", rb.key()), b.metricRegistry)
1761 b.protocolRequestsRate[rb.key()] = protocolRequestsRate
1762 }
1763 protocolRequestsRate.Mark(1)
1764
1765 if b.brokerProtocolRequestsRate != nil {
1766 brokerProtocolRequestsRate := b.brokerProtocolRequestsRate[rb.key()]
1767 if brokerProtocolRequestsRate == nil {
1768 brokerProtocolRequestsRate = b.registerMeter(fmt.Sprintf("protocol-requests-rate-%d", rb.key()))
1769 b.brokerProtocolRequestsRate[rb.key()] = brokerProtocolRequestsRate
1770 }
1771 brokerProtocolRequestsRate.Mark(1)
1772 }
1773}
1774
1775type throttleSupport interface {
1776 throttleTime() time.Duration
1777}
1778
1779func (b *Broker) handleThrottledResponse(resp protocolBody) {
1780 throttledResponse, ok := resp.(throttleSupport)
1781 if !ok {
1782 return
1783 }
1784 throttleTime := throttledResponse.throttleTime()
1785 if throttleTime == time.Duration(0) {
1786 return
1787 }
1788 DebugLogger.Printf(
1789 "broker/%d %T throttled %v\n", b.ID(), resp, throttleTime)
1790 b.setThrottle(throttleTime)
1791 b.updateThrottleMetric(throttleTime)
1792}
1793
1794func (b *Broker) setThrottle(throttleTime time.Duration) {
1795 b.throttleTimerLock.Lock()
1796 defer b.throttleTimerLock.Unlock()
1797 if b.throttleTimer != nil {
1798 // if there is an existing timer stop/clear it
1799 if !b.throttleTimer.Stop() {
1800 <-b.throttleTimer.C
1801 }
1802 }
1803 b.throttleTimer = time.NewTimer(throttleTime)
1804}
1805
1806func (b *Broker) waitIfThrottled() {
1807 b.throttleTimerLock.Lock()
1808 defer b.throttleTimerLock.Unlock()
1809 if b.throttleTimer != nil {
1810 DebugLogger.Printf("broker/%d waiting for throttle timer\n", b.ID())
1811 <-b.throttleTimer.C
1812 b.throttleTimer = nil
1813 }
1814}
1815
1816func (b *Broker) updateThrottleMetric(throttleTime time.Duration) {
1817 if b.brokerThrottleTime != nil {
1818 throttleTimeInMs := int64(throttleTime / time.Millisecond)
1819 b.brokerThrottleTime.Update(throttleTimeInMs)
1820 }
1821}
1822
1823func (b *Broker) registerMetrics() {
1824 b.brokerIncomingByteRate = b.registerMeter("incoming-byte-rate")
1825 b.brokerRequestRate = b.registerMeter("request-rate")
1826 b.brokerFetchRate = b.registerMeter("consumer-fetch-rate")
1827 b.brokerRequestSize = b.registerHistogram("request-size")
1828 b.brokerRequestLatency = b.registerHistogram("request-latency-in-ms")
1829 b.brokerOutgoingByteRate = b.registerMeter("outgoing-byte-rate")
1830 b.brokerResponseRate = b.registerMeter("response-rate")
1831 b.brokerResponseSize = b.registerHistogram("response-size")
1832 b.brokerRequestsInFlight = b.registerCounter("requests-in-flight")
1833 b.brokerThrottleTime = b.registerHistogram("throttle-time-in-ms")
1834 b.brokerProtocolRequestsRate = map[int16]metrics.Meter{}
1835}
1836
1837func (b *Broker) registerMeter(name string) metrics.Meter {
1838 nameForBroker := getMetricNameForBroker(name, b)
1839 return metrics.GetOrRegisterMeter(nameForBroker, b.metricRegistry)
1840}
1841
1842func (b *Broker) registerHistogram(name string) metrics.Histogram {
1843 nameForBroker := getMetricNameForBroker(name, b)
1844 return getOrRegisterHistogram(nameForBroker, b.metricRegistry)
1845}
1846
1847func (b *Broker) registerCounter(name string) metrics.Counter {
1848 nameForBroker := getMetricNameForBroker(name, b)
1849 return metrics.GetOrRegisterCounter(nameForBroker, b.metricRegistry)
1850}
1851
1852func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config {
1853 if cfg == nil {
1854 cfg = &tls.Config{
1855 MinVersion: tls.VersionTLS12,
1856 }
1857 }
1858 if cfg.ServerName != "" {
1859 return cfg
1860 }
1861
1862 c := cfg.Clone()
1863 sn, _, err := net.SplitHostPort(addr)
1864 if err != nil {
1865 Logger.Println(fmt.Errorf("failed to get ServerName from addr %w", err))
1866 }
1867 c.ServerName = sn
1868 return c
1869}