From 6b4d102070c532f0f4159a48894c78ec5080dae2 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Fri, 20 May 2016 09:45:23 -0400 Subject: [PATCH] NCCL integration --- CMakeLists.txt | 1 + Makefile | 6 + Makefile.config.example | 4 + cmake/ConfigGen.cmake | 4 + cmake/Dependencies.cmake | 8 + cmake/Modules/FindNCCL.cmake | 22 ++ cmake/Summary.cmake | 5 + cmake/Templates/caffe_config.h.in | 7 +- include/caffe/net.hpp | 12 + include/caffe/parallel.hpp | 41 +- include/caffe/sgd_solvers.hpp | 62 +-- include/caffe/solver.hpp | 12 +- include/caffe/solver_factory.hpp | 14 +- include/caffe/util/math_functions.hpp | 25 +- include/caffe/util/nccl.hpp | 37 ++ src/caffe/layer_factory.cpp | 1 - src/caffe/net.cpp | 36 ++ src/caffe/parallel.cpp | 369 +++++++----------- src/caffe/proto/caffe.proto | 5 +- src/caffe/solver.cpp | 47 ++- src/caffe/solvers/adagrad_solver.cpp | 1 - src/caffe/solvers/nesterov_solver.cpp | 1 - src/caffe/solvers/sgd_solver.cpp | 1 - src/caffe/test/test_gradient_based_solver.cpp | 27 +- src/caffe/util/math_functions.cu | 28 +- tools/caffe.cpp | 4 +- 26 files changed, 474 insertions(+), 306 deletions(-) create mode 100644 cmake/Modules/FindNCCL.cmake create mode 100644 include/caffe/util/nccl.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ab427c17f53..41dec0fe747 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,6 +37,7 @@ caffe_option(BUILD_python_layer "Build the Caffe Python layer" ON) caffe_option(USE_OPENCV "Build with OpenCV support" ON) caffe_option(USE_LEVELDB "Build with levelDB" ON) caffe_option(USE_LMDB "Build with lmdb" ON) +caffe_option(USE_NCCL "Build with NCCL Library for multi-GPU support" ON IF NOT CPU_ONLY) caffe_option(ALLOW_LMDB_NOLOCK "Allow MDB_NOLOCK when reading LMDB files (only if necessary)" OFF) # ---[ Dependencies diff --git a/Makefile b/Makefile index 5379be90685..9b26df570c5 100644 --- a/Makefile +++ b/Makefile @@ -334,6 +334,12 @@ ifeq ($(USE_CUDNN), 1) COMMON_FLAGS += -DUSE_CUDNN endif +# NCCL acceleration configuration +ifeq ($(USE_NCCL), 1) + LIBRARIES += nccl + COMMON_FLAGS += -DUSE_NCCL +endif + # configure IO libraries ifeq ($(USE_OPENCV), 1) COMMON_FLAGS += -DUSE_OPENCV diff --git a/Makefile.config.example b/Makefile.config.example index 4b7d035bc0d..d5f269f6b65 100644 --- a/Makefile.config.example +++ b/Makefile.config.example @@ -5,6 +5,10 @@ # cuDNN version 4 or higher is required. # USE_CUDNN := 1 +# NCCL acceleration switch (uncomment to build with NCCL) +# See https://github.com/NVIDIA/nccl +# USE_NCCL := 1 + # CPU-only switch (uncomment to build without GPU support). # cuDNN version 4 or higher is required. # CPU_ONLY := 1 diff --git a/cmake/ConfigGen.cmake b/cmake/ConfigGen.cmake index 056371110b5..11c4296630f 100644 --- a/cmake/ConfigGen.cmake +++ b/cmake/ConfigGen.cmake @@ -81,6 +81,10 @@ function(caffe_generate_export_configs) list(APPEND Caffe_DEFINITIONS -DUSE_MKL) endif() + if(USE_NCCL) + list(APPEND Caffe_DEFINITIONS -DUSE_NCCL) + endif() + configure_file("cmake/Templates/CaffeConfig.cmake.in" "${PROJECT_BINARY_DIR}/CaffeConfig.cmake" @ONLY) # Add targets to the build-tree export set diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index c7b6a17aa69..d79263ce08b 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -170,3 +170,11 @@ endif() if(BUILD_docs) find_package(Doxygen) endif() + +# ---[ NCCL +if(USE_NCCL) + add_definitions(-DUSE_NCCL) + find_package(NCCL REQUIRED) + include_directories(SYSTEM ${NCCL_INCLUDE}) + list(APPEND Caffe_LINKER_LIBS ${NCCL_LIBRARIES}) +endif() diff --git a/cmake/Modules/FindNCCL.cmake b/cmake/Modules/FindNCCL.cmake new file mode 100644 index 00000000000..b6f6c4bdab2 --- /dev/null +++ b/cmake/Modules/FindNCCL.cmake @@ -0,0 +1,22 @@ +# Find the NCCL libraries +# +# The following variables are optionally searched for defaults +# NCCL_ROOT_DIR: Base directory where all NCCL components are found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIR +# NCCL_LIBRARIES + +find_path(NCCL_INCLUDE_DIR NAMES nccl.h PATHS ${NCCL_ROOT_DIR}) + +find_library(NCCL_LIBRARIES NAMES nccl PATHS ${NCCL_ROOT_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARIES) + +if(NCCL_FOUND) + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIR}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARIES) +endif() + diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index ba025cf81e0..ec23cb11c09 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -146,6 +146,11 @@ function(caffe_print_configuration_summary) else() caffe_status(" cuDNN : Disabled") endif() + if(USE_NCCL) + caffe_status(" NCCL : " NCCL_FOUND THEN "Yes" ELSE "Not found") + else() + caffe_status(" NCCL : Disabled") + endif() caffe_status("") endif() if(HAVE_PYTHON) diff --git a/cmake/Templates/caffe_config.h.in b/cmake/Templates/caffe_config.h.in index 5e2bfa9b9e7..80514f3b1b4 100644 --- a/cmake/Templates/caffe_config.h.in +++ b/cmake/Templates/caffe_config.h.in @@ -4,15 +4,16 @@ /* Binaries directory */ #define BINARY_FOLDER "${PROJECT_BINARY_DIR}" -/* NVIDA Cuda */ +/* NVIDIA Cuda */ #cmakedefine HAVE_CUDA -/* NVIDA cuDNN */ +/* NVIDIA cuDNN */ #cmakedefine HAVE_CUDNN #cmakedefine USE_CUDNN -/* NVIDA cuDNN */ +/* NVIDIA cuDNN */ #cmakedefine CPU_ONLY +#cmakedefine USE_NCCL /* Test device */ #define CUDA_TEST_DEVICE ${CUDA_TEST_DEVICE} diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 0addb3c2a6d..ae004f2764f 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -14,6 +14,9 @@ namespace caffe { +template +class Solver; + /** * @brief Connects Layer%s together into a directed acyclic graph (DAG) * specified by a NetParameter. @@ -227,6 +230,11 @@ class Net { static bool StateMeetsRule(const NetState& state, const NetStateRule& rule, const string& layer_name); + /// @brief set a Solver for this net + void SetSolver(Solver* s) { + solver_ = s; + } + protected: // Helpers for Init. /// @brief Append a new top blob to the net. @@ -278,6 +286,8 @@ class Net { vector param_owners_; vector param_display_names_; vector > param_layer_indices_; + /// (layer, blob) -> param_id map + map, int> layer_index_params_; map param_names_index_; /// blob indices for the input and the output of the net vector net_input_blob_indices_; @@ -307,6 +317,8 @@ class Net { bool debug_info_; /// The root net that actually holds the shared layers in data parallelism const Net* const root_net_; + /// Pointer to the solver being used with this net + Solver* solver_; DISABLE_COPY_AND_ASSIGN(Net); }; diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp index 8d13df2fc64..1ff0dd7f3ff 100644 --- a/include/caffe/parallel.hpp +++ b/include/caffe/parallel.hpp @@ -14,6 +14,10 @@ #include "caffe/syncedmem.hpp" #include "caffe/util/blocking_queue.hpp" +#ifdef USE_NCCL +#include "caffe/util/nccl.hpp" +#endif + namespace caffe { // Represents a net parameters. Once a net is created, its parameter buffers can @@ -89,7 +93,7 @@ class P2PSync : public GPUParams, public Solver::Callback, public InternalThread { public: explicit P2PSync(shared_ptr > root_solver, - P2PSync* parent, const SolverParameter& param); + int rank, int nranks, const SolverParameter& param); virtual ~P2PSync(); inline const shared_ptr >& solver() const { @@ -104,18 +108,47 @@ class P2PSync : public GPUParams, public Solver::Callback, // Divide the batch size by the number of solvers static void divide_batch_size(NetParameter* net); +#ifdef USE_NCCL + // set the NCCL communicator + void setNCCLComm(ncclComm_t comm); +#endif + + public: + void allreduce(int param_id); + void syncCommStream(); + protected: + void SetupP2PAccess(); + void soft_barrier(); void on_start(); - void on_gradients_ready(); - + void allreduce(); + void syncAllStreams(); +#ifndef CPU_ONLY +#ifdef USE_NCCL + ncclComm_t getNCCLComm(); +#endif + cudaStream_t getCommStream(); +#endif void InternalThreadEntry(); + const int rank_; + const int nranks_; P2PSync* parent_; vector*> children_; +#ifndef CPU_ONLY +#ifdef USE_NCCL + std::vector nccl_comms_; +#endif + vector comm_streams_; +#endif BlockingQueue*> queue_; const int initial_iter_; - Dtype* parent_grads_; + shared_ptr > solver_; + const SolverParameter& params_; + + // per-parameter reduction enabled + bool per_parameter_reduce_; using Params::size_; using Params::data_; diff --git a/include/caffe/sgd_solvers.hpp b/include/caffe/sgd_solvers.hpp index 1fc52d87137..14536dc7c09 100644 --- a/include/caffe/sgd_solvers.hpp +++ b/include/caffe/sgd_solvers.hpp @@ -15,9 +15,11 @@ namespace caffe { template class SGDSolver : public Solver { public: - explicit SGDSolver(const SolverParameter& param) - : Solver(param) { PreSolve(); } - explicit SGDSolver(const string& param_file) + explicit SGDSolver(const SolverParameter& param, + Solver *root_solver = NULL) + : Solver(param, root_solver) { PreSolve(); } + explicit SGDSolver(const string& param_file, + Solver *root_solver = NULL) : Solver(param_file) { PreSolve(); } virtual inline const char* type() const { return "SGD"; } @@ -48,10 +50,12 @@ class SGDSolver : public Solver { template class NesterovSolver : public SGDSolver { public: - explicit NesterovSolver(const SolverParameter& param) - : SGDSolver(param) {} - explicit NesterovSolver(const string& param_file) - : SGDSolver(param_file) {} + explicit NesterovSolver(const SolverParameter& param, + Solver *root_solver = NULL) + : SGDSolver(param, root_solver) {} + explicit NesterovSolver(const string& param_file, + Solver *root_solver = NULL) + : SGDSolver(param_file, root_solver) {} virtual inline const char* type() const { return "Nesterov"; } protected: @@ -63,10 +67,14 @@ class NesterovSolver : public SGDSolver { template class AdaGradSolver : public SGDSolver { public: - explicit AdaGradSolver(const SolverParameter& param) - : SGDSolver(param) { constructor_sanity_check(); } - explicit AdaGradSolver(const string& param_file) - : SGDSolver(param_file) { constructor_sanity_check(); } + explicit AdaGradSolver(const SolverParameter& param, + Solver *root_solver = NULL) + : SGDSolver(param, root_solver) + { constructor_sanity_check(); } + explicit AdaGradSolver(const string& param_file, + Solver *root_solver = NULL) + : SGDSolver(param_file, root_solver) + { constructor_sanity_check(); } virtual inline const char* type() const { return "AdaGrad"; } protected: @@ -83,10 +91,14 @@ class AdaGradSolver : public SGDSolver { template class RMSPropSolver : public SGDSolver { public: - explicit RMSPropSolver(const SolverParameter& param) - : SGDSolver(param) { constructor_sanity_check(); } - explicit RMSPropSolver(const string& param_file) - : SGDSolver(param_file) { constructor_sanity_check(); } + explicit RMSPropSolver(const SolverParameter& param, + Solver *root_solver = NULL) + : SGDSolver(param, root_solver) + { constructor_sanity_check(); } + explicit RMSPropSolver(const string& param_file, + Solver *root_solver = NULL) + : SGDSolver(param_file, root_solver) + { constructor_sanity_check(); } virtual inline const char* type() const { return "RMSProp"; } protected: @@ -106,10 +118,12 @@ class RMSPropSolver : public SGDSolver { template class AdaDeltaSolver : public SGDSolver { public: - explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver(param) { AdaDeltaPreSolve(); } - explicit AdaDeltaSolver(const string& param_file) - : SGDSolver(param_file) { AdaDeltaPreSolve(); } + explicit AdaDeltaSolver(const SolverParameter& param, + Solver *root_solver = NULL) + : SGDSolver(param, root_solver) { AdaDeltaPreSolve(); } + explicit AdaDeltaSolver(const string& param_file, + Solver *root_solver = NULL) + : SGDSolver(param_file, root_solver) { AdaDeltaPreSolve(); } virtual inline const char* type() const { return "AdaDelta"; } protected: @@ -130,10 +144,12 @@ class AdaDeltaSolver : public SGDSolver { template class AdamSolver : public SGDSolver { public: - explicit AdamSolver(const SolverParameter& param) - : SGDSolver(param) { AdamPreSolve();} - explicit AdamSolver(const string& param_file) - : SGDSolver(param_file) { AdamPreSolve(); } + explicit AdamSolver(const SolverParameter& param, + Solver *root_solver = NULL) + : SGDSolver(param, root_solver) { AdamPreSolve();} + explicit AdamSolver(const string& param_file, + Solver *root_solver = NULL) + : SGDSolver(param_file, root_solver) { AdamPreSolve(); } virtual inline const char* type() const { return "Adam"; } protected: diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 38259edad9f..0c431436873 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -6,6 +6,7 @@ #include "caffe/net.hpp" #include "caffe/solver_factory.hpp" +#include "caffe/util/benchmark.hpp" namespace caffe { @@ -76,9 +77,14 @@ class Solver { // Invoked at specific points during an iteration class Callback { + public: + virtual void allreduce(int param_id) = 0; + virtual void syncCommStream() = 0; + protected: virtual void on_start() = 0; - virtual void on_gradients_ready() = 0; + virtual void allreduce() = 0; + virtual void soft_barrier() = 0; template friend class Solver; @@ -129,6 +135,10 @@ class Solver { // True iff a request to stop early was received. bool requested_early_exit_; + // Timing information + Timer iteration_timer_; + float iterations_last_; + DISABLE_COPY_AND_ASSIGN(Solver); }; diff --git a/include/caffe/solver_factory.hpp b/include/caffe/solver_factory.hpp index cfff721af40..c8cb8b7fca3 100644 --- a/include/caffe/solver_factory.hpp +++ b/include/caffe/solver_factory.hpp @@ -53,7 +53,7 @@ class Solver; template class SolverRegistry { public: - typedef Solver* (*Creator)(const SolverParameter&); + typedef Solver* (*Creator)(const SolverParameter&, Solver*); typedef std::map CreatorRegistry; static CreatorRegistry& Registry() { @@ -70,12 +70,13 @@ class SolverRegistry { } // Get a solver using a SolverParameter. - static Solver* CreateSolver(const SolverParameter& param) { + static Solver* CreateSolver(const SolverParameter& param, + Solver* root_solver = NULL) { const string& type = param.type(); CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type << " (known types: " << SolverTypeListString() << ")"; - return registry[type](param); + return registry[type](param, root_solver); } static vector SolverTypeList() { @@ -112,7 +113,7 @@ template class SolverRegisterer { public: SolverRegisterer(const string& type, - Solver* (*creator)(const SolverParameter&)) { + Solver* (*creator)(const SolverParameter&, Solver*)) { // LOG(INFO) << "Registering solver type: " << type; SolverRegistry::AddCreator(type, creator); } @@ -126,9 +127,10 @@ class SolverRegisterer { #define REGISTER_SOLVER_CLASS(type) \ template \ Solver* Creator_##type##Solver( \ - const SolverParameter& param) \ + const SolverParameter& param, \ + Solver* root_solver) \ { \ - return new type##Solver(param); \ + return new type##Solver(param, root_solver); \ } \ REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver) diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index 9cfd6b550ce..a75eff04e05 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -44,20 +44,20 @@ void caffe_cpu_eltwise_min(const int N, const Dtype alpha, const Dtype* X, const Dtype beta, Dtype* Y); template -void caffe_copy(const int N, const Dtype *X, Dtype *Y); +void caffe_copy(const int N, const Dtype* X, Dtype* Y); template -void caffe_set(const int N, const Dtype alpha, Dtype *X); +void caffe_set(const int N, const Dtype alpha, Dtype* X); inline void caffe_memset(const size_t N, const int alpha, void* X) { memset(X, alpha, N); // NOLINT(caffe/alt_fn) } template -void caffe_add_scalar(const int N, const Dtype alpha, Dtype *X); +void caffe_add_scalar(const int N, const Dtype alpha, Dtype* X); template -void caffe_scal(const int N, const Dtype alpha, Dtype *X); +void caffe_scal(const int N, const Dtype alpha, Dtype* X); template void caffe_sqr(const int N, const Dtype* a, Dtype* y); @@ -150,7 +150,7 @@ DEFINE_CAFFE_CPU_UNARY_FUNC(sgnbit, \ DEFINE_CAFFE_CPU_UNARY_FUNC(fabs, y[i] = std::fabs(x[i])); template -void caffe_cpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y); +void caffe_cpu_scale(const int n, const Dtype alpha, const Dtype* x, Dtype* y); #ifndef CPU_ONLY // GPU @@ -176,10 +176,10 @@ template void caffe_gpu_axpby(const int N, const Dtype alpha, const Dtype* X, const Dtype beta, Dtype* Y); -void caffe_gpu_memcpy(const size_t N, const void *X, void *Y); +void caffe_gpu_memcpy(const size_t N, const void* X, void* Y); template -void caffe_gpu_set(const int N, const Dtype alpha, Dtype *X); +void caffe_gpu_set(const int N, const Dtype alpha, Dtype* X); inline void caffe_gpu_memset(const size_t N, const int alpha, void* X) { #ifndef CPU_ONLY @@ -190,10 +190,15 @@ inline void caffe_gpu_memset(const size_t N, const int alpha, void* X) { } template -void caffe_gpu_add_scalar(const int N, const Dtype alpha, Dtype *X); +void caffe_gpu_add_scalar(const int N, const Dtype alpha, Dtype* X); template -void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X); +void caffe_gpu_scal(const int N, const Dtype alpha, Dtype* X); + +#ifndef CPU_ONLY +template +void caffe_gpu_scal(const int N, const Dtype alpha, Dtype* X, cudaStream_t str); +#endif template void caffe_gpu_add(const int N, const Dtype* a, const Dtype* b, Dtype* y); @@ -254,7 +259,7 @@ template void caffe_gpu_fabs(const int n, const Dtype* x, Dtype* y); template -void caffe_gpu_scale(const int n, const Dtype alpha, const Dtype *x, Dtype* y); +void caffe_gpu_scale(const int n, const Dtype alpha, const Dtype* x, Dtype* y); // y[i] = max(a * x[i], b * y[i]) template diff --git a/include/caffe/util/nccl.hpp b/include/caffe/util/nccl.hpp new file mode 100644 index 00000000000..e01fb7451e8 --- /dev/null +++ b/include/caffe/util/nccl.hpp @@ -0,0 +1,37 @@ +#ifndef CAFFE_UTIL_NCCL_H_ +#define CAFFE_UTIL_NCCL_H_ +#ifdef USE_NCCL + +#include + +#include "caffe/common.hpp" + +#define NCCL_CHECK(condition) \ +{ \ + ncclResult_t result = condition; \ + CHECK_EQ(result, ncclSuccess) << " " \ + << ncclGetErrorString(result); \ +} + +namespace caffe { + +namespace nccl { + +template class dataType; + +template<> class dataType { + public: + static const ncclDataType_t type = ncclFloat; +}; +template<> class dataType { + public: + static const ncclDataType_t type = ncclDouble; +}; + +} // namespace nccl + +} // namespace caffe + +#endif // end USE_NCCL + +#endif // CAFFE_UTIL_NCCL_H_ diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 713787d4227..012438eaf09 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -72,7 +72,6 @@ shared_ptr > GetConvolutionLayer( } else { LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; } - return shared_ptr >(); // [-Wreturn-type] } REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer); diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 1815293d65e..1aa52ef2e91 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -273,6 +273,13 @@ void Net::Init(const NetParameter& in_param) { layer_names_index_[layer_names_[layer_id]] = layer_id; } ShareWeights(); + + // invert param_layer_indices_ to give map of + // (level_id, local param_id) -> global param_id + for (int i = 0; i < param_layer_indices_.size(); ++i) { + layer_index_params_[param_layer_indices_[i]] = i; + } + debug_info_ = param.debug_info(); LOG_IF(INFO, Caffe::root_solver()) << "Network initialization done."; } @@ -589,6 +596,35 @@ void Net::BackwardFromTo(int start, int end) { layers_[i]->Backward( top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]); if (debug_info_) { BackwardDebugInfo(i); } + + // reduce gradients as soon as they are ready + if (Caffe::solver_count() > 1) { +#ifndef CPU_ONLY + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); +#endif + for (int j = 0; j < layers_[i]->blobs().size(); ++j) { + int param_id = layer_index_params_[make_pair(i, j)]; + + // check if we need to synchronize after reduction + bool need_sync = false; + // If param has been split, update owner and sync + if (param_owners_[param_id] >= 0) { + param_id = param_owners_[param_id]; + need_sync = true; + } + + for (int k = 0; k < solver_->callbacks().size(); ++k) { + solver_->callbacks()[k]->allreduce(param_id); + } + + // perform synchronization if needed + if (need_sync) { + for (int k = 0; k < solver_->callbacks().size(); ++k) { + solver_->callbacks()[k]->syncCommStream(); + } + } + } + } } } } diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp index 9b659561942..4b0effd6a7a 100644 --- a/src/caffe/parallel.cpp +++ b/src/caffe/parallel.cpp @@ -9,12 +9,18 @@ #include #include "boost/thread.hpp" +#include "boost/thread/latch.hpp" #include "caffe/caffe.hpp" #include "caffe/parallel.hpp" #include "caffe/util/gpu_memory.hpp" +#ifdef USE_NCCL +#include "caffe/util/nccl.hpp" +#endif namespace caffe { +shared_ptr bar; + enum Op { copy, replace_cpu, @@ -70,8 +76,8 @@ static size_t total_size(const vector*>& params) { template Params::Params(shared_ptr > root_solver) : size_(total_size(root_solver->net()->learnable_params())), - data_(), - diff_() { + data_(NULL), + diff_(NULL) { } template @@ -106,13 +112,13 @@ template GPUParams::~GPUParams() { #ifndef CPU_ONLY int initial_device; - cudaGetDevice(&initial_device); - cudaSetDevice(buffer_device_); + CUDA_CHECK(cudaGetDevice(&initial_device)); + CUDA_CHECK(cudaSetDevice(buffer_device_)); GPUMemoryManager::deallocate(data_); GPUMemoryManager::deallocate(diff_); data_ = NULL; diff_ = NULL; - cudaSetDevice(initial_device); + CUDA_CHECK(cudaSetDevice(initial_device)); #endif } @@ -126,80 +132,9 @@ void GPUParams::configure(Solver* solver) const { void DevicePair::compute(const vector devices, vector* pairs) { #ifndef CPU_ONLY - vector remaining(devices); - - // Depth for reduction tree - int remaining_depth = static_cast(ceil(log2(remaining.size()))); - - // Group GPUs by board - for (int d = 0; d < remaining_depth; ++d) { - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - cudaDeviceProp a, b; - CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); - CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); - if (a.isMultiGpuBoard && b.isMultiGpuBoard) { - if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; - } - } - } - } - } - ostringstream s; - for (int i = 0; i < remaining.size(); ++i) { - s << (i ? ", " : "") << remaining[i]; - } - DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str(); - - // Group by P2P accessibility - remaining_depth = ceil(log2(remaining.size())); - for (int d = 0; d < remaining_depth; ++d) { - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - int access; - CUDA_CHECK( - cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); - if (access) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; - } - } - } - } - s.str(""); - for (int i = 0; i < remaining.size(); ++i) { - s << (i ? ", " : "") << remaining[i]; - } - DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str(); - - // Group remaining - remaining_depth = ceil(log2(remaining.size())); - for (int d = 0; d < remaining_depth; ++d) { - for (int i = 0; i < remaining.size(); ++i) { - pairs->push_back(DevicePair(remaining[i], remaining[i + 1])); - DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" - << remaining[i + 1]; - remaining.erase(remaining.begin() + i + 1); - } - } - - // Should only be the parent node remaining - CHECK_EQ(remaining.size(), 1); - - pairs->insert(pairs->begin(), DevicePair(-1, remaining[0])); - - CHECK(pairs->size() == devices.size()); - for (int i = 0; i < pairs->size(); ++i) { - CHECK((*pairs)[i].parent() != (*pairs)[i].device()); - for (int j = i + 1; j < pairs->size(); ++j) { - CHECK((*pairs)[i].device() != (*pairs)[j].device()); - } + pairs->push_back(DevicePair(-1, devices[0])); + for (int i = 0; i < devices.size() - 1; ++i) { + pairs->push_back(DevicePair(devices[i], devices[i + 1])); } #else NO_GPU; @@ -210,71 +145,84 @@ void DevicePair::compute(const vector devices, vector* pairs) { template P2PSync::P2PSync(shared_ptr > root_solver, - P2PSync* parent, const SolverParameter& param) + int rank, int nranks, const SolverParameter& param) : GPUParams(root_solver, param.device_id()), - parent_(parent), + rank_(rank), + nranks_(nranks), + parent_(), children_(), queue_(), initial_iter_(root_solver->iter()), - solver_() { + solver_(), + params_(param), + per_parameter_reduce_(param.per_parameter_reduce()) { +#ifndef USE_NCCL + LOG(FATAL) << "USE_NCCL := 1 must be specified for multi-GPU"; +#endif + #ifndef CPU_ONLY int initial_device; CUDA_CHECK(cudaGetDevice(&initial_device)); const int self = param.device_id(); CUDA_CHECK(cudaSetDevice(self)); - if (parent == NULL) { + if (rank == 0) { solver_ = root_solver; } else { Caffe::set_root_solver(false); - solver_.reset(new WorkerSolver(param, root_solver.get())); + solver_.reset(caffe::SolverRegistry::CreateSolver(param, + root_solver.get())); Caffe::set_root_solver(true); } this->configure(solver_.get()); solver_->add_callback(this); - if (parent) { - // Enable p2p access between devices - const int peer = parent->solver_->param().device_id(); - int access; - CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); - if (access) { - CUDA_CHECK(cudaDeviceEnablePeerAccess(peer, 0)); - } else { - LOG(INFO)<< "GPU " << self << " does not have p2p access to GPU " << peer; - } - // Allocate receiving buffer on parent - CUDA_CHECK(cudaSetDevice(peer)); - GPUMemoryManager::allocate(reinterpret_cast(&parent_grads_), - size_ * sizeof(Dtype)); - CUDA_CHECK(cudaSetDevice(self)); - } +#if defined(USE_NCCL) + nccl_comms_.resize(1); +#endif + comm_streams_.resize(1); + CUDA_CHECK(cudaStreamCreateWithFlags(&comm_streams_[0], + cudaStreamNonBlocking)); + CHECK_GT(comm_streams_.size(), 0); CUDA_CHECK(cudaSetDevice(initial_device)); #else NO_GPU; #endif } +#ifndef CPU_ONLY +#ifdef USE_NCCL +template +void P2PSync::setNCCLComm(ncclComm_t comm) { + this->nccl_comms_[0] = comm; +} + +template +ncclComm_t P2PSync::getNCCLComm() { + return this->nccl_comms_[0]; +} +#endif + +template +cudaStream_t P2PSync::getCommStream() { + return this->comm_streams_[0]; +} +#endif + template P2PSync::~P2PSync() { #ifndef CPU_ONLY - if (parent_) { - int initial_device; - CUDA_CHECK(cudaGetDevice(&initial_device)); - const int self = solver_->param().device_id(); - const int peer = parent_->solver_->param().device_id(); - CUDA_CHECK(cudaSetDevice(peer)); - GPUMemoryManager::deallocate(parent_grads_); - parent_grads_ = NULL; - int access; - cudaSetDevice(self); - CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); - if (access) { - CUDA_CHECK(cudaDeviceDisablePeerAccess(peer)); - } - CUDA_CHECK(cudaSetDevice(initial_device)); + for (int i = 0; i < comm_streams_.size(); ++i) { + cudaStreamDestroy(comm_streams_[i]); + } + +#ifdef USE_NCCL + for (int i = 0; i < nccl_comms_.size(); ++i) { + ncclCommDestroy(nccl_comms_[i]); } +#endif // USE_NCCL + #endif } @@ -295,143 +243,98 @@ void P2PSync::InternalThreadEntry() { } template -void P2PSync::on_start() { +void P2PSync::soft_barrier() { #ifndef CPU_ONLY -#ifdef DEBUG - int device; - CUDA_CHECK(cudaGetDevice(&device)); - CHECK(device == solver_->param().device_id()); -#else -// CHECK(false); + // CPU barrier to avoid busy-polling on the GPU. + bar->wait(); #endif +} - // Wait for update from parent - if (parent_) { - P2PSync *parent = queue_.pop(); - CHECK(parent == parent_); - } - - // Update children - for (int i = children_.size() - 1; i >= 0; i--) { - Dtype* src = data_; - Dtype* dst = children_[i]->data_; - -#ifdef DEBUG - cudaPointerAttributes attributes; - CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); - CHECK(attributes.device == device); - CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); - CHECK(attributes.device == children_[i]->solver_->param().device_id()); +template +void P2PSync::on_start() { +#ifndef CPU_ONLY +#ifdef USE_NCCL + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + NCCL_CHECK(ncclBcast(data_, size_, nccl::dataType::type, 0, + getNCCLComm(), getCommStream())); + CUDA_CHECK(cudaStreamSynchronize(getCommStream())); +#endif // USE_NCCL #endif +} - CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), - cudaMemcpyDeviceToDevice, cudaStreamDefault)); +template +void P2PSync::allreduce() { +#ifndef CPU_ONLY +#ifdef USE_NCCL + // only reduce if we haven't in the bwd pass + if (!per_parameter_reduce_) { + bar->wait(); CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); - children_[i]->queue_.push(this); + NCCL_CHECK(ncclAllReduce(diff_, diff_, size_, nccl::dataType::type, + ncclSum, getNCCLComm(), getCommStream())); + caffe_gpu_scal(size_, (Dtype)1.0 / Caffe::solver_count(), diff_, + getCommStream()); } -#endif +#endif // USE_NCCL +#endif // CPU_ONLY } template -void P2PSync::on_gradients_ready() { +void P2PSync::allreduce(int param_id) { #ifndef CPU_ONLY -#ifdef DEBUG - int device; - CUDA_CHECK(cudaGetDevice(&device)); - CHECK(device == solver_->param().device_id()); -#endif - - // Sum children gradients as they appear in the queue - for (int i = 0; i < children_.size(); ++i) { - P2PSync *child = queue_.pop(); - Dtype* src = child->parent_grads_; - Dtype* dst = diff_; - -#ifdef DEBUG - bool ok = false; - for (int j = 0; j < children_.size(); ++j) { - if (child == children_[j]) { - ok = true; - } - } - CHECK(ok); - cudaPointerAttributes attributes; - CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); - CHECK(attributes.device == device); - CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); - CHECK(attributes.device == device); -#endif - - caffe_gpu_add(size_, src, dst, dst); +#ifdef USE_NCCL + // reduce aynchronously in the bwd path + if (per_parameter_reduce_) { + bar->wait(); + const vector > >& params = solver_->net()->params(); + NCCL_CHECK(ncclAllReduce(params[param_id]->gpu_diff(), + params[param_id]->mutable_gpu_diff(), + params[param_id]->count(), + nccl::dataType::type, + ncclSum, + getNCCLComm(), + getCommStream())); + caffe_gpu_scal(params[param_id]->count(), (Dtype)1. / Caffe::solver_count(), + params[param_id]->mutable_gpu_diff(), getCommStream()); } +#endif // USE_NCCL +#endif // CPU_ONLY +} - // Send gradients to parent - if (parent_) { - Dtype* src = diff_; - Dtype* dst = parent_grads_; - -#ifdef DEBUG - cudaPointerAttributes attributes; - CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); - CHECK(attributes.device == device); - CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); - CHECK(attributes.device == parent_->solver_->param().device_id()); -#endif - - CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // - cudaMemcpyDeviceToDevice, cudaStreamDefault)); - CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); - parent_->queue_.push(this); - } else { - // Loss functions divide gradients by the batch size, so to compensate - // for split batch, the root solver divides by number of solvers. - caffe_gpu_scal(size_, Dtype(1.0 / Caffe::solver_count()), diff_); - } +template +void P2PSync::syncCommStream() { +#ifndef CPU_ONLY + CUDA_CHECK(cudaStreamSynchronize(comm_streams_[0])); #endif } template -void P2PSync::Prepare(const vector& gpus, - vector > >* syncs) { - // Pair devices for map-reduce synchronization - vector pairs; - DevicePair::compute(gpus, &pairs); - ostringstream s; - for (int i = 1; i < pairs.size(); ++i) { - s << (i == 1 ? "" : ", ") << pairs[i].parent() << ":" << pairs[i].device(); +void P2PSync::Run(const vector& gpus) { + vector > > syncs(gpus.size()); + bar.reset(new boost::barrier(gpus.size())); + SolverParameter param = solver_->param(); + for (int i = 1; i < gpus.size(); ++i) { + param.set_device_id(gpus[i]); + syncs[i].reset(new P2PSync(solver_, i, gpus.size(), param)); } - LOG(INFO)<< "GPUs pairs " << s.str(); - - SolverParameter param(solver_->param()); - - // Build the GPU tree by finding the parent for each solver - for (int attempts = 0; attempts < pairs.size(); ++attempts) { - for (int i = 1; i < pairs.size(); ++i) { - if (!syncs->at(i).get()) { - P2PSync* parent = NULL; - for (int j = 0; j < syncs->size(); ++j) { - P2PSync* sync = j == 0 ? this : syncs->at(j).get(); - if (sync) { - const SolverParameter& p = sync->solver()->param(); - if (p.device_id() == pairs[i].parent()) { - parent = sync; - } - } - } - if (parent) { - param.set_device_id(pairs[i].device()); - syncs->at(i).reset(new P2PSync(solver_, parent, param)); - parent->children_.push_back((P2PSync*) syncs->at(i).get()); - } - } - } +#ifdef USE_NCCL + ncclComm_t *comms = new ncclComm_t[nranks_]; + int *gpu_list = new int[nranks_]; + for (int i = 0; i < nranks_; ++i) { + gpu_list[i] = gpus[i]; } -} + NCCL_CHECK(ncclCommInitAll(comms, nranks_, gpu_list)); -template -void P2PSync::Run(const vector& gpus) { - vector > > syncs(gpus.size()); - Prepare(gpus, &syncs); + this->setNCCLComm(comms[0]); + + for (int i = 1; i < nranks_; ++i) { + syncs[i]->setNCCLComm(comms[i]); + } + delete[] comms; + delete[] gpu_list; +#else + LOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL"; +#endif // USE_NCCL LOG(INFO)<< "Starting Optimization"; @@ -440,7 +343,7 @@ void P2PSync::Run(const vector& gpus) { } // Run root solver on current thread - solver_->Solve(); + this->solver_->Solve(); for (int i = 1; i < syncs.size(); ++i) { syncs[i]->StopInternalThread(); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 106c787de74..bfc3669abaa 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 41 (last added: type) +// SolverParameter next available ID: 42 (last added: per_parameter_reduce) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -239,6 +239,9 @@ message SolverParameter { } // DEPRECATED: use type instead of solver_type optional SolverType solver_type = 30 [default = SGD]; + + // Reduce parameter gradients individually + optional bool per_parameter_reduce = 41 [default = true]; } // A message that stores the solver snapshots diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index ece3913e88a..a2d7b951ea6 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -28,14 +28,14 @@ SolverAction::Enum Solver::GetRequestedAction() { template Solver::Solver(const SolverParameter& param, const Solver* root_solver) : net_(), callbacks_(), root_solver_(root_solver), - requested_early_exit_(false) { + requested_early_exit_(false), iteration_timer_(), iterations_last_() { Init(param); } template Solver::Solver(const string& param_file, const Solver* root_solver) : net_(), callbacks_(), root_solver_(root_solver), - requested_early_exit_(false) { + requested_early_exit_(false), iteration_timer_(), iterations_last_() { SolverParameter param; ReadSolverParamsFromTextFileOrDie(param_file, ¶m); Init(param); @@ -197,22 +197,34 @@ void Solver::Step(int iters) { int average_loss = this->param_.average_loss(); losses_.clear(); smoothed_loss_ = 0; + iteration_timer_.Start(); + + for (int i = 0; i < callbacks_.size(); ++i) { + // we need to sync all threads before starting, otherwise some cuda init, + // malloc or other cuda stuff could interlock with in-loop cuda GPU sync + // called in on_start. + callbacks_[i]->soft_barrier(); + // Initial bcast of parameters + callbacks_[i]->on_start(); + } + + net_->SetSolver(this); while (iter_ < stop_iter) { // zero-init the params net_->ClearParamDiffs(); if (param_.test_interval() && iter_ % param_.test_interval() == 0 - && (iter_ > 0 || param_.test_initialization()) - && Caffe::root_solver()) { - TestAll(); + && (iter_ > 0 || param_.test_initialization())) { + if (Caffe::root_solver()) { + TestAll(); + } if (requested_early_exit_) { // Break out of the while loop because stop was requested while testing. break; } - } - - for (int i = 0; i < callbacks_.size(); ++i) { - callbacks_[i]->on_start(); + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->soft_barrier(); + } } const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); @@ -225,8 +237,13 @@ void Solver::Step(int iters) { // average the loss across iterations for smoothed reporting UpdateSmoothedLoss(loss, start_iter, average_loss); if (display) { + float lapse = iteration_timer_.Seconds(); + float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1); LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ - << ", loss = " << smoothed_loss_; + << " (" << per_s << " iter/s, " << lapse << "s/" + << param_.display() <<" iter), loss = " << smoothed_loss_; + iteration_timer_.Start(); + iterations_last_ = iter_; const vector*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { @@ -247,9 +264,17 @@ void Solver::Step(int iters) { } } } +#ifndef CPU_ONLY + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); +#endif for (int i = 0; i < callbacks_.size(); ++i) { - callbacks_[i]->on_gradients_ready(); + callbacks_[i]->allreduce(); } + // Make sure all gradient exchanges have finished in per-level scheme + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->syncCommStream(); + } + ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate diff --git a/src/caffe/solvers/adagrad_solver.cpp b/src/caffe/solvers/adagrad_solver.cpp index e78eadca141..d8107e1e623 100644 --- a/src/caffe/solvers/adagrad_solver.cpp +++ b/src/caffe/solvers/adagrad_solver.cpp @@ -12,7 +12,6 @@ void adagrad_update_gpu(int N, Dtype* g, Dtype* h, Dtype delta, template void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { - CHECK(Caffe::root_solver()); const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype delta = this->param_.delta(); diff --git a/src/caffe/solvers/nesterov_solver.cpp b/src/caffe/solvers/nesterov_solver.cpp index 23ab2d4369a..7c1fac1f884 100644 --- a/src/caffe/solvers/nesterov_solver.cpp +++ b/src/caffe/solvers/nesterov_solver.cpp @@ -12,7 +12,6 @@ void nesterov_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum, template void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { - CHECK(Caffe::root_solver()); const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.momentum(); diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp index f30f316d1a0..09ddaaff915 100644 --- a/src/caffe/solvers/sgd_solver.cpp +++ b/src/caffe/solvers/sgd_solver.cpp @@ -100,7 +100,6 @@ void SGDSolver::ClipGradients() { template void SGDSolver::ApplyUpdate() { - CHECK(Caffe::root_solver()); Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index efd7a7fbb8c..ba4efa4bb74 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -36,7 +36,6 @@ class GradientBasedSolverTest : public MultiDeviceTest { string snapshot_prefix_; shared_ptr > solver_; - shared_ptr > sync_; int seed_; // Dimensions are determined by generate_sample_data.py // TODO this is brittle and the hdf5 file should be checked instead. @@ -47,11 +46,13 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Test data: check out generate_sample_data.py in the same directory. string* input_file_; + virtual const char* solver_type() = 0; virtual void InitSolver(const SolverParameter& param) = 0; virtual void InitSolverFromProtoString(const string& proto) { SolverParameter param; CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + param.set_type(solver_type()); // Set the solver_mode according to current Caffe::mode. switch (Caffe::mode()) { case Caffe::CPU: @@ -202,9 +203,9 @@ class GradientBasedSolverTest : public MultiDeviceTest { gpus.push_back(i); } Caffe::set_solver_count(gpus.size()); - this->sync_.reset(new P2PSync( - this->solver_, NULL, this->solver_->param())); - this->sync_->Run(gpus); + P2PSync sync(this->solver_, 0, gpus.size(), + this->solver_->param()); + sync.Run(gpus); Caffe::set_solver_count(1); } if (snapshot) { @@ -574,6 +575,9 @@ class SGDSolverTest : public GradientBasedSolverTest { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new SGDSolver(param)); } + virtual const char *solver_type() { + return "SGD"; + } }; TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices); @@ -710,6 +714,9 @@ class AdaGradSolverTest : public GradientBasedSolverTest { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new AdaGradSolver(param)); } + virtual const char *solver_type() { + return "AdaGrad"; + } }; TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices); @@ -810,6 +817,9 @@ class NesterovSolverTest : public GradientBasedSolverTest { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new NesterovSolver(param)); } + virtual const char *solver_type() { + return "Nesterov"; + } }; TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices); @@ -943,6 +953,9 @@ class AdaDeltaSolverTest : public GradientBasedSolverTest { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new AdaDeltaSolver(param)); } + virtual const char *solver_type() { + return "AdaDelta"; + } }; TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); @@ -1077,6 +1090,9 @@ class AdamSolverTest : public GradientBasedSolverTest { new_param.set_momentum2(momentum2); this->solver_.reset(new AdamSolver(new_param)); } + virtual const char *solver_type() { + return "Adam"; + } }; TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices); @@ -1177,6 +1193,9 @@ class RMSPropSolverTest : public GradientBasedSolverTest { new_param.set_rms_decay(rms_decay); this->solver_.reset(new RMSPropSolver(new_param)); } + virtual const char *solver_type() { + return "RMSProp"; + } }; TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices); diff --git a/src/caffe/util/math_functions.cu b/src/caffe/util/math_functions.cu index 0074601f4da..cfdcac868f1 100644 --- a/src/caffe/util/math_functions.cu +++ b/src/caffe/util/math_functions.cu @@ -82,15 +82,35 @@ void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) { } template <> -void caffe_gpu_scal(const int N, const float alpha, float *X) { +void caffe_gpu_scal(const int N, const float alpha, float* X) { CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1)); } template <> -void caffe_gpu_scal(const int N, const double alpha, double *X) { +void caffe_gpu_scal(const int N, const double alpha, double* X) { CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1)); } +template <> +void caffe_gpu_scal(const int N, const float alpha, float* X, + cudaStream_t str) { + cudaStream_t initial_stream; + CUBLAS_CHECK(cublasGetStream(Caffe::cublas_handle(), &initial_stream)); + CUBLAS_CHECK(cublasSetStream(Caffe::cublas_handle(), str)); + CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1)); + CUBLAS_CHECK(cublasSetStream(Caffe::cublas_handle(), initial_stream)); +} + +template <> +void caffe_gpu_scal(const int N, const double alpha, double* X, + cudaStream_t str) { + cudaStream_t initial_stream; + CUBLAS_CHECK(cublasGetStream(Caffe::cublas_handle(), &initial_stream)); + CUBLAS_CHECK(cublasSetStream(Caffe::cublas_handle(), str)); + CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1)); + CUBLAS_CHECK(cublasSetStream(Caffe::cublas_handle(), initial_stream)); +} + template <> void caffe_gpu_axpby(const int N, const float alpha, const float* X, const float beta, float* Y) { @@ -128,14 +148,14 @@ void caffe_gpu_asum(const int n, const double* x, double* y) { } template <> -void caffe_gpu_scale(const int n, const float alpha, const float *x, +void caffe_gpu_scale(const int n, const float alpha, const float* x, float* y) { CUBLAS_CHECK(cublasScopy(Caffe::cublas_handle(), n, x, 1, y, 1)); CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), n, &alpha, y, 1)); } template <> -void caffe_gpu_scale(const int n, const double alpha, const double *x, +void caffe_gpu_scale(const int n, const double alpha, const double* x, double* y) { CUBLAS_CHECK(cublasDcopy(Caffe::cublas_handle(), n, x, 1, y, 1)); CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), n, &alpha, y, 1)); diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 1f599e9b5c7..305cfda54b8 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -224,7 +224,7 @@ int train() { } if (gpus.size() > 1) { - caffe::P2PSync sync(solver, NULL, solver->param()); + caffe::P2PSync sync(solver, 0, gpus.size(), solver->param()); sync.Run(gpus); } else { LOG(INFO) << "Starting Optimization"; @@ -424,7 +424,7 @@ RegisterBrewFunction(time); int main(int argc, char** argv) { // Print output to stderr (while still logging). - FLAGS_alsologtostderr = 1; + FLAGS_alsologtostderr = true; // Set version gflags::SetVersionString(STRINGIZE2(CAFFE_VERSION)); // Usage message.