From 28798100702545bbf2d5a004899e00eda55f7777 Mon Sep 17 00:00:00 2001 From: boks1971 Date: Mon, 16 Sep 2024 22:32:38 +0530 Subject: [PATCH] atomic and tests --- agent_test.go | 21 +++++++++++++++++++++ candidatepair.go | 42 +++++++++++++++++++++++------------------- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/agent_test.go b/agent_test.go index b53c7d0d..5622ec15 100644 --- a/agent_test.go +++ b/agent_test.go @@ -721,6 +721,10 @@ func TestCandidatePairStats(t *testing.T) { p := a.findPair(hostLocal, prflxRemote) p.state = CandidatePairStateFailed + for i := 0; i < 10; i++ { + p.UpdateRoundTripTime(time.Duration(i+1) * time.Second) + } + stats := a.GetCandidatePairsStats() if len(stats) != 4 { t.Fatal("expected 4 candidate pairs stats") @@ -766,6 +770,23 @@ func TestCandidatePairStats(t *testing.T) { t.Fatalf("expected host-prflx pair to have state failed, it has state %s instead", prflxPairStat.State.String()) } + + expectedCurrentRoundTripTime := time.Duration(10) * time.Second + if prflxPairStat.CurrentRoundTripTime != expectedCurrentRoundTripTime.Seconds() { + t.Fatalf("expected current round trip time to be %f, it is %f instead", + expectedCurrentRoundTripTime.Seconds(), prflxPairStat.CurrentRoundTripTime) + } + + expectedTotalRoundTripTime := time.Duration(55) * time.Second + if prflxPairStat.TotalRoundTripTime != expectedTotalRoundTripTime.Seconds() { + t.Fatalf("expected total round trip time to be %f, it is %f instead", + expectedTotalRoundTripTime.Seconds(), prflxPairStat.TotalRoundTripTime) + } + + if prflxPairStat.ResponsesReceived != 10 { + t.Fatalf("expected responses received to be 10, it is %d instead", + prflxPairStat.ResponsesReceived) + } } func TestLocalCandidateStats(t *testing.T) { diff --git a/candidatepair.go b/candidatepair.go index 306232c3..2744a978 100644 --- a/candidatepair.go +++ b/candidatepair.go @@ -5,7 +5,7 @@ package ice import ( "fmt" - "sync" + "sync/atomic" "time" "github.com/pion/stun/v3" @@ -32,9 +32,8 @@ type CandidatePair struct { nominateOnBindingSuccess bool // stats - statsMu sync.RWMutex - currentRoundTripTime time.Duration - totalRoundTripTime time.Duration + currentRoundTripTime atomic.Pointer[time.Duration] + totalRoundTripTime atomic.Pointer[time.Duration] responsesReceived uint64 } @@ -112,37 +111,42 @@ func (a *Agent) sendSTUN(msg *stun.Message, local, remote Candidate) { // UpdateRoundTripTime sets the current round time of this pair and // accumulates total round trip time and responses received func (p *CandidatePair) UpdateRoundTripTime(rtt time.Duration) { - p.statsMu.Lock() - defer p.statsMu.Unlock() + p.currentRoundTripTime.Store(&rtt) - p.currentRoundTripTime = rtt - p.totalRoundTripTime += rtt - p.responsesReceived++ + prevTotalRoundTripTime := p.totalRoundTripTime.Load() + totalRoundTripTime := rtt + if prevTotalRoundTripTime != nil { + totalRoundTripTime += *prevTotalRoundTripTime + } + p.totalRoundTripTime.CompareAndSwap(prevTotalRoundTripTime, &totalRoundTripTime) + + atomic.AddUint64(&p.responsesReceived, 1) } // CurrentRoundTripTime returns the current round trip time in seconds // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-currentroundtriptime func (p *CandidatePair) CurrentRoundTripTime() float64 { - p.statsMu.RLock() - defer p.statsMu.RUnlock() + crtt := p.currentRoundTripTime.Load() + if crtt != nil { + return crtt.Seconds() + } - return p.currentRoundTripTime.Seconds() + return 0 } // TotalRoundTripTime returns the current round trip time in seconds // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-totalroundtriptime func (p *CandidatePair) TotalRoundTripTime() float64 { - p.statsMu.RLock() - defer p.statsMu.RUnlock() + trtt := p.totalRoundTripTime.Load() + if trtt != nil { + return trtt.Seconds() + } - return p.totalRoundTripTime.Seconds() + return 0 } // ResponsesReceived returns the total number of connectivity responses received // https://www.w3.org/TR/webrtc-stats/#dom-rtcicecandidatepairstats-responsesreceived func (p *CandidatePair) ResponsesReceived() uint64 { - p.statsMu.RLock() - defer p.statsMu.RUnlock() - - return p.responsesReceived + return atomic.LoadUint64(&p.responsesReceived) }