Skip to content

Commit

Permalink
Merge pull request #9945 from brian-kelley/FixTemplatedLocalView
Browse files Browse the repository at this point in the history
Tpetra: fix modify flags of WDV::getView<T>
  • Loading branch information
brian-kelley authored Dec 13, 2021
2 parents 419ef91 + 8f6a1c9 commit f975d21
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 58 deletions.
108 changes: 51 additions & 57 deletions packages/tpetra/core/src/Tpetra_Details_WrappedDualView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ class WrappedDualView {

// This is an expert-only constructor
// For WrappedDualView to manage synchronizations correctly,
// it must have an DualView which is not a subview to due the
// sync's on. This is what origDualV is for. In this case,
// it must have an DualView which is not a subview to due the
// sync's on. This is what origDualV is for. In this case,
// dualV is a subview of origDualV.
WrappedDualView(DualViewType dualV,DualViewType origDualV)
: originalDualView(origDualV),
Expand Down Expand Up @@ -201,7 +201,7 @@ class WrappedDualView {
size_t extent(const int i) const {
return dualView.h_view.extent(i);
}

void stride(size_t * stride_) const {
dualView.stride(stride_);
}
Expand All @@ -210,7 +210,7 @@ class WrappedDualView {
size_t origExtent(const int i) const {
return originalDualView.h_view.extent(i);
}

const char * label() const {
return dualView.d_view.label();
}
Expand All @@ -219,7 +219,7 @@ class WrappedDualView {
typename t_host::const_type
getHostView(Access::ReadOnlyStruct
DEBUG_UVM_REMOVAL_ARGUMENT
) const
) const
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getHostViewReadOnly");
if(needsSyncPath()) {
Expand All @@ -232,7 +232,7 @@ class WrappedDualView {
t_host
getHostView(Access::ReadWriteStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getHostViewReadWrite");
static_assert(dualViewHasNonConstData,
Expand All @@ -249,7 +249,7 @@ class WrappedDualView {
t_host
getHostView(Access::OverwriteAllStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getHostViewOverwriteAll");
static_assert(dualViewHasNonConstData,
Expand All @@ -259,7 +259,7 @@ class WrappedDualView {
}
if(needsSyncPath()) {
throwIfDeviceViewAlive();
if (deviceMemoryIsHostAccessible) Kokkos::fence();
if (deviceMemoryIsHostAccessible) Kokkos::fence();
dualView.clear_sync_state();
dualView.modify_host();
}
Expand All @@ -269,7 +269,7 @@ class WrappedDualView {
typename t_dev::const_type
getDeviceView(Access::ReadOnlyStruct
DEBUG_UVM_REMOVAL_ARGUMENT
) const
) const
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceViewReadOnly");
if(needsSyncPath()) {
Expand All @@ -282,7 +282,7 @@ class WrappedDualView {
t_dev
getDeviceView(Access::ReadWriteStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceViewReadWrite");
static_assert(dualViewHasNonConstData,
Expand All @@ -298,7 +298,7 @@ class WrappedDualView {
t_dev
getDeviceView(Access::OverwriteAllStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceViewOverwriteAll");
static_assert(dualViewHasNonConstData,
Expand All @@ -317,13 +317,10 @@ class WrappedDualView {

template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type::const_type
getView (Access::ReadOnlyStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
bool returnDevice = true;
{
auto tmp = dualView.template view<TargetDeviceType>();
if (tmp == this->dualView.view_host()) returnDevice = false;
}

getView (Access::ReadOnlyStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
using ReturnViewType = typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type::const_type;
using ReturnDeviceType = typename ReturnViewType::device_type;
constexpr bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;
if(returnDevice) {
DEBUG_UVM_REMOVAL_PRINT_CALLER("getView<Device>ReadOnly");
if(needsSyncPath()) {
Expand All @@ -338,19 +335,17 @@ class WrappedDualView {
impl::sync_host(originalDualView);
}
}

return dualView.template view<TargetDeviceType>();
}
}


template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type
getView (Access::ReadWriteStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
bool returnDevice = true;
{
auto tmp = dualView.template view<TargetDeviceType>();
if (tmp == this->dualView.view_host()) returnDevice = false;
}
getView (Access::ReadWriteStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
using ReturnViewType = typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type;
using ReturnDeviceType = typename ReturnViewType::device_type;
constexpr bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;

if(returnDevice) {
DEBUG_UVM_REMOVAL_PRINT_CALLER("getView<Device>ReadWrite");
Expand All @@ -369,26 +364,25 @@ class WrappedDualView {
if(needsSyncPath()) {
throwIfDeviceViewAlive();
impl::sync_host(originalDualView);
originalDualView.modify_host();
originalDualView.modify_host();
}
}

return dualView.template view<TargetDeviceType>();
}
}


template<class TargetDeviceType>
typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type
getView (Access::OverwriteAllStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
getView (Access::OverwriteAllStruct s DEBUG_UVM_REMOVAL_ARGUMENT) const {
using ReturnViewType = typename std::remove_reference<decltype(std::declval<DualViewType>().template view<TargetDeviceType>())>::type;
using ReturnDeviceType = typename ReturnViewType::device_type;

if (iAmASubview())
return getView<TargetDeviceType>(Access::ReadWrite);

bool returnDevice = true;
{
auto tmp = dualView.template view<TargetDeviceType>();
if (tmp == this->dualView.view_host()) returnDevice = false;
}

constexpr bool returnDevice = std::is_same<ReturnDeviceType, DeviceType>::value;

if(returnDevice) {
DEBUG_UVM_REMOVAL_PRINT_CALLER("getView<Device>OverwriteAll");
static_assert(dualViewHasNonConstData,
Expand All @@ -399,25 +393,25 @@ class WrappedDualView {
dualView.modify_host();
}
}
else {
else {
DEBUG_UVM_REMOVAL_PRINT_CALLER("getView<Host>OverwriteAll");
static_assert(dualViewHasNonConstData,
"OverwriteAll views are not available for DualView with const data");
if(needsSyncPath()) {
throwIfDeviceViewAlive();
dualView.clear_sync_state();
dualView.modify_device();
dualView.modify_device();
}
}

return dualView.template view<TargetDeviceType>();
}
}


typename t_host::const_type
getHostSubview(int offset, int numEntries, Access::ReadOnlyStruct
DEBUG_UVM_REMOVAL_ARGUMENT
) const
) const
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getHostSubviewReadOnly");
if(needsSyncPath()) {
Expand All @@ -430,7 +424,7 @@ class WrappedDualView {
t_host
getHostSubview(int offset, int numEntries, Access::ReadWriteStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getHostSubviewReadWrite");
static_assert(dualViewHasNonConstData,
Expand All @@ -446,7 +440,7 @@ class WrappedDualView {
t_host
getHostSubview(int offset, int numEntries, Access::OverwriteAllStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getHostSubviewOverwriteAll");
static_assert(dualViewHasNonConstData,
Expand All @@ -470,7 +464,7 @@ class WrappedDualView {
t_dev
getDeviceSubview(int offset, int numEntries, Access::ReadWriteStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceSubviewReadWrite");
static_assert(dualViewHasNonConstData,
Expand All @@ -486,7 +480,7 @@ class WrappedDualView {
t_dev
getDeviceSubview(int offset, int numEntries, Access::OverwriteAllStruct
DEBUG_UVM_REMOVAL_ARGUMENT
)
)
{
DEBUG_UVM_REMOVAL_PRINT_CALLER("getDeviceSubviewOverwriteAll");
static_assert(dualViewHasNonConstData,
Expand Down Expand Up @@ -543,7 +537,7 @@ class WrappedDualView {
bool need_sync_device() const {
return originalDualView.need_sync_device();
}

int host_view_use_count() const {
return originalDualView.h_view.use_count();
}
Expand All @@ -553,14 +547,14 @@ class WrappedDualView {
}


// MultiVector really needs to get at the raw DualViews,
// MultiVector really needs to get at the raw DualViews,
// but we'd very much prefer that users not.
template<typename SC, typename LO, typename GO, typename NO>
friend class ::Tpetra::MultiVector;

private:
// A Kokkos implementation of WrappedDualView will have to make these
// functions publically accessable, but in the Tpetra version, we'd
// functions publically accessable, but in the Tpetra version, we'd
// really rather not.
DualViewType getOriginalDualView() const {
return originalDualView;
Expand Down Expand Up @@ -596,22 +590,22 @@ class WrappedDualView {
}

bool needsSyncPath() const {
// needsSyncPath tells us whether we need the "sync path" where we (potentially) fence,
// needsSyncPath tells us whether we need the "sync path" where we (potentially) fence,
// check use counts and take care of sync/modify for the underlying DualView
//
// The logic is this:
// The logic is this:
// 1) For non-CUDA archtectures where there the host/device pointers are aliased
// we don't need the "sync path."
// 2) For CUDA, we always need the "sync path" if we're using the CudaUVMSpace (we need to make sure
// 2) For CUDA, we always need the "sync path" if we're using the CudaUVMSpace (we need to make sure
// to fence before reading memory on host) OR if the host/device pointers are aliased.
//
// Avoiding the "sync path" speeds up calculations on architectures where we can
// Avoiding the "sync path" speeds up calculations on architectures where we can
// avoid it (e.g. SerialNode) by not not touching the modify flags.
//
// Note for the future: Memory spaces that can be addressed on both host and device
// that don't otherwise have an intrinsic fencing mechanism will need to trigger the
// Note for the future: Memory spaces that can be addressed on both host and device
// that don't otherwise have an intrinsic fencing mechanism will need to trigger the
// "sync path"

#ifdef KOKKOS_ENABLE_CUDA
return std::is_same<typename t_dev::memory_space,Kokkos::CudaUVMSpace>::value || !memoryIsAliased();
#else
Expand All @@ -622,15 +616,15 @@ class WrappedDualView {
void throwIfHostViewAlive() const {
if (dualView.h_view.use_count() > dualView.d_view.use_count()) {
std::ostringstream msg;
msg << "Tpetra::Details::WrappedDualView (name = " << dualView.d_view.label()
msg << "Tpetra::Details::WrappedDualView (name = " << dualView.d_view.label()
<< "; host use_count = " << dualView.h_view.use_count()
<< "; device use_count = " << dualView.d_view.use_count() << "): "
<< "Cannot access data on device while a host view is alive";
throw std::runtime_error(msg.str());
}
}

void throwIfDeviceViewAlive() const {
void throwIfDeviceViewAlive() const {
if (dualView.d_view.use_count() > dualView.h_view.use_count()) {
std::ostringstream msg;
msg << "Tpetra::Details::WrappedDualView (name = " << dualView.d_view.label()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,52 @@ TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL( MultiVector, DeviceHostView, LO, GO, Scalar ,
TEST_ASSERT(gerr == 0);
}

TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL( MultiVector, TemplatedGetLocalView, LO, GO, Scalar , Node )
{
Teuchos::RCP<const Teuchos::Comm<int> > comm = Tpetra::getDefaultComm();
int np = comm->getSize();

using vector_t = Tpetra::Vector<Scalar,LO,GO,Node>;
using map_t = Tpetra::Map<LO,GO,Node>;
using WDV = typename vector_t::wrapped_dual_view_type;
using device_t = typename WDV::DeviceType;
using host_t = typename WDV::HostType;

const size_t nGlobalEntries = 8 * np;
Teuchos::Array<GO> myEntries(nGlobalEntries);

// Default one-to-one linear block map in Trilinos
Teuchos::RCP<const map_t> defaultMap =
rcp(new map_t(nGlobalEntries, 0, comm));

// Create vector
vector_t x(defaultMap);

{
// Check that getLocalView<device_t> produces a view with the correct device type
auto deviceView = x.template getLocalView<device_t>(Tpetra::Access::ReadWrite);
bool correctType = std::is_same<typename decltype(deviceView)::device_type, device_t>::value;
TEST_ASSERT(correctType);
}
constexpr bool needsSyncPath = !std::is_same<Kokkos::HostSpace, typename device_t::memory_space>::value;
if(needsSyncPath)
{
TEST_ASSERT(x.need_sync_host());
TEST_ASSERT(!x.need_sync_device());
}
// Assuming device/host device types aren't the same, make sure getting the host view also works
{
auto hostView = x.template getLocalView<host_t>(Tpetra::Access::ReadWrite);
bool correctType = std::is_same<typename decltype(hostView)::device_type, host_t>::value;
TEST_ASSERT(correctType);
if(needsSyncPath)
{
TEST_ASSERT(!x.need_sync_host());
TEST_ASSERT(x.need_sync_device());
}
}
}

//
// INSTANTIATIONS
//
Expand All @@ -331,7 +377,8 @@ TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL( MultiVector, DeviceHostView, LO, GO, Scalar ,
TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( MultiVector, HostView, LO, GO, SCALAR, NODE ) \
TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( MultiVector, DeviceView, LO, GO, SCALAR, NODE ) \
TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( MultiVector, HostDeviceView, LO, GO, SCALAR, NODE ) \
TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( MultiVector, DeviceHostView, LO, GO, SCALAR, NODE )
TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( MultiVector, DeviceHostView, LO, GO, SCALAR, NODE ) \
TEUCHOS_UNIT_TEST_TEMPLATE_4_INSTANT( MultiVector, TemplatedGetLocalView, LO, GO, SCALAR, NODE )

TPETRA_ETI_MANGLING_TYPEDEFS()

Expand Down

0 comments on commit f975d21

Please sign in to comment.