@@ -34,6 +34,7 @@
#include " gradient_compression.h"
#include " ../ndarray/ndarray_function.h"
#include " ../operator/tensor/sparse_retain-inl.h"
+#include " ./utils.h"
namespace mxnet {
namespace kvstore {
/* *
@@ -176,17 +177,17 @@ class CommCPU : public Comm {
reduce[i] = buf.copy_buf [i];
const_vars[i] = reduce[i].var ();
}
- auto result = buf.merged ;
+ NDArray result = buf.merged ;
+ Resource rsc = ResourceManager::Get ()->Request (result.ctx (),
+ ResourceRequest (ResourceRequest::kTempSpace ));
Engine::Get ()->PushAsync (
- [reduce, result, this ](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ [reduce, result, rsc, this ](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray out = result;
- Resource rsc = ResourceManager::Get ()->Request (rctx.ctx ,
- ResourceRequest (ResourceRequest::kTempSpace ));
is_serial_push_?
ReduceSumCPUExSerial (reduce, &out)
: mxnet::ndarray::ElementwiseSum (rctx.get_stream <cpu>(), rsc, reduce, &out);
on_complete ();
- }, Context::CPU (), const_vars, {result.var ()},
+ }, Context::CPU (), const_vars, {result.var (), rsc. var },
FnProperty::kCPUPrioritized , priority, PROFILER_MESSAGE (" KVStoreReduce" ));
}
@@ -491,11 +492,7 @@ class CommDevice : public Comm {
void Init (int key, const NDArrayStorageType stype, const TShape& shape,
int dtype = mshadow::kFloat32 ) override {
- if (stype == kDefaultStorage ) {
- sorted_key_attrs_.push_back (std::make_tuple (key, shape, dtype));
- } else {
- LOG (FATAL) << " storage type " << stype << " not implemented for device yet" ;
- }
+ sorted_key_attrs_.emplace_back (key, shape, dtype, stype);
}
void InitBuffersAndComm (const std::vector<NDArray>& src) {
@@ -528,26 +525,42 @@ class CommDevice : public Comm {
InitBuffersAndComm (src);
auto & buf = merge_buf_[key];
std::vector<NDArray> reduce (src.size ());
- CopyFromTo (src[0 ], &(buf.merged ), priority);
- reduce[0 ] = buf.merged ;
- if (buf.copy_buf .empty ()) {
- // TODO(mli) this results in large device memory usage for huge ndarray,
- // such as the largest fullc in VGG. consider to do segment reduce with
- // NDArray.Slice or gpu direct memory access. for the latter, we need to
- // remove some ctx check, and also it reduces 20% perf
- buf.copy_buf .resize (src.size ()-1 );
+ const NDArrayStorageType stype = buf.merged .storage_type ();
+ if (stype == kDefaultStorage ) {
+ CopyFromTo (src[0 ], &(buf.merged ), priority);
+ reduce[0 ] = buf.merged ;
+
+ if (buf.copy_buf .empty ()) {
+ // TODO(mli) this results in large device memory usage for huge ndarray,
+ // such as the largest fullc in VGG. consider to do segment reduce with
+ // NDArray.Slice or gpu direct memory access. for the latter, we need to
+ // remove some ctx check, and also it reduces 20% perf
+ buf.copy_buf .resize (src.size ()-1 );
+ for (size_t i = 0 ; i < src.size ()-1 ; ++i) {
+ buf.copy_buf [i] = NDArray (
+ buf.merged .shape (), buf.merged .ctx (), false , buf.merged .dtype ());
+ }
+ }
for (size_t i = 0 ; i < src.size ()-1 ; ++i) {
- buf.copy_buf [i] = NDArray (
- buf.merged .shape (), buf.merged .ctx (), false , buf.merged .dtype ());
+ CopyFromTo (src[i+1 ], &(buf.copy_buf [i]), priority);
+ reduce[i+1 ] = buf.copy_buf [i];
+ }
+ } else {
+ if (buf.copy_buf .empty ()) {
+ buf.copy_buf .resize (src.size ());
+ for (size_t j = 0 ; j < src.size (); ++j) {
+ buf.copy_buf [j] = NDArray (
+ buf.merged .storage_type (), buf.merged .shape (), buf.merged .ctx (),
+ true , buf.merged .dtype ());
+ }
+ }
+ for (size_t i = 0 ; i < src.size (); ++i) {
+ CopyFromTo (src[i], &(buf.copy_buf [i]), priority);
+ reduce[i] = buf.copy_buf [i];
}
}
- for (size_t i = 0 ; i < src.size ()-1 ; ++i) {
- CopyFromTo (src[i+1 ], &(buf.copy_buf [i]), priority);
- reduce[i+1 ] = buf.copy_buf [i];
- }
-
- ElementwiseSum (reduce, &buf.merged );
+ ElementwiseSum (reduce, &buf.merged , priority);
return buf.merged ;
}
@@ -621,7 +634,53 @@ class CommDevice : public Comm {
const std::vector<std::pair<NDArray*, NDArray>>& dst,
const bool use_copy,
const int priority) override {
- LOG (FATAL) << " Not implemented yet" ;
+ CHECK_EQ (src.storage_type (), kRowSparseStorage )
+ << " BroadcastRowSparse expects row-sparse src NDArray" ;
+
+ for (size_t i = 0 ; i < dst.size (); ++i) {
+ NDArray* out = dst[i].first ;
+ NDArray row_id = dst[i].second ;
+ if (use_copy) {
+ CopyFromTo (src, out, priority);
+ } else {
+ CHECK_EQ (out->storage_type (), kRowSparseStorage )
+ << " BroadcastRowSparse expects row_sparse dst NDArray" ;
+
+ const bool is_diff_ctx = out->ctx () != src.ctx ();
+ NDArray out_gpu = is_diff_ctx? NDArray (kRowSparseStorage , out->shape (),
+ src.ctx (), true , out->dtype (), out->aux_types ()) : *out;
+
+ CHECK_EQ (row_id.ctx (), src.ctx ())
+ << " row_id and src are expected to be on the same context" ;
+
+ Engine::Get ()->PushAsync ([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ NDArray temp = out_gpu;
+ const TBlob& indices = row_id.data ();
+ switch (temp.ctx ().dev_mask ()) {
+ case cpu::kDevMask : {
+ mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream <cpu>(),
+ src, indices, kWriteTo , &temp);
+ break ;
+ }
+#if MXNET_USE_CUDA
+ case gpu::kDevMask : {
+ mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream <gpu>(),
+ src, indices, kWriteTo , &temp);
+ // wait for GPU operations to complete
+ rctx.get_stream <gpu>()->Wait ();
+ break ;
+ }
+#endif
+ default : LOG (FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+ }
+ on_complete ();
+ }, out_gpu.ctx (), {src.var (), row_id.var ()}, {out_gpu.var ()},
+ FnProperty::kNormal , priority, PROFILER_MESSAGE (" KVStoreSparseRetain" ));
+ if (is_diff_ctx) {
+ CopyFromTo (out_gpu, out, priority);
+ }
+ }
+ }
}
private:
@@ -667,7 +726,7 @@ class CommDevice : public Comm {
#endif
}
- using KeyAttrs = std::tuple<int , TShape, int >;
+ using KeyAttrs = std::tuple<int , TShape, int , NDArrayStorageType >;
// try to allocate buff on device evenly
void InitMergeBuffer (const std::vector<Context>& devs) {
std::sort (sorted_key_attrs_.begin (), sorted_key_attrs_.end (), [](
@@ -680,9 +739,10 @@ class CommDevice : public Comm {
ctx_info[d.dev_id ] = std::make_pair (d, 0 );
}
for (size_t i = 0 ; i < sorted_key_attrs_.size (); ++i) {
- int key = std::get<0 >(sorted_key_attrs_[i]);
- TShape s = std::get<1 >(sorted_key_attrs_[i]);
- int type = std::get<2 >(sorted_key_attrs_[i]);
+ const int key = std::get<0 >(sorted_key_attrs_[i]);
+ const TShape& shape = std::get<1 >(sorted_key_attrs_[i]);
+ const int type = std::get<2 >(sorted_key_attrs_[i]);
+ const NDArrayStorageType stype = std::get<3 >(sorted_key_attrs_[i]);
auto & buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t >::max ();
@@ -693,8 +753,12 @@ class CommDevice : public Comm {
min_size = size;
}
}
- buf.merged = NDArray (s, ctx, false , type);
- ctx_info[ctx.dev_id ].second += s.Size ();
+ if (stype == kDefaultStorage ) {
+ buf.merged = NDArray (shape, ctx, false , type);
+ } else {
+ buf.merged = NDArray (stype, shape, ctx, true , type);
+ }
+ ctx_info[ctx.dev_id ].second += shape.Size ();
}
inited_ = true ;
}
0 comments on commit
786e376