Skip to content

Commit

Permalink
kgo / kmsg: process the RecordBatches field in FetchResponses properly
Browse files Browse the repository at this point in the history
As it turns out, a FetchResponse can be of high enough version to use
actual record batches, but still use message sets. Rather than relying
on the response version, we need to check the magic in each batch.

It may be possible that each batch has a different magic from one batch
to the next, so we cannot just check the first batch and use that for
all decoding.

For message set v1, it is also possible that a client could have used
messages v0, compressed them, and used that as the "inner messages" in a
message v1.

All of these cases are now handled properly, which necessitated the
removal of some functions from kmsg. The functions in kmsg were not
necessarily correct. Further, it was a bit odd to stuff message decoding
and validating into kmsg. This makes kmsg a more dedicated encoding /
decoding package.
  • Loading branch information
twmb committed Feb 22, 2021
1 parent bcb330b commit 0cdc2b6
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 198 deletions.
3 changes: 3 additions & 0 deletions pkg/kgo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package kgo
import (
"context"
"fmt"
"hash/crc32"
"math/rand"
"reflect"
"sort"
Expand All @@ -33,6 +34,8 @@ import (
"github.com/twmb/franz-go/pkg/kmsg"
)

var crc32c = crc32.MakeTable(crc32.Castagnoli) // record crc's use Castagnoli table; for consuming/producing

// Client issues requests and handles responses to a Kafka cluster.
type Client struct {
cfg cfg
Expand Down
2 changes: 0 additions & 2 deletions pkg/kgo/sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -1581,8 +1581,6 @@ func (r seqRecBatch) appendTo(
return dst
}

var crc32c = crc32.MakeTable(crc32.Castagnoli) // record crc's use Castagnoli table

func (pnr promisedNumberedRecord) appendTo(dst []byte, offsetDelta int32) []byte {
dst = kbin.AppendVarint(dst, pnr.lengthField)
dst = kbin.AppendInt8(dst, 0) // attributes, currently unused
Expand Down
298 changes: 223 additions & 75 deletions pkg/kgo/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@ package kgo

import (
"context"
"encoding/binary"
"fmt"
"hash/crc32"
"sync"
"sync/atomic"
"time"

"github.com/twmb/franz-go/pkg/kbin"
"github.com/twmb/franz-go/pkg/kerr"
"github.com/twmb/franz-go/pkg/kmsg"
)

type readerFrom interface {
ReadFrom([]byte) error
}

// A source consumes from an individual broker.
//
// As long as there is at least one active cursor, a source aims to have *one*
Expand Down Expand Up @@ -764,37 +771,91 @@ func (o *cursorOffsetNext) processRespPartition(version int16, rp *kmsg.FetchRes
LogStartOffset: rp.LogStartOffset,
}

switch version {
case 0, 1:
msgs, err := kmsg.ReadV0Messages(rp.RecordBatches)
if err != nil {
fp.Err = err
aborter := buildAborter(rp)

// A response could contain any of message v0, message v1, or record
// batches, and this is solely dictated by the magic byte (not the
// fetch response version). The magic byte is located at byte 17.
//
// 0 thru 8: int64 offset / first offset
// 9 thru 12: int32 length
// 13 thru 16: crc (magic 0 or 1), or partition leader epoch (magic 2)
// 17: magic
//
// We decode and validate similarly for messages and record batches, so
// we "abstract" away the high level stuff into a check function just
// below, and then switch based on the magic for how to process.
var (
in = rp.RecordBatches

r readerFrom
length int32
lengthField *int32
crcField *int32
crcTable *crc32.Table
crcAt int

check = func() bool {
if err := r.ReadFrom(in[:length]); err != nil {
return false
}
if length := int32(len(in[12:length])); length != *lengthField {
fp.Err = fmt.Errorf("encoded length %d does not match read length %d", *lengthField, length)
return false
}
if crcCalc := int32(crc32.Checksum(in[crcAt:length], crcTable)); crcCalc != *crcField {
fp.Err = fmt.Errorf("encoded crc %x does not match calculated crc %x", *crcField, crcCalc)
return false
}
return true
}
o.processV0Messages(&fp, msgs, decompressor)
)

case 2, 3:
msgs, err := kmsg.ReadV1Messages(rp.RecordBatches)
if err != nil {
fp.Err = err
for len(in) > 17 && fp.Err == nil {
length = int32(binary.BigEndian.Uint32(in[8:]))
length += 12 // for the int64 offset we skipped and int32 length field itself
if len(in) < int(length) {
break
}
o.processV1Messages(&fp, msgs, decompressor)

default:
batches, err := kmsg.ReadRecordBatches(rp.RecordBatches)
if err != nil {
fp.Err = err
switch magic := in[16]; magic {
case 0:
m := new(kmsg.MessageV0)
lengthField = &m.MessageSize
crcField = &m.CRC
crcTable = crc32.IEEETable
crcAt = 16
r = m
case 1:
m := new(kmsg.MessageV1)
lengthField = &m.MessageSize
crcField = &m.CRC
crcTable = crc32.IEEETable
crcAt = 16
r = m
case 2:
rb := new(kmsg.RecordBatch)
lengthField = &rb.Length
crcField = &rb.CRC
crcTable = crc32c
crcAt = 21
r = rb

}
var numPartitionRecords int
for i := range batches {
numPartitionRecords += int(batches[i].NumRecords)

if !check() {
break
}
fp.Records = make([]*Record, 0, numPartitionRecords)
aborter := buildAborter(rp)
for i := range batches {
o.processRecordBatch(&fp, &batches[i], aborter, decompressor)
if fp.Err != nil {
break
}

in = in[length:]

switch t := r.(type) {
case *kmsg.MessageV0:
o.processV0OuterMessage(&fp, t, decompressor)
case *kmsg.MessageV1:
o.processV1OuterMessage(&fp, t, decompressor)
case *kmsg.RecordBatch:
o.processRecordBatch(&fp, t, aborter, decompressor)
}
}

Expand Down Expand Up @@ -843,6 +904,24 @@ func (a aborter) trackAbortedPID(producerID int64) {
// processing records to fetch part //
//////////////////////////////////////

// readRawRecords reads n records from in and returns them, returning
// kbin.ErrNotEnoughData if in does not contain enough data.
func readRawRecords(n int, in []byte) ([]kmsg.Record, error) {
rs := make([]kmsg.Record, n)
for i := 0; i < n; i++ {
length, used := kbin.Varint(in)
total := used + int(length)
if used == 0 || length < 0 || len(in) < total {
return nil, kbin.ErrNotEnoughData
}
if err := (&rs[i]).ReadFrom(in[:total]); err != nil {
return nil, err
}
in = in[total:]
}
return rs, nil
}

func (o *cursorOffsetNext) processRecordBatch(
fp *FetchPartition,
batch *kmsg.RecordBatch,
Expand Down Expand Up @@ -870,7 +949,7 @@ func (o *cursorOffsetNext) processRecordBatch(
return
}

krecords, err := kmsg.ReadRecords(int(batch.NumRecords), rawRecords)
krecords, err := readRawRecords(int(batch.NumRecords), rawRecords)
if err != nil {
fp.Err = fmt.Errorf("invalid record batch: %v", err)
return
Expand Down Expand Up @@ -903,36 +982,88 @@ func (o *cursorOffsetNext) processRecordBatch(
}
}

func (o *cursorOffsetNext) processV1Messages(
// Processes an outer v1 message. There could be no inner message, which makes
// this easy, but if not, we decompress and process each inner message as
// either v0 or v1. We only expect the inner message to be v1, but technically
// a crazy pipeline could have v0 anywhere.
func (o *cursorOffsetNext) processV1OuterMessage(
fp *FetchPartition,
messages []kmsg.MessageV1,
message *kmsg.MessageV1,
decompressor *decompressor,
) {
for i := range messages {
message := &messages[i]
compression := byte(message.Attributes & 0x0003)
if compression == 0 {
if !o.processV1Message(fp, message) {
return
}
continue
compression := byte(message.Attributes & 0x0003)
if compression == 0 {
o.processV1Message(fp, message)
return
}

rawInner, err := decompressor.decompress(message.Value, compression)
if err != nil {
fp.Err = fmt.Errorf("unable to decompress messages: %v", err)
return
}

var innerMessages []readerFrom
for len(rawInner) > 17 { // magic at byte 17
length := int32(binary.BigEndian.Uint32(rawInner[8:]))
length += 12 // skip offset and length fields
if len(rawInner) < int(length) {
break
}

rawMessages, err := decompressor.decompress(message.Value, compression)
if err != nil {
fp.Err = fmt.Errorf("unable to decompress messages: %v", err)
return
var (
magic = rawInner[16]

msg readerFrom
lengthField *int32
crcField *int32
)

switch magic {
case 0:
m := new(kmsg.MessageV0)
msg = m
lengthField = &m.MessageSize
crcField = &m.CRC
case 1:
m := new(kmsg.MessageV1)
msg = m
lengthField = &m.MessageSize
crcField = &m.CRC

default:
fp.Err = fmt.Errorf("message set v1 has inner message with invalid magic %d", magic)
break
}
innerMessages, err := kmsg.ReadV1Messages(rawMessages)
if err != nil {
fp.Err = err

if err := msg.ReadFrom(rawInner[:length]); err != nil {
break
}
if len(innerMessages) == 0 {
return
if length := int32(len(rawInner[12:length])); length != *lengthField {
fp.Err = fmt.Errorf("encoded length %d does not match read length %d", *lengthField, length)
break
}
firstOffset := message.Offset - int64(len(innerMessages)) + 1
for i := range innerMessages {
innerMessage := &innerMessages[i]
if crcCalc := int32(crc32.ChecksumIEEE(rawInner[16:length])); crcCalc != *crcField {
fp.Err = fmt.Errorf("encoded crc %x does not match calculated crc %x", *crcField, crcCalc)
break
}
innerMessages = append(innerMessages, msg)
rawInner = rawInner[length:]
}
if len(innerMessages) == 0 {
return
}

firstOffset := message.Offset - int64(len(innerMessages)) + 1
for i := range innerMessages {
innerMessage := innerMessages[i]
switch innerMessage := innerMessage.(type) {
case *kmsg.MessageV0:
innerMessage.Offset = firstOffset + int64(i)
if !o.processV0Message(fp, innerMessage) {
return
}
case *kmsg.MessageV1:
innerMessage.Offset = firstOffset + int64(i)
if !o.processV1Message(fp, innerMessage) {
return
Expand All @@ -958,40 +1089,57 @@ func (o *cursorOffsetNext) processV1Message(
return true
}

func (o *cursorOffsetNext) processV0Messages(
// Processes an outer v0 message. We expect inner messages to be entirely v0 as
// well, so this only tries v0 always.
func (o *cursorOffsetNext) processV0OuterMessage(
fp *FetchPartition,
messages []kmsg.MessageV0,
message *kmsg.MessageV0,
decompressor *decompressor,
) {
for i := range messages {
message := &messages[i]
compression := byte(message.Attributes & 0x0003)
if compression == 0 {
if !o.processV0Message(fp, message) {
return
}
continue
}
compression := byte(message.Attributes & 0x0003)
if compression == 0 {
o.processV0Message(fp, message)
return
}

rawMessages, err := decompressor.decompress(message.Value, compression)
if err != nil {
fp.Err = fmt.Errorf("unable to decompress messages: %v", err)
return
rawInner, err := decompressor.decompress(message.Value, compression)
if err != nil {
fp.Err = fmt.Errorf("unable to decompress messages: %v", err)
return
}

var innerMessages []kmsg.MessageV0
for len(rawInner) > 17 { // magic at byte 17
length := int32(binary.BigEndian.Uint32(rawInner[8:]))
length += 12 // skip offset and length fields
if len(rawInner) < int(length) {
break
}
innerMessages, err := kmsg.ReadV0Messages(rawMessages)
if err != nil {
fp.Err = err
var m kmsg.MessageV0
if err := m.ReadFrom(rawInner[:length]); err != nil {
break
}
if len(innerMessages) == 0 {
return
if length := int32(len(rawInner[12:length])); length != m.MessageSize {
fp.Err = fmt.Errorf("encoded length %d does not match read length %d", m.MessageSize, length)
break
}
firstOffset := message.Offset - int64(len(innerMessages)) + 1
for i := range innerMessages {
innerMessage := &innerMessages[i]
innerMessage.Offset = firstOffset + int64(i)
if !o.processV0Message(fp, innerMessage) {
return
}
if crcCalc := int32(crc32.ChecksumIEEE(rawInner[16:length])); crcCalc != m.CRC {
fp.Err = fmt.Errorf("encoded crc %x does not match calculated crc %x", m.CRC, crcCalc)
break
}
innerMessages = append(innerMessages, m)
rawInner = rawInner[length:]
}
if len(innerMessages) == 0 {
return
}

firstOffset := message.Offset - int64(len(innerMessages)) + 1
for i := range innerMessages {
innerMessage := &innerMessages[i]
innerMessage.Offset = firstOffset + int64(i)
if !o.processV0Message(fp, innerMessage) {
return
}
}
}
Expand Down
Loading

0 comments on commit 0cdc2b6

Please sign in to comment.