Skip to content

Commit

Permalink
kgo: allow record ctx cancelation to propagate a bit more
Browse files Browse the repository at this point in the history
If a record's context is canceled, we now allow it to be failed in two
more locations:

* while the producer ID is loading -- we can actually now cancel the
  producer ID loading request (which may also benefit people using
  transactions that want to force quit the client)

* while a sink is backing off due to request failures

For people using transactions, canceling a context now allows you to
force quit in more areas, but the same caveat applies: your client will
likely end up in an invalid transactional state and be unable to
continue.

For #769.
  • Loading branch information
twmb committed Jul 29, 2024
1 parent d4982d7 commit 305d8dc
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 30 deletions.
17 changes: 14 additions & 3 deletions pkg/kgo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func isRetryableBrokerErr(err error) bool {
}
// We could have a retryable producer ID failure, which then bubbled up
// as errProducerIDLoadFail so as to be retried later.
if errors.Is(err, errProducerIDLoadFail) {
if pe := (*errProducerIDLoadFail)(nil); errors.As(err, &pe) {
return true
}
// We could have chosen a broker, and then a concurrent metadata update
Expand Down Expand Up @@ -139,8 +139,6 @@ var (
// restart a new connection ourselves.
errSaslReauthLoop = errors.New("the broker is repeatedly giving us sasl lifetimes that are too short to write a request")

errProducerIDLoadFail = errors.New("unable to initialize a producer ID due to request failures")

// A temporary error returned when Kafka replies with a different
// correlation ID than we were expecting for the request the client
// issued.
Expand Down Expand Up @@ -224,6 +222,19 @@ type ErrFirstReadEOF struct {
err error
}

type errProducerIDLoadFail struct {
err error
}

func (e *errProducerIDLoadFail) Error() string {
if e.err == nil {
return "unable to initialize a producer ID due to request failures"
}
return fmt.Sprintf("unable to initialize a producer ID due to request failures: %v", e.err)
}

func (e *errProducerIDLoadFail) Unwrap() error { return e.err }

const (
firstReadSASL uint8 = iota
firstReadTLS
Expand Down
26 changes: 26 additions & 0 deletions pkg/kgo/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,32 @@ var (
npartitionsAt int64
)

type slowConn struct {
net.Conn
}

func (s *slowConn) Write(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
return s.Conn.Write(p)
}

func (s *slowConn) Read(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
return s.Conn.Read(p)
}

type slowDialer struct {
d net.Dialer
}

func (s *slowDialer) DialContext(ctx context.Context, network, host string) (net.Conn, error) {
c, err := s.d.DialContext(ctx, network, host)
if err != nil {
return nil, err
}
return &slowConn{c}, nil
}

func init() {
var err error
if n, _ := strconv.Atoi(os.Getenv("KGO_TEST_RF")); n > 0 {
Expand Down
2 changes: 1 addition & 1 deletion pkg/kgo/produce_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func TestIssue769(t *testing.T) {
case <-timer.C:
t.Fatal("expected record to fail within 3s")
}
if pe := (*errProducerIDLoadFail)(nil); !errors.As(rerr, &pe) || !errors.Is(pe.err, context.Canceled) {
if pe := (*errProducerIDLoadFail)(nil); !errors.As(rerr, &pe) || !(errors.Is(pe.err, context.Canceled) || strings.Contains(pe.err.Error(), "canceled")) {
t.Errorf("got %v != exp errProducerIDLoadFail{context.Canceled}", rerr)
}
}
Expand Down
26 changes: 17 additions & 9 deletions pkg/kgo/producer.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,11 @@ func (cl *Client) TryProduce(
// retries. If any of these conditions are hit and it is currently safe to fail
// records, all buffered records for the relevant partition are failed. Only
// the first record's context in a batch is considered when determining whether
// the batch should be canceled.
// the batch should be canceled. A record is not safe to fail if the client
// is idempotently producing and a request has been sent; in this case, the
// client cannot know if the broker actually processed the request (if so, then
// removing the records from the client will create errors the next time you
// produce).
//
// If the client is transactional and a transaction has not been begun, the
// promise is immediately called with an error corresponding to not being in a
Expand Down Expand Up @@ -679,7 +683,7 @@ func (cl *Client) ProducerID(ctx context.Context) (int64, int16, error) {

go func() {
defer close(done)
id, epoch, err = cl.producerID()
id, epoch, err = cl.producerID(ctx2fn(ctx))
}()

select {
Expand All @@ -701,7 +705,7 @@ var errReloadProducerID = errors.New("producer id needs reloading")
// initProducerID initializes the client's producer ID for idempotent
// producing only (no transactions, which are more special). After the first
// load, this clears all buffered unknown topics.
func (cl *Client) producerID() (int64, int16, error) {
func (cl *Client) producerID(ctxFn func() context.Context) (int64, int16, error) {
p := &cl.producer

id := p.id.Load().(*producerID)
Expand Down Expand Up @@ -730,7 +734,7 @@ func (cl *Client) producerID() (int64, int16, error) {
}
p.id.Store(id)
} else {
newID, keep := cl.doInitProducerID(id.id, id.epoch)
newID, keep := cl.doInitProducerID(ctxFn, id.id, id.epoch)
if keep {
id = newID
// Whenever we have a new producer ID, we need
Expand All @@ -748,7 +752,7 @@ func (cl *Client) producerID() (int64, int16, error) {
id = &producerID{
id: id.id,
epoch: id.epoch,
err: errProducerIDLoadFail,
err: &errProducerIDLoadFail{newID.err},
}
}
}
Expand Down Expand Up @@ -825,7 +829,7 @@ func (cl *Client) failProducerID(id int64, epoch int16, err error) {

// doInitProducerID inits the idempotent ID and potentially the transactional
// producer epoch, returning whether to keep the result.
func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID, bool) {
func (cl *Client) doInitProducerID(ctxFn func() context.Context, lastID int64, lastEpoch int16) (*producerID, bool) {
cl.cfg.logger.Log(LogLevelInfo, "initializing producer id")
req := kmsg.NewPtrInitProducerIDRequest()
req.TransactionalID = cl.cfg.txnID
Expand All @@ -835,7 +839,8 @@ func (cl *Client) doInitProducerID(lastID int64, lastEpoch int16) (*producerID,
req.TransactionTimeoutMillis = int32(cl.cfg.txnTimeout.Milliseconds())
}

resp, err := req.RequestWith(cl.ctx, cl)
ctx := ctxFn()
resp, err := req.RequestWith(ctx, cl)
if err != nil {
if errors.Is(err, errUnknownRequestKey) || errors.Is(err, errBrokerTooOld) {
cl.cfg.logger.Log(LogLevelInfo, "unable to initialize a producer id because the broker is too old or the client is pinned to an old version, continuing without a producer id")
Expand Down Expand Up @@ -940,13 +945,14 @@ func (cl *Client) addUnknownTopicRecord(pr promisedRec) {
}
unknown.buffered = append(unknown.buffered, pr)
if len(unknown.buffered) == 1 {
go cl.waitUnknownTopic(pr.ctx, pr.Topic, unknown)
go cl.waitUnknownTopic(pr.ctx, pr.Record.Context, pr.Topic, unknown)
}
}

// waitUnknownTopic waits for a notification
func (cl *Client) waitUnknownTopic(
rctx context.Context,
pctx context.Context, // context passed to Produce
rctx context.Context, // context on the record itself
topic string,
unknown *unknownTopicProduces,
) {
Expand Down Expand Up @@ -974,6 +980,8 @@ func (cl *Client) waitUnknownTopic(

for err == nil {
select {
case <-pctx.Done():
err = pctx.Err()
case <-rctx.Done():
err = rctx.Err()
case <-cl.ctx.Done():
Expand Down
121 changes: 116 additions & 5 deletions pkg/kgo/sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ func (s *sink) maybeBackoff() {
select {
case <-after.C:
case <-s.cl.ctx.Done():
case <-s.anyCtx().Done():
}
}

Expand Down Expand Up @@ -247,6 +248,34 @@ func (s *sink) drain() {
}
}

// Returns the first context encountered ranging across all records.
// This does not use defers to make it clear at the return that all
// unlocks are called in proper order. Ideally, do not call this func
// due to lock intensity.
func (s *sink) anyCtx() context.Context {
s.recBufsMu.Lock()
for _, recBuf := range s.recBufs {
recBuf.mu.Lock()
if len(recBuf.batches) > 0 {
batch0 := recBuf.batches[0]
batch0.mu.Lock()
if batch0.canFailFromLoadErrs && len(batch0.records) > 0 {
r0 := batch0.records[0]
if rctx := r0.cancelingCtx(); rctx != nil {
batch0.mu.Unlock()
recBuf.mu.Unlock()
s.recBufsMu.Unlock()
return rctx
}
}
batch0.mu.Unlock()
}
recBuf.mu.Unlock()
}
s.recBufsMu.Unlock()
return context.Background()
}

func (s *sink) produce(sem <-chan struct{}) bool {
var produced bool
defer func() {
Expand All @@ -267,6 +296,7 @@ func (s *sink) produce(sem <-chan struct{}) bool {
// - auth failure
// - transactional: a produce failure that failed the producer ID
// - AddPartitionsToTxn failure (see just below)
// - some head-of-line context failure
//
// All but the first error is fatal. Recovery may be possible with
// EndTransaction in specific cases, but regardless, all buffered
Expand All @@ -275,10 +305,71 @@ func (s *sink) produce(sem <-chan struct{}) bool {
// NOTE: we init the producer ID before creating a request to ensure we
// are always using the latest id/epoch with the proper sequence
// numbers. (i.e., resetAllSequenceNumbers && producerID logic combo).
id, epoch, err := s.cl.producerID()
//
// For the first-discovered-record-head-of-line context, we want to
// avoid looking it up if possible (which is why producerID takes a
// ctxFn). If we do use one, we want to be sure that the
// context.Canceled error is from *that* context rather than the client
// context or something else. So, we go through some special care to
// track setting the ctx / looking up if it is canceled.
var holCtxMu sync.Mutex
var holCtx context.Context
ctxFn := func() context.Context {
holCtxMu.Lock()
defer holCtxMu.Unlock()
holCtx = s.anyCtx()
return holCtx
}
isHolCtxDone := func() bool {
holCtxMu.Lock()
defer holCtxMu.Unlock()
if holCtx == nil {
return false
}
select {
case <-holCtx.Done():
return true
default:
}
return false
}

id, epoch, err := s.cl.producerID(ctxFn)
if err != nil {
var pe *errProducerIDLoadFail
switch {
case errors.Is(err, errProducerIDLoadFail):
case errors.As(err, &pe):
if errors.Is(pe.err, context.Canceled) && isHolCtxDone() {
// Some head-of-line record in a partition had a context cancelation.
// We look for any partition with HOL cancelations and fail them all.
s.cl.cfg.logger.Log(LogLevelInfo, "the first record in some partition(s) had a context cancelation; failing all relevant partitions", "broker", logID(s.nodeID))
s.recBufsMu.Lock()
defer s.recBufsMu.Unlock()
for _, recBuf := range s.recBufs {
recBuf.mu.Lock()
var failAll bool
if len(recBuf.batches) > 0 {
batch0 := recBuf.batches[0]
batch0.mu.Lock()
if batch0.canFailFromLoadErrs && len(batch0.records) > 0 {
r0 := batch0.records[0]
if rctx := r0.cancelingCtx(); rctx != nil {
select {
case <-rctx.Done():
failAll = true // we must not call failAllRecords here, because failAllRecords locks batches!
default:
}
}
}
batch0.mu.Unlock()
}
if failAll {
recBuf.failAllRecords(err)
}
recBuf.mu.Unlock()
}
return true
}
s.cl.bumpRepeatedLoadErr(err)
s.cl.cfg.logger.Log(LogLevelWarn, "unable to load producer ID, bumping client's buffered record load errors by 1 and retrying")
return true // whatever caused our produce, we did nothing, so keep going
Expand Down Expand Up @@ -385,6 +476,9 @@ func (s *sink) doSequenced(
promise: promise,
}

// We can NOT use any record context. If we do, we force the request to
// fail while also force the batch to be unfailable (due to no
// response),
br, err := s.cl.brokerOrErr(s.cl.ctx, s.nodeID, errUnknownBroker)
if err != nil {
wait.err = err
Expand Down Expand Up @@ -432,6 +526,11 @@ func (s *sink) doTxnReq(
req.batches.eachOwnerLocked(seqRecBatch.removeFromTxn)
}
}()
// We do NOT let record context cancelations fail this request: doing
// so would put the transactional ID in an unknown state. This is
// similar to the warning we give in the txn.go file, but the
// difference there is the user knows explicitly at the function call
// that canceling the context will opt them into invalid state.
err = s.cl.doWithConcurrentTransactions(s.cl.ctx, "AddPartitionsToTxn", func() error {
stripped, err = s.issueTxnReq(req, txnReq)
return err
Expand Down Expand Up @@ -1422,6 +1521,16 @@ type promisedRec struct {
*Record
}

func (pr promisedRec) cancelingCtx() context.Context {
if pr.ctx.Done() != nil {
return pr.ctx
}
if pr.Context.Done() != nil {
return pr.Context
}
return nil
}

// recBatch is the type used for buffering records before they are written.
type recBatch struct {
owner *recBuf // who owns us
Expand Down Expand Up @@ -1454,10 +1563,12 @@ type recBatch struct {
// Returns an error if the batch should fail.
func (b *recBatch) maybeFailErr(cfg *cfg) error {
if len(b.records) > 0 {
ctx := b.records[0].ctx
r0 := &b.records[0]
select {
case <-ctx.Done():
return ctx.Err()
case <-r0.ctx.Done():
return r0.ctx.Err()
case <-r0.Context.Done():
return r0.Context.Err()
default:
}
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/kgo/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,9 @@ func (s *source) fetch(consumerSession *consumerSession, doneFetch chan<- struct
// reload offsets *always* triggers a metadata update.
if updateWhy != nil {
why := updateWhy.reason(fmt.Sprintf("fetch had inner topic errors from broker %d", s.nodeID))
// loadWithSessionNow triggers a metadata update IF there are
// offsets to reload. If there are no offsets to reload, we
// trigger one here.
if !reloadOffsets.loadWithSessionNow(consumerSession, why) {
if updateWhy.isOnly(kerr.UnknownTopicOrPartition) || updateWhy.isOnly(kerr.UnknownTopicID) {
s.cl.triggerUpdateMetadata(false, why)
Expand Down
Loading

0 comments on commit 305d8dc

Please sign in to comment.