Skip to content

Commit

Permalink
extract messaging components from IpfsDHT into its own struct. create…
Browse files Browse the repository at this point in the history
… a new struct that manages sending DHT messages that can be used independently from the DHT.
  • Loading branch information
aschmahmann committed Oct 6, 2020
1 parent 9304f55 commit 1cbdbdf
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 168 deletions.
87 changes: 3 additions & 84 deletions dht.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package dht

import (
"bytes"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -33,7 +32,6 @@ import (
goprocessctx "github.com/jbenet/goprocess/context"
"github.com/multiformats/go-base32"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
"go.opencensus.io/tag"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -97,8 +95,7 @@ type IpfsDHT struct {
ctx context.Context
proc goprocess.Process

strmap map[peer.ID]*messageSender
smlk sync.Mutex
protoMessenger *ProtocolMessenger

plk sync.Mutex

Expand Down Expand Up @@ -188,6 +185,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error)
dht.enableValues = cfg.enableValues

dht.Validator = cfg.validator
dht.protoMessenger = NewProtocolMessenger(dht.host, dht.protocols, dht.Validator)

dht.testAddressUpdateProcessing = cfg.testAddressUpdateProcessing

Expand Down Expand Up @@ -274,7 +272,6 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) {
selfKey: kb.ConvertPeerID(h.ID()),
peerstore: h.Peerstore(),
host: h,
strmap: make(map[peer.ID]*messageSender),
birth: time.Now(),
protocols: protocols,
protocolsStrs: protocol.ConvertToStrings(protocols),
Expand Down Expand Up @@ -507,67 +504,8 @@ func (dht *IpfsDHT) persistRTPeersInPeerStore() {
}
}

// putValueToPeer stores the given key/value pair at the peer 'p'
func (dht *IpfsDHT) putValueToPeer(ctx context.Context, p peer.ID, rec *recpb.Record) error {
pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0)
pmes.Record = rec
rpmes, err := dht.sendRequest(ctx, p, pmes)
if err != nil {
logger.Debugw("failed to put value to peer", "to", p, "key", loggableRecordKeyBytes(rec.Key), "error", err)
return err
}

if !bytes.Equal(rpmes.GetRecord().Value, pmes.GetRecord().Value) {
logger.Infow("value not put correctly", "put-message", pmes, "get-message", rpmes)
return errors.New("value not put correctly")
}

return nil
}

var errInvalidRecord = errors.New("received invalid record")

// getValueOrPeers queries a particular peer p for the value for
// key. It returns either the value or a list of closer peers.
// NOTE: It will update the dht's peerstore with any new addresses
// it finds for the given peer.
func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) {
pmes, err := dht.getValueSingle(ctx, p, key)
if err != nil {
return nil, nil, err
}

// Perhaps we were given closer peers
peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers())

if rec := pmes.GetRecord(); rec != nil {
// Success! We were given the value
logger.Debug("got value")

// make sure record is valid.
err = dht.Validator.Validate(string(rec.GetKey()), rec.GetValue())
if err != nil {
logger.Debug("received invalid record (discarded)")
// return a sentinal to signify an invalid record was received
err = errInvalidRecord
rec = new(recpb.Record)
}
return rec, peers, err
}

if len(peers) > 0 {
return nil, peers, nil
}

return nil, nil, routing.ErrNotFound
}

// getValueSingle simply performs the get value RPC with the given parameters
func (dht *IpfsDHT) getValueSingle(ctx context.Context, p peer.ID, key string) (*pb.Message, error) {
pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0)
return dht.sendRequest(ctx, p, pmes)
}

// getLocal attempts to retrieve the value from the datastore
func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) {
logger.Debugw("finding value in datastore", "key", loggableRecordKeyString(key))
Expand Down Expand Up @@ -696,17 +634,6 @@ func (dht *IpfsDHT) FindLocal(id peer.ID) peer.AddrInfo {
}
}

// findPeerSingle asks peer 'p' if they know where the peer with id 'id' is
func (dht *IpfsDHT) findPeerSingle(ctx context.Context, p peer.ID, id peer.ID) (*pb.Message, error) {
pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0)
return dht.sendRequest(ctx, p, pmes)
}

func (dht *IpfsDHT) findProvidersSingle(ctx context.Context, p peer.ID, key multihash.Multihash) (*pb.Message, error) {
pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0)
return dht.sendRequest(ctx, p, pmes)
}

// nearestPeersToQuery returns the routing tables closest peers.
func (dht *IpfsDHT) nearestPeersToQuery(pmes *pb.Message, count int) []peer.ID {
closer := dht.routingTable.NearestPeers(kb.ConvertKey(string(pmes.GetKey())), count)
Expand Down Expand Up @@ -847,15 +774,7 @@ func (dht *IpfsDHT) Host() host.Host {

// Ping sends a ping message to the passed peer and waits for a response.
func (dht *IpfsDHT) Ping(ctx context.Context, p peer.ID) error {
req := pb.NewMessage(pb.Message_PING, nil, 0)
resp, err := dht.sendRequest(ctx, p, req)
if err != nil {
return fmt.Errorf("sending request: %w", err)
}
if resp.Type != pb.Message_PING {
return fmt.Errorf("got unexpected response type: %v", resp.Type)
}
return nil
return dht.protoMessenger.Ping(ctx, p)
}

// newContextWithLocalTags returns a new context.Context with the InstanceID and
Expand Down
72 changes: 50 additions & 22 deletions dht_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import (
"time"

"github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"
"github.com/libp2p/go-msgio/protoio"

"github.com/libp2p/go-libp2p-kad-dht/metrics"
Expand Down Expand Up @@ -207,12 +209,38 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool {
}
}

type messageManager struct {
host host.Host // the network services we need
strmap map[peer.ID]*messageSender
smlk sync.Mutex
protocols []protocol.ID
}

func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) {
m.smlk.Lock()
defer m.smlk.Unlock()
ms, ok := m.strmap[p]
if !ok {
return
}
delete(m.strmap, p)

// Do this asynchronously as ms.lk can block for a while.
go func() {
if err := ms.lk.Lock(ctx); err != nil {
return
}
defer ms.lk.Unlock()
ms.invalidate()
}()
}

// sendRequest sends out a request, but also makes sure to
// measure the RTT for latency measurements.
func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
func (m *messageManager) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) {
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))

