Permalink
Browse files

rsp push and rsp pull for comm device, used in kvstore('device') (#8732)

* comm device for rsp push and pull

* update

* update test

* optimization for same row_ids

* add stream->wait

* remove using space

* fix race of rsc and extend ElementwiseSum to rsp cases

* add log fatal in ElementwiseSum

* direct copy rows if full rsp and put all outputs on ctx of src

* trigger

* fix

* simplify copy

* move check same rowids to utils and add test for same rowids case

* remove direct copy row by row

* fix checkSameRowid

* gpu unique impl draft

* unique

* update

* fix windows build

* trigger windows build

* support single rowid with multiple vals

* address comments

* check same row_ids and copy in fronted

* revise names and disable test for local kvstore
  • Loading branch information...
1 parent 171d717 commit 786e376651c7f6f9b05b7758d091b22a7a72ef55 @ZiyueHuang ZiyueHuang committed with eric-haibin-lin Jan 15, 2018
View
@@ -298,7 +298,8 @@ def pull(self, key, out=None, priority=0):
def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
""" Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \
- from the store with specified row_ids.
+ from the store with specified row_ids. When there is only one row_id, KVStoreRowSparsePull \
+ is invoked just once and the result is broadcast to all the rest of outputs.
`row_sparse_pull` is executed asynchronously after all previous
`pull`/`row_sparse_pull` calls and the last `push` call for the
@@ -349,7 +350,17 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
"""
assert(out is not None)
assert(row_ids is not None)
- ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
+ if isinstance(row_ids, NDArray):
+ row_ids = [row_ids]
+ assert(isinstance(row_ids, list)), \
+ "row_ids should be NDArray or list of NDArray"
+ first_out = out
+ # whether row_ids are the same
+ single_rowid = False
+ if len(row_ids) == 1 and isinstance(out, list):
+ single_rowid = True
+ first_out = [out[0]]
+ ckeys, cvals, use_str_keys = _ctype_key_value(key, first_out)
_, crow_ids, _ = _ctype_key_value(key, row_ids)
assert(len(crow_ids) == len(cvals)), \
"the number of row_ids doesn't match the number of values"
@@ -359,6 +370,11 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
else:
check_call(_LIB.MXKVStorePullRowSparse(
self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
+ # the result can be copied to other devices without invoking row_sparse_pull
+ # if the indices are the same
+ if single_rowid:
+ for out_i in out[1:]:
+ out[0].copyto(out_i)
def set_gradient_compression(self, compression_params):
""" Specifies type of low-bit quantization for gradient compression \
View
@@ -24,6 +24,7 @@
#include "./utils.h"
#include "../operator/tensor/cast_storage-inl.h"
+#include "../operator/tensor/sparse_retain-inl.h"
namespace mxnet {
namespace common {
@@ -34,6 +35,15 @@ void CheckFormatWrapper<cpu>(const RunContext &rctx, const NDArray &input,
CheckFormatImpl<cpu>(rctx, input, err_cpu, full_check);
}
+template<>
+void SparseRetainOpForwardRspWrapper<cpu>(mshadow::Stream<cpu> *s,
+ const NDArray& input_nd,
+ const TBlob& idx_data,
+ const OpReqType req,
+ NDArray* output_nd) {
+ mxnet::op::SparseRetainOpForwardRspImpl<cpu>(s, input_nd, idx_data, req, output_nd);
+}
+
template<>
void CastStorageDispatch<cpu>(const OpContext& ctx,
const NDArray& input,
View
@@ -24,6 +24,7 @@
#include "./utils.h"
#include "../operator/tensor/cast_storage-inl.h"
+#include "../operator/tensor/sparse_retain-inl.h"
namespace mxnet {
namespace common {
@@ -34,6 +35,15 @@ void CheckFormatWrapper<gpu>(const RunContext &rctx, const NDArray &input,
CheckFormatImpl<gpu>(rctx, input, err_cpu, full_check);
}
+template<>
+void SparseRetainOpForwardRspWrapper<gpu>(mshadow::Stream<gpu> *s,
+ const NDArray& input_nd,
+ const TBlob& idx_data,
+ const OpReqType req,
+ NDArray* output_nd) {
+ mxnet::op::SparseRetainOpForwardRspImpl<gpu>(s, input_nd, idx_data, req, output_nd);
+}
+
template<>
void CastStorageDispatch<gpu>(const OpContext& ctx,
const NDArray& input,
View
@@ -213,7 +213,18 @@ void CheckFormatImpl(const RunContext &rctx, const NDArray &input,
}
}
+/*! \brief Pick rows specified by user input index array from a row sparse ndarray
+ * and save them in the output sparse ndarray.
+ */
+template<typename xpu>
+void SparseRetainOpForwardRspWrapper(mshadow::Stream<xpu> *s,
+ const NDArray& input_nd,
+ const TBlob& idx_data,
+ const OpReqType req,
+ NDArray* output_nd);
+/* \brief Casts tensor storage type to the new type.
+ */
template<typename xpu>
void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output);
View
@@ -34,6 +34,7 @@
#include "gradient_compression.h"
#include "../ndarray/ndarray_function.h"
#include "../operator/tensor/sparse_retain-inl.h"
+#include "./utils.h"
namespace mxnet {
namespace kvstore {
/**
@@ -176,17 +177,17 @@ class CommCPU : public Comm {
reduce[i] = buf.copy_buf[i];
const_vars[i] = reduce[i].var();
}
- auto result = buf.merged;
+ NDArray result = buf.merged;
+ Resource rsc = ResourceManager::Get()->Request(result.ctx(),
+ ResourceRequest(ResourceRequest::kTempSpace));
Engine::Get()->PushAsync(
- [reduce, result, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ [reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray out = result;
- Resource rsc = ResourceManager::Get()->Request(rctx.ctx,
- ResourceRequest(ResourceRequest::kTempSpace));
is_serial_push_?
ReduceSumCPUExSerial(reduce, &out)
: mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
on_complete();
- }, Context::CPU(), const_vars, {result.var()},
+ }, Context::CPU(), const_vars, {result.var(), rsc.var},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
}
@@ -491,11 +492,7 @@ class CommDevice : public Comm {
void Init(int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32) override {
- if (stype == kDefaultStorage) {
- sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype));
- } else {
- LOG(FATAL) << "storage type " << stype << " not implemented for device yet";
- }
+ sorted_key_attrs_.emplace_back(key, shape, dtype, stype);
}
void InitBuffersAndComm(const std::vector<NDArray>& src) {
@@ -528,26 +525,42 @@ class CommDevice : public Comm {
InitBuffersAndComm(src);
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());
- CopyFromTo(src[0], &(buf.merged), priority);
- reduce[0] = buf.merged;
- if (buf.copy_buf.empty()) {
- // TODO(mli) this results in large device memory usage for huge ndarray,
- // such as the largest fullc in VGG. consider to do segment reduce with
- // NDArray.Slice or gpu direct memory access. for the latter, we need to
- // remove some ctx check, and also it reduces 20% perf
- buf.copy_buf.resize(src.size()-1);
+ const NDArrayStorageType stype = buf.merged.storage_type();
+ if (stype == kDefaultStorage) {
+ CopyFromTo(src[0], &(buf.merged), priority);
+ reduce[0] = buf.merged;
+
+ if (buf.copy_buf.empty()) {
+ // TODO(mli) this results in large device memory usage for huge ndarray,
+ // such as the largest fullc in VGG. consider to do segment reduce with
+ // NDArray.Slice or gpu direct memory access. for the latter, we need to
+ // remove some ctx check, and also it reduces 20% perf
+ buf.copy_buf.resize(src.size()-1);
+ for (size_t i = 0; i < src.size()-1; ++i) {
+ buf.copy_buf[i] = NDArray(
+ buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+ }
+ }
for (size_t i = 0; i < src.size()-1; ++i) {
- buf.copy_buf[i] = NDArray(
- buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
+ CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
+ reduce[i+1] = buf.copy_buf[i];
+ }
+ } else {
+ if (buf.copy_buf.empty()) {
+ buf.copy_buf.resize(src.size());
+ for (size_t j = 0; j < src.size(); ++j) {
+ buf.copy_buf[j] = NDArray(
+ buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
+ true, buf.merged.dtype());
+ }
+ }
+ for (size_t i = 0; i < src.size(); ++i) {
+ CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
+ reduce[i] = buf.copy_buf[i];
}
}
- for (size_t i = 0; i < src.size()-1; ++i) {
- CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
- reduce[i+1] = buf.copy_buf[i];
- }
-
- ElementwiseSum(reduce, &buf.merged);
+ ElementwiseSum(reduce, &buf.merged, priority);
return buf.merged;
}
@@ -621,7 +634,53 @@ class CommDevice : public Comm {
const std::vector<std::pair<NDArray*, NDArray>>& dst,
const bool use_copy,
const int priority) override {
- LOG(FATAL) << "Not implemented yet";
+ CHECK_EQ(src.storage_type(), kRowSparseStorage)
+ << "BroadcastRowSparse expects row-sparse src NDArray";
+
+ for (size_t i = 0; i < dst.size(); ++i) {
+ NDArray* out = dst[i].first;
+ NDArray row_id = dst[i].second;
+ if (use_copy) {
+ CopyFromTo(src, out, priority);
+ } else {
+ CHECK_EQ(out->storage_type(), kRowSparseStorage)
+ << "BroadcastRowSparse expects row_sparse dst NDArray";
+
+ const bool is_diff_ctx = out->ctx() != src.ctx();
+ NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
+ src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+
+ CHECK_EQ(row_id.ctx(), src.ctx())
+ << "row_id and src are expected to be on the same context";
+
+ Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ NDArray temp = out_gpu;
+ const TBlob& indices = row_id.data();
+ switch (temp.ctx().dev_mask()) {
+ case cpu::kDevMask: {
+ mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
+ src, indices, kWriteTo, &temp);
+ break;
+ }
+#if MXNET_USE_CUDA
+ case gpu::kDevMask: {
+ mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
+ src, indices, kWriteTo, &temp);
+ // wait for GPU operations to complete
+ rctx.get_stream<gpu>()->Wait();
+ break;
+ }
+#endif
+ default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+ }
+ on_complete();
+ }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+ if (is_diff_ctx) {
+ CopyFromTo(out_gpu, out, priority);
+ }
+ }
+ }
}
private:
@@ -667,7 +726,7 @@ class CommDevice : public Comm {
#endif
}
- using KeyAttrs = std::tuple<int, TShape, int>;
+ using KeyAttrs = std::tuple<int, TShape, int, NDArrayStorageType>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
@@ -680,9 +739,10 @@ class CommDevice : public Comm {
ctx_info[d.dev_id] = std::make_pair(d, 0);
}
for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
- int key = std::get<0>(sorted_key_attrs_[i]);
- TShape s = std::get<1>(sorted_key_attrs_[i]);
- int type = std::get<2>(sorted_key_attrs_[i]);
+ const int key = std::get<0>(sorted_key_attrs_[i]);
+ const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
+ const int type = std::get<2>(sorted_key_attrs_[i]);
+ const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
auto& buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t>::max();
@@ -693,8 +753,12 @@ class CommDevice : public Comm {
min_size = size;
}
}
- buf.merged = NDArray(s, ctx, false, type);
- ctx_info[ctx.dev_id].second += s.Size();
+ if (stype == kDefaultStorage) {
+ buf.merged = NDArray(shape, ctx, false, type);
+ } else {
+ buf.merged = NDArray(stype, shape, ctx, true, type);
+ }
+ ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
}
@@ -34,6 +34,7 @@
#include <functional>
#include <algorithm>
#include "./comm.h"
+#include "./utils.h"
namespace mxnet {
namespace kvstore {
@@ -223,12 +224,12 @@ class KVStoreLocal : public KVStore {
<< "PullRowSparse expects row_sparse src NDArray";
auto &target_val_rowids = grouped_val_rowids[i];
const size_t num_vals = target_val_rowids.size();
- for (size_t i = 0; i < num_vals; i++) {
- auto &row_id = target_val_rowids[i].second;
- NDArray indices(row_id.shape(), pinned_ctx_, false, mshadow::kInt64);
+ for (size_t j = 0; j < num_vals; j++) {
+ auto &row_id = target_val_rowids[j].second;
+ NDArray indices(row_id.shape(), local.ctx(), false, mshadow::kInt64);
CopyFromTo(row_id, &indices, 0);
Unique(&indices, priority);
- target_val_rowids[i].second = indices;
+ target_val_rowids[j].second = indices;
}
comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], false, priority);
}
@@ -354,29 +355,41 @@ class KVStoreLocal : public KVStore {
}
/**
- * \brief sort and get unique values. Output is expected to be on cpu_pinned context
+ * \brief sort and get unique values.
*/
- void Unique(NDArray *out, int priority = 0) {
- CHECK_EQ(out->ctx().dev_mask(), pinned_ctx_.dev_mask())
- << "Unique expects input with `pinned_ctx_`";
+ void Unique(NDArray *out, int priority) {
+ Resource rsc = ResourceManager::Get()->Request(out->ctx(),
+ ResourceRequest(ResourceRequest::kTempSpace));
Engine::Get()->PushAsync(
- [out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ [rsc, out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray *output = out;
CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs";
- const auto size = out->shape()[0];
- auto out_data = output->data();
- MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, IType, {
- auto dptr = output->data().dptr<IType>();
- common::ParallelSort(dptr, dptr + size, omp_get_max_threads());
- auto num_unique_idx = std::unique(dptr, dptr + size) - dptr;
- *output = output->Reshape(mshadow::Shape1(num_unique_idx));
- });
+ nnvm::dim_t size = out->shape()[0];
+ switch (out->ctx().dev_mask()) {
+ case cpu::kDevMask: {
+ mshadow::Stream<cpu> *s = rctx.get_stream<cpu>();
+ UniqueImpl(rsc, s, output, size);
+ break;
+ }
+ #if MXNET_USE_CUDA
+ case gpu::kDevMask: {
+ mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
+ UniqueImpl(rsc, s, output, size);
+ // wait for GPU operations to complete
+ s->Wait();
+ break;
+ }
+ #endif
+ default:
+ LOG(FATAL) << "GPU not enabled.";
+ }
on_complete();
- }, pinned_ctx_, {}, {out->var()},
- FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreUnique"));
+ }, out->ctx(), {}, {out->var(), rsc.var},
+ FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreUnique"));
out->WaitToRead();
}
+
/// reducer and broadcaster
Comm* comm_;
/// pinned context
Oops, something went wrong.

0 comments on commit 786e376

Please sign in to comment.