Skip to content

Commit

Permalink
fix PLI and FIR handling, wrongly triggering track.OnEnded (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
EmrysMyrddin authored Aug 16, 2022
1 parent 6f204fa commit 8ad810e
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 26 deletions.
63 changes: 37 additions & 26 deletions track.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

const (
rtpOutboundMTU = 1200
rtcpInboundMTU = 1500
)

var (
Expand Down Expand Up @@ -223,38 +224,48 @@ func (track *baseTrack) bind(ctx webrtc.TrackLocalContext, specializedTrack Trac
keyFrameController, ok := encodedReader.Controller().(codec.KeyFrameController)
if ok {
stopRead = make(chan struct{})
go func() {
reader := ctx.RTCPReader()
for {
select {
case <-stopRead:
return
default:
}
go track.rtcpReadLoop(ctx.RTCPReader(), keyFrameController, stopRead)
}

var readerBuffer []byte
_, _, err := reader.Read(readerBuffer, interceptor.Attributes{})
if err != nil {
track.onError(err)
return
}
return selectedCodec, nil
}

pkts, err := rtcp.Unmarshal(readerBuffer)
func (track *baseTrack) rtcpReadLoop(reader interceptor.RTCPReader, keyFrameController codec.KeyFrameController, stopRead chan struct{}) {
readerBuffer := make([]byte, rtcpInboundMTU)

for _, pkt := range pkts {
switch pkt.(type) {
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
if err := keyFrameController.ForceKeyFrame(); err != nil {
track.onError(err)
return
}
}
readLoop:
for {
select {
case <-stopRead:
return
default:
}

readLength, _, err := reader.Read(readerBuffer, interceptor.Attributes{})
if err != nil {
if errors.Is(err, io.EOF) {
return
}
logger.Warnf("failed to read rtcp packet: %s", err)
continue
}

pkts, err := rtcp.Unmarshal(readerBuffer[:readLength])
if err != nil {
logger.Warnf("failed to unmarshal rtcp packet: %s", err)
continue
}

for _, pkt := range pkts {
switch pkt.(type) {
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
if err := keyFrameController.ForceKeyFrame(); err != nil {
logger.Warnf("failed to force key frame: %s", err)
continue readLoop
}
}
}()
}
}

return selectedCodec, nil
}

func (track *baseTrack) unbind(ctx webrtc.TrackLocalContext) error {
Expand Down
100 changes: 100 additions & 0 deletions track_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package mediadevices

import (
"errors"
"github.com/pion/interceptor"
"io"
"testing"
"time"
)
Expand Down Expand Up @@ -53,3 +55,101 @@ func TestOnEnded(t *testing.T) {
}
})
}

type fakeRTCPReader struct {
mockReturn chan []byte
end chan struct{}
}

func (mock *fakeRTCPReader) Read(buffer []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) {
select {
case <-mock.end:
return 0, nil, io.EOF
case mockReturn := <-mock.mockReturn:
if len(buffer) < len(mock.mockReturn) {
return 0, nil, io.ErrShortBuffer
}

return copy(buffer, mockReturn), attributes, nil
}
}

type fakeKeyFrameController struct {
called chan struct{}
}

func (mock *fakeKeyFrameController) ForceKeyFrame() error {
mock.called <- struct{}{}
return nil
}

func TestRtcpHandler(t *testing.T) {

t.Run("ShouldStopReading", func(t *testing.T) {
tr := &baseTrack{}
stop := make(chan struct{}, 1)
stopped := make(chan struct{})
go func() {
tr.rtcpReadLoop(&fakeRTCPReader{end: stop}, &fakeKeyFrameController{}, stop)
stopped <- struct{}{}
}()

stop <- struct{}{}

select {
case <-time.After(100 * time.Millisecond):
t.Error("Timeout")
case <-stopped:
}
})

t.Run("ShouldForceKeyFrame", func(t *testing.T) {
for packetType, packet := range map[string][]byte{
"PLI": {
// v=2, p=0, FMT=1, PSFB, len=1
0x81, 0xce, 0x00, 0x02,
// ssrc=0x0
0x00, 0x00, 0x00, 0x00,
// ssrc=0x4bc4fcb4
0x4b, 0xc4, 0xfc, 0xb4,
},
"FIR": {
// v=2, p=0, FMT=4, PSFB, len=3
0x84, 0xce, 0x00, 0x04,
// ssrc=0x0
0x00, 0x00, 0x00, 0x00,
// ssrc=0x4bc4fcb4
0x4b, 0xc4, 0xfc, 0xb4,
// ssrc=0x12345678
0x12, 0x34, 0x56, 0x78,
// Seqno=0x42
0x42, 0x00, 0x00, 0x00,
},
} {
t.Run(packetType, func(t *testing.T) {
tr := &baseTrack{}
tr.OnEnded(func(err error) {
if err != io.EOF {
t.Error(err)
}
})
stop := make(chan struct{}, 1)
defer func() {
stop <- struct{}{}
}()
mockKeyFrameController := &fakeKeyFrameController{called: make(chan struct{}, 1)}
mockRTCPReader := &fakeRTCPReader{end: stop, mockReturn: make(chan []byte, 1)}

go tr.rtcpReadLoop(mockRTCPReader, mockKeyFrameController, stop)

mockRTCPReader.mockReturn <- packet

select {
case <-time.After(1000 * time.Millisecond):
t.Error("Timeout")
case <-mockKeyFrameController.called:
}
})
}
})
}

0 comments on commit 8ad810e

Please sign in to comment.