ms, err := dht.messageSenderForPeer(ctx, p)
ms, err := m.messageSenderForPeer(ctx, p)
if err != nil {
stats.Record(ctx,
metrics.SentRequests.M(1),
Expand All @@ -239,15 +267,15 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message
metrics.SentBytes.M(int64(pmes.Size())),
metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)),
)
dht.peerstore.RecordLatency(p, time.Since(start))
m.host.Peerstore().RecordLatency(p, time.Since(start))
return rpmes, nil
}

// sendMessage sends out a message
func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
func (m *messageManager) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error {
ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes))

ms, err := dht.messageSenderForPeer(ctx, p)
ms, err := m.messageSenderForPeer(ctx, p)
if err != nil {
stats.Record(ctx,
metrics.SentMessages.M(1),
Expand All @@ -273,30 +301,30 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message
return nil
}

func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
dht.smlk.Lock()
ms, ok := dht.strmap[p]
func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) {
m.smlk.Lock()
ms, ok := m.strmap[p]
if ok {
dht.smlk.Unlock()
m.smlk.Unlock()
return ms, nil
}
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
dht.strmap[p] = ms
dht.smlk.Unlock()
ms = &messageSender{p: p, m: m, lk: newCtxMutex()}
m.strmap[p] = ms
m.smlk.Unlock()

if err := ms.prepOrInvalidate(ctx); err != nil {
dht.smlk.Lock()
defer dht.smlk.Unlock()
m.smlk.Lock()
defer m.smlk.Unlock()

if msCur, ok := dht.strmap[p]; ok {
if msCur, ok := m.strmap[p]; ok {
// Changed. Use the new one, old one is invalid and
// not in the map so we can just throw it away.
if ms != msCur {
return msCur, nil
}
// Not changed, remove the now invalid stream from the
// map.
delete(dht.strmap, p)
delete(m.strmap, p)
}
// Invalid but not in map. Must have been removed by a disconnect.
return nil, err
Expand All @@ -306,11 +334,11 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
}

type messageSender struct {
s network.Stream
r msgio.ReadCloser
lk ctxMutex
p peer.ID
dht *IpfsDHT
s network.Stream
r msgio.ReadCloser
lk ctxMutex
p peer.ID
m *messageManager

invalid bool
singleMes int
Expand Down Expand Up @@ -351,7 +379,7 @@ func (ms *messageSender) prep(ctx context.Context) error {
// We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks
// one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for
// backwards compatibility reasons).
nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...)
nstr, err := ms.m.host.NewStream(ctx, ms.p, ms.m.protocols...)
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,14 +570,14 @@ func TestInvalidMessageSenderTracking(t *testing.T) {
defer dht.Close()

foo := peer.ID("asdasd")
_, err := dht.messageSenderForPeer(ctx, foo)
_, err := dht.protoMessenger.m.messageSenderForPeer(ctx, foo)
if err == nil {
t.Fatal("that shouldnt have succeeded")
}

dht.smlk.Lock()
mscnt := len(dht.strmap)
dht.smlk.Unlock()
dht.protoMessenger.m.smlk.Lock()
mscnt := len(dht.protoMessenger.m.strmap)
dht.protoMessenger.m.smlk.Unlock()

if mscnt > 0 {
t.Fatal("should have no message senders in map")
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4er
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
Expand All @@ -66,6 +67,7 @@ github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
Expand All @@ -91,6 +93,7 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huin/goupnp v1.0.0 h1:wg75sLpL6DZqwHQN6E1Cfk6mtfzS45z8OV+ic+DtHRo=
github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc=
Expand Down Expand Up @@ -467,6 +470,7 @@ github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU=
github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg=
github.com/onsi/ginkgo v1.12.1 h1:mFwc4LvZ0xpSvDZ3E+k8Yte0hLOMxXUlP+yXtJqkYfQ=
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
Expand Down Expand Up @@ -626,6 +630,7 @@ golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA=
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand All @@ -650,6 +655,7 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/src-d/go-cli.v0 v0.0.0-20181105080154-d492247bbc0d/go.mod h1:z+K8VcOYVYcSwSjGebuDL6176A1XskgbtNl64NSg+n8=
gopkg.in/src-d/go-log.v1 v1.0.1/go.mod h1:GN34hKP0g305ysm2/hctJ0Y8nWP3zxXXJ8GFabTyABE=
Expand Down
4 changes: 1 addition & 3 deletions lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/routing"

pb "github.com/libp2p/go-libp2p-kad-dht/pb"
kb "github.com/libp2p/go-libp2p-kbucket"
)

Expand All @@ -30,12 +29,11 @@ func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan pee
ID: p,
})

pmes, err := dht.findPeerSingle(ctx, p, peer.ID(key))
peers, err := dht.protoMessenger.GetClosestPeers(ctx, p, peer.ID(key))
if err != nil {
logger.Debugf("error getting closer peers: %s", err)
return nil, err
}
peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers())

// For DHT query command
routing.PublishQueryEvent(ctx, &routing.QueryEvent{
Expand Down
Loading

0 comments on commit 1cbdbdf

Please sign in to comment.