Permalink
Browse files

fix random generator: do not gen seed each time (#9119)

* add tests for distribution generators

fix lint

fix lint

fix typo

fix docstring

fix docstring

* [Bugfix] fix random generator: do not gen seed each time

* gen samplers on gpu for test_softmax

* fix test cases

* remove unnecessary prints

* refactor RandGenerator

* get_native_random -> get_parallel_random

* revise test cases + remove dependency of scipy

* raise warning
  • Loading branch information...
1 parent 2586c66 commit 34a51959bd2bc21c6cfa93f5fe0e079ef5268261 @yzhliu yzhliu committed with piiswrong Dec 28, 2017
@@ -21,7 +21,7 @@
blacklist = [
'Windows.h', 'cublas_v2.h', 'cuda/tensor_gpu-inl.cuh',
- 'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h',
+ 'cuda_runtime.h', 'cudnn.h', 'cudnn_lrn-inl.h', 'curand.h', 'curand_kernel.h',
'glog/logging.h', 'io/azure_filesys.h', 'io/hdfs_filesys.h', 'io/s3_filesys.h',
'kvstore_dist.h', 'mach/clock.h', 'mach/mach.h',
'malloc.h', 'mkl.h', 'mkl_cblas.h', 'mkl_vsl.h', 'mkl_vsl_functions.h',
View
@@ -28,6 +28,7 @@
#include <dmlc/logging.h>
#include "./base.h"
#include "./engine.h"
+#include "../../src/common/random_generator.h"
namespace mxnet {
@@ -40,7 +41,9 @@ struct ResourceRequest {
/*! \brief mshadow::Random<xpu> object */
kRandom,
/*! \brief A dynamic temp space that can be arbitrary size */
- kTempSpace
+ kTempSpace,
+ /*! \brief common::RandGenerator<xpu> object, which can be used in GPU kernel functions */
+ kParallelRandom
};
/*! \brief type of resources */
Type type;
@@ -89,6 +92,19 @@ struct Resource {
ret->set_stream(stream);
return ret;
}
+
+ /*!
+ * \brief Get parallel random number generator.
+ * \tparam xpu the device type of random number generator.
+ * \tparam DType the return type.
+ * \return the native random number generator. for gpu, it is allocated on global memory.
+ */
+ template<typename xpu, typename DType>
+ inline common::random::RandGenerator<xpu, DType>* get_parallel_random() const {
+ CHECK_EQ(req.type, ResourceRequest::kParallelRandom);
+ return static_cast<common::random::RandGenerator<xpu, DType>*>(ptr_);
+ }
+
/*!
* \brief Get space requested as mshadow Tensor.
* The caller can request arbitrary size.
View
@@ -82,7 +82,7 @@ class Storage {
virtual void SharedIncrementRefCount(Handle handle) = 0;
/*!
* \brief Free storage.
- * \param handle Handle struect.
+ * \param handle Handle struct.
*/
virtual void Free(Handle handle) = 0;
/*!
@@ -87,7 +87,7 @@ sub check_with_device
]
},
);
- my $shape = [100, 100];
+ my $shape = [1000, 1000];
for my $symbdic (@symbols)
{
my $name = $symbdic->{name};
@@ -648,7 +648,8 @@ def update(self, index, weight, grad, state):
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)
weight[:] += - lr/2 * (grad + wd * weight) + normal(0, math.sqrt(lr),
- weight.shape, weight.context)
+ shape=weight.shape,
+ ctx=weight.context)
@register # pylint: disable=invalid-name
View
@@ -34,6 +34,10 @@
import numpy as np
import numpy.testing as npt
import numpy.random as rnd
+try:
+ import scipy.stats as ss
+except ImportError:
+ ss = None
try:
import requests
except ImportError:
@@ -1593,3 +1597,225 @@ def next(self):
The data of next batch.
"""
return self.the_batch
+
+def gen_buckets_probs_with_ppf(ppf, nbuckets):
+ """Generate the buckets and probabilities for chi_square test when the ppf (Quantile function)
+ is specified.
+
+ Parameters
+ ----------
+ ppf : function
+ The Quantile function that takes a probability and maps it back to a value.
+ It's the inverse of the cdf function
+ nbuckets : int
+ size of the buckets
+
+ Returns
+ -------
+ buckets : list of tuple
+ The generated buckets
+ probs : list
+ The generate probabilities
+ """
+ assert nbuckets > 0
+ probs = [1.0 / nbuckets for _ in range(nbuckets)]
+ buckets = [(ppf(i / float(nbuckets)), ppf((i + 1) / float(nbuckets))) for i in range(nbuckets)]
+ return buckets, probs
+
+def mean_check(generator, mu, sigma, nsamples=1000000):
+ """Test the generator by matching the mean.
+
+ We test the sample mean by checking if it falls inside the range
+ (mu - 3 * sigma / sqrt(n), mu + 3 * sigma / sqrt(n))
+
+ References::
+
+ @incollection{goucher2009beautiful,
+ title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},
+ author={Goucher, Adam and Riley, Tim},
+ year={2009},
+ chapter=10
+ }
+
+ Examples::
+
+ generator = lambda x: np.random.normal(0, 1.0, size=x)
+ mean_check_ret = mean_check(generator, 0, 1.0)
+
+ Parameters
+ ----------
+ generator : function
+ The generator function. It's expected to generate N i.i.d samples by calling generator(N).
+ mu : float
+ sigma : float
+ nsamples : int
+
+ Returns
+ -------
+ ret : bool
+ Whether the mean test succeeds
+ """
+ samples = np.array(generator(nsamples))
+ sample_mean = samples.mean()
+ ret = (sample_mean > mu - 3 * sigma / np.sqrt(nsamples)) and\
+ (sample_mean < mu + 3 * sigma / np.sqrt(nsamples))
+ return ret
+
+def var_check(generator, sigma, nsamples=1000000):
+ """Test the generator by matching the variance.
+ It will need a large number of samples and is not recommended to use
+
+ We test the sample variance by checking if it falls inside the range
+ (sigma^2 - 3 * sqrt(2 * sigma^4 / (n-1)), sigma^2 + 3 * sqrt(2 * sigma^4 / (n-1)))
+
+ References::
+
+ @incollection{goucher2009beautiful,
+ title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},
+ author={Goucher, Adam and Riley, Tim},
+ year={2009},
+ chapter=10
+ }
+
+ Examples::
+
+ generator = lambda x: np.random.normal(0, 1.0, size=x)
+ var_check_ret = var_check(generator, 0, 1.0)
+
+ Parameters
+ ----------
+ generator : function
+ The generator function. It's expected to generate N i.i.d samples by calling generator(N).
+ sigma : float
+ nsamples : int
+
+ Returns
+ -------
+ ret : bool
+ Whether the variance test succeeds
+ """
+ samples = np.array(generator(nsamples))
+ sample_var = samples.var(ddof=1)
+ ret = (sample_var > sigma ** 2 - 3 * np.sqrt(2 * sigma ** 4 / (nsamples - 1))) and\
+ (sample_var < sigma ** 2 + 3 * np.sqrt(2 * sigma ** 4 / (nsamples - 1)))
+ return ret
+
+def chi_square_check(generator, buckets, probs, nsamples=1000000):
+ """Run the chi-square test for the generator. The generator can be both continuous and discrete.
+ If the generator is continuous, the buckets should contain tuples of (range_min, range_max) and
+ the probs should be the corresponding ideal probability within the specific ranges.
+ Otherwise, the buckets should be the possible output of the discrete distribution and the probs
+ should be groud-truth probability.
+
+ Usually the user is required to specify the probs parameter.
+
+ After obtatining the p value, we could further use the standard p > 0.05 threshold to get
+ the final result.
+
+ Examples::
+ buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, 0, 1), 5)
+ generator = lambda x: np.random.normal(0, 1.0, size=x)
+ p = chi_square_check(generator=generator, buckets=buckets, probs=probs)
+ assert(p > 0.05)
+
+ Parameters
+ ----------
+ generator: function
+ A function that is assumed to generate i.i.d samples from a specific distribution.
+ generator(N) should generate N random samples.
+ buckets: list of tuple or list of number
+ The buckets to run the chi-square the test. Make sure that the buckets cover
+ the whole range of the distribution. Also, the buckets must be in ascending order and have
+ no intersection
+ probs: list or tuple
+ The ground-truth probability of the random value fall in a specific bucket.
+ nsamples:int
+ The number of samples to generate for the testing
+
+ Returns
+ -------
+ p : float
+ p value that the generator has the expected distribution.
+ A higher value indicates a larger confidence
+ obs_freq : list
+ Observed frequency of buckets
+ expected_freq : list
+ The expected (ground-truth) frequency of the buckets
+ """
+ if not ss:
+ raise ImportError("scipy is not available."
+ " Please check if the scipy python bindings are installed.")
+ assert isinstance(buckets, list)
+ samples = generator(nsamples)
+ assert len(probs) == len(buckets)
+ if isinstance(buckets[0], (list, tuple)):
+ # Check whether the buckets are valid and fill them into a npy array
+ continuous_dist = True
+ buckets_npy = np.zeros((len(buckets) * 2, ), dtype=np.float32)
+ for i, _ in enumerate(buckets):
+ assert(buckets[i][0] <= buckets[i][1])
+ if i < len(buckets) - 1:
+ assert(buckets[i][1] <= buckets[i + 1][0])
+ buckets_npy[i * 2] = buckets[i][0]
+ buckets_npy[i * 2 + 1] = buckets[i][1]
+ else:
+ continuous_dist = False
+ buckets_npy = np.array(buckets)
+ expected_freq = (nsamples * np.array(probs, dtype=np.float32)).astype(np.int32)
+ if continuous_dist:
+ sample_bucket_ids = np.searchsorted(buckets_npy, samples, side='right')
+ else:
+ sample_bucket_ids = samples
+ if continuous_dist:
+ sample_bucket_ids = sample_bucket_ids // 2
+ obs_freq = np.zeros(shape=len(buckets), dtype=np.int)
+ for i in range(len(buckets)):
+ obs_freq[i] = (sample_bucket_ids == i).sum()
+ _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq)
+ return p, obs_freq, expected_freq
+
+def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, success_rate=0.25):
+ """Verify whether the generator is correct using chi-square testing.
+
+ The test is repeated for "nrepeat" times and we check if the success rate is
+ above the threshold (25% by default).
+
+ Parameters
+ ----------
+ generator: function
+ A function that is assumed to generate i.i.d samples from a specific distribution.
+ generator(N) should generate N random samples.
+ buckets: list of tuple or list of number
+ The buckets to run the chi-square the test. Make sure that the buckets cover
+ the whole range of the distribution. Also, the buckets must be in ascending order and
+ have no intersection
+ probs: list or tuple
+ The ground-truth probability of the random value fall in a specific bucket.
+ nsamples: int
+ The number of samples to generate for the testing
+ nrepeat: int
+ The times to repeat the test
+ success_rate: float
+ The desired success rate
+
+ Returns
+ -------
+ cs_ret_l: list
+ The p values of the chi-square test.
+ """
+ cs_ret_l = []
+ obs_freq_l = []
+ expected_freq_l = []
+ for _ in range(nrepeat):
+ cs_ret, obs_freq, expected_freq = chi_square_check(generator=generator, buckets=buckets,
+ probs=probs, nsamples=nsamples)
+ cs_ret_l.append(cs_ret)
+ obs_freq_l.append(obs_freq)
+ expected_freq_l.append(expected_freq)
+ success_num = (np.array(cs_ret_l) > 0.05).sum()
+ if success_num < nrepeat * success_rate:
+ raise AssertionError("Generator test fails, Chi-square p=%s, obs_freq=%s, expected_freq=%s."
+ "\nbuckets=%s, probs=%s"
+ % (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
+ str(buckets), str(probs)))
+ return cs_ret_l
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file random_generator.cu
+ * \brief gpu implements for parallel random number generator.
+ */
+
+#include <algorithm>
+#include "./random_generator.h"
+#include "../operator/mxnet_op.h"
+
+namespace mxnet {
+namespace common {
+namespace random {
+
+__global__ void rand_generator_seed_kernel(curandStatePhilox4_32_10_t *states_,
+ const int size,
+ uint32_t seed) {
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
+ if (id < size) curand_init(seed, id, 0, states_ + id);
+}
+
+template<>
+void RandGenerator<gpu, float>::Seed(Stream<gpu> *s, uint32_t seed) {
+ using namespace mshadow::cuda;
+ int ngrid = std::min(kMaxGridNum,
+ (RandGenerator<gpu, float>::kNumRandomStates + kBaseThreadNum - 1) /
+ kBaseThreadNum);
+ rand_generator_seed_kernel
+ <<<ngrid, kBaseThreadNum, 0, Stream<gpu>::GetStream(s)>>>(
+ states_,
+ RandGenerator<gpu, float>::kNumRandomStates,
+ seed);
+}
+
+} // namespace random
+} // namespace common
+} // namespace mxnet
Oops, something went wrong.

0 comments on commit 34a5195

Please sign in to comment.