Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Conversation

yzhliu
Copy link
Member

@yzhliu yzhliu commented Dec 18, 2017

Description

The current implement for random sampler (#8179) raises a new (random) seed every time it starts to generate a random number, which is incorrect. @asmushetzel

Although std::mt19937 seems to work with this approach, samplers with curand generate low-quality of randomness and probably makes these numbers correlated.

Firstly, I noticed training with SGLD collapses to low ACC (#8958). Secondly, @sxjscience has written new test cases for random sampler, using mean/var/chi square test, non of them passes with the current implement.

According to Nvidia document for kernel random API, we need to maintain global seeds and reuse them. Moreover, the random state is not thread-safe.

The implement here maintains a fixed number of global random states and can be access through Resource. In case it is accessed by multiple GPU streams, (default) 4 independent GPU generators are created in global Resource.

I tested example/bayesian-methods and now it converges to reasonable results. It should pass @sxjscience 's new test cases as well.

The memory usage and speed barely changes.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage: Waiting for @sxjscience 's test cases
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Global seeds for GPU & CPU sampler.
  • Fix SGLD optimizer arguments.

@yzhliu yzhliu self-assigned this Dec 18, 2017
@yzhliu yzhliu added the Bug label Dec 18, 2017
@sxjscience
Copy link
Member

@Javelinjs I've uploaded the script here https://gist.github.com/sxjscience/453605a1ea3102bc0010f9fb16df8238. Currently we should rely on the result of "Chi Square test" and the "mean test" as the "var test" needs far more samples. The "chi square test" performs relatively the best.

}

template<>
RandGenerator<cpu, float> *NewRandGenerator<cpu, float>() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return an value so you don't need DeleteRandGenerator

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But for NewRandGenerator<gpu> it is allocated on device.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should return an class that has an internal pointer to device.

@sxjscience
Copy link
Member

@Javelinjs The updated gist here https://gist.github.com/sxjscience/453605a1ea3102bc0010f9fb16df8238 tests all the available random ops in MXNet: normal, uniform, gamma, exponential, poisson, negative_binomial, generalized_negative_binomial, multinomial

const int kGPURndStateNum = 32768;

// uniform number generation in Cuda made consistent with stl (include 0 but exclude 1)
// by using 1.0-curand_uniform(). Needed as some samplers below won't be able to deal with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment taken from prior code. A bit misleading as it references "samplers below" (but in this file, there are none).

inline static void LaunchNativeRandomGenerator(mshadow::Stream<cpu> *,
common::random::RandGenerator<cpu, GType> *rnd,
const int N, Args... args) {
// do not use openmp since it does not guarantee the output order.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear to me what it means that "does not guarantee the output order". I really think we should support openmp here as sampling on CPU is equivalent important to sampling on CPU and we should not leave a potential speedup of 4-8 on the table. Sampling on CPU is really slow anyway.
Wouldn't it be natural to use the exact same design pattern as for the GPU case, i.e. a set of preallocated samplers as a GlobalSampler and then assign them to the different threads?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, assume the actual sequence of std:: mt19937 is 0.1, 0.2, 0.3, 0.4, when it is generated for arr[4] with openmp this could become arr = {0.2, 0.1, 0.4, 0.3}.

since std:: mt19937 is thread-safe, I didn't preallocate it for multi-threads. But you're right, the same design could be adopted.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think we should use openmp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adopting the same design would solve this ordering issue. And we don't need to pre-allocate thousands of CPU-samplers, 256 would certainly be enough.

@asmushetzel
Copy link
Contributor

Nice catch and implementation. I was thinking about the same pattern (using a sufficiently large pool of pre-allocated random generators) as well in the initial implementation, but wasn't sure about blocking that much memory for the entire runtime. But in fact, it is the far better solution, not only in terms of sampling accuracy.
Left some more comments in the code.
One thing that should be added are unit tests to test_random.py that also verify "chi square" . We need to add the test that exhibited the problem with the prior implementation and ensures that from now on we never degrade again.

@yzhliu
Copy link
Member Author

yzhliu commented Dec 19, 2017

@asmushetzel I'll merge the test in #9129

fix lint

fix lint

fix typo

fix docstring

fix docstring
kTempSpace
kTempSpace,
/*! \brief common::RandGenerator<xpu> object, which can be used in GPU kernel functions */
kNativeRandom
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a new enum? Can this be merged with kRandom?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kRandom returns mshadow::Random, whose behavior is different from the new one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be called kParallelRandom

Copy link
Contributor

@asmushetzel asmushetzel Dec 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 (kParallelRandom would be a name that expresses what it is good for) Same for all functions that have "native" in their name.

}

template<>
RandGenerator<cpu, float> *NewRandGenerator<cpu, float>() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should return an class that has an internal pointer to device.

// (non-thread-safe) random generator stores global states,
// always use mxnet_op::LaunchNativeRandomGenerator for launching a multi-threaded kernel.
template<typename DType>
class RandGeneratorGlobal<gpu, DType> : public RandGenerator<gpu, DType> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this?

inline static void LaunchNativeRandomGenerator(mshadow::Stream<cpu> *,
common::random::RandGenerator<cpu, GType> *rnd,
const int N, Args... args) {
// do not use openmp since it does not guarantee the output order.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think we should use openmp.

* \param args Varargs to eventually pass to the OP::Map() functoion
*/
template<typename GType, typename ...Args>
inline static void LaunchNativeRandomGenerator(mshadow::Stream<cpu> *,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LaunchRNG

const int N, Args... args) {
using namespace mshadow::cuda;
const int nloop(1 + (N - 1) / common::random::kGPUMinRndNumberPerThread);
int ngrid = std::min(common::random::kGPURndStateNum / kBaseThreadNum,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

common::random::kGPURndStateNum / kBaseThreadNum could be 0

for (int i = id * kGPUMinRndNumberPerThread;
i < N;
i += nthread * kGPUMinRndNumberPerThread) {
for (int j = 0; j < kGPUMinRndNumberPerThread && i + j < N; ++j) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two loops look weird. Are you sure it should be i<N?

class RandGenerator;

template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
class RandGeneratorGlobal;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this should be the internal implementation of RandGenerator. It doesn't need to be a top level public class

@piiswrong
Copy link
Contributor

I don't think we need LaunchRNG. Why not Launch with N = rnd->size() ?

@yzhliu
Copy link
Member Author

yzhliu commented Dec 20, 2017

@piiswrong By using LaunchRNG I want to hide the underlying implementation to users. Otherwise,

  • Users need to understand gpu state is not thread-safe, and pick a curand state in Op::Map, then loop the array carefully.
  • The implementation of Launch implies two adjacent array entries are accessed by two adjacent threads in one block. But in my understanding, we should generate successive numbers from one state, i.e., one thread, as much as we can. With Launch, this will make users suffering calculating the array element index, which can easily make mistake and the codes will be weird.
  • The curand states are allocated in global memory. For efficiency, we copy it to local memory when launching a kernel (and copy back to global at the end, as suggested in NV's doc). Users probably do not want to do it in every Op::Map function.

I think LaunchRNG is a convenient helper function to use. Users can still use Launch if they want to control everything.

@yzhliu
Copy link
Member Author

yzhliu commented Dec 20, 2017

I have refactor RandomGenerator for readability, and add openmp for CPU. ping @piiswrong @asmushetzel
I also merge @sxjscience 's PR #9129 in.

@piiswrong
Copy link
Contributor

adding LaunchXX interface complicates the design.
You can add a helper function in the random operators file


#if MXNET_USE_CUDA

// at least how many random numbers should be generated by one GPU thread.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a number of contiguous random numbers should be generated from one state

// at least how many random numbers should be generated by one CPU thread.
const int kCPUMinRndNumberPerThread = 64;
// store how many global random states for CPU.
const int kCPURndStateNum = 1024;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be
RandGenerator<cpu, DType>::kNumRandomStates

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed.


// Free the allocated GPU memory.
// For global singleton,
// calling this in destructor may cause undefined behavior.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

// Will use float data type whenever instantiated for half_t or any other non
// standard real type.
template<typename Device, typename DType MSHADOW_DEFAULT_DTYPE>
class RandGeneratorImpl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need RandGenerator and RandGeneratorImpl? Why not just one?

public:
// Copy state to local memory for efficiency.
__device__ explicit RandGeneratorImpl(curandStatePhilox4_32_10_t *state)
: state_(*state) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are copying state to state_ by value. Then wouldn't the next call of the same random operator give you the same results?

Copy link
Member

@sxjscience sxjscience Dec 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tested the case of multiple runs of the same generator. I should add that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two choices:

  • Follow the exact same patterns in CPU and GPU, i.e. copy by value but then also ensure to save the state again back in the RandGenerator at the end of LaunchRNG (which isn't the case currently)
  • Copy state by reference in the CPU case
    I would prefer a consistent handling for both cases (i.e. first version).

Copy link
Member Author

@yzhliu yzhliu Dec 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implement here was correct. It did save the state back in LaunchRNG.
I now refactor it to a more readable version.

RandGenerator() {
cudaError_t e = cudaMalloc(&states_, kGPURndStateNum * sizeof(curandStatePhilox4_32_10_t));
if (e != cudaSuccess && e != cudaErrorCudartUnloading) {
throw std::bad_alloc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using the existing macros for interpreting cuda-errors (i.e. CUDA_CALL(cudaMalloc(.....))? That would also tell the user that something went wrong on the device, while throwing std::bad_alloc will be misleading as it refers to a memory allocation on the host.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

#ifdef _OPENMP
const int omp_threads = std::min(kCPURndStateNum,
engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
if (omp_threads < 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need special code for omp_threads < 2. The general loop below will work there as well and not create any overhead either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intentional.
OpenMP disables some compiler optimizations, so an omp loop with 1 thread is slower than a simple loop without omp

common::random::RandGenerator<cpu, GType> *rnd,
const int N, Args... args) {
using namespace mxnet::common::random;
#ifdef _OPENMP
Copy link
Contributor

@asmushetzel asmushetzel Dec 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need that ifdef. Even if we compile w/out openmp-support, the functions GetRecommendedOMPTHreadCount() and omp_get_thread_num() should be replaced by appropriate stubs. So your code does not need to do any specialties (see dmlc/omp.h and engine/openmp.cc).

@@ -34,6 +34,7 @@
import numpy as np
import numpy.testing as npt
import numpy.random as rnd
import scipy.stats as ss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't depend on scipy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should I revise this? Move it to be inside the functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see mx.nd.sparse

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Javelinjs @piiswrong I've added one commit to solve the problem. Also,
I also add tests for the case in which the generator is triggered multiple times. 199fabd

@yzhliu
Copy link
Member Author

yzhliu commented Dec 28, 2017

merged changes from #9129

@piiswrong piiswrong merged commit 34a5195 into apache:master Dec 28, 2017
yzhliu added a commit to yzhliu/mxnet that referenced this pull request Dec 29, 2017
* add tests for distribution generators

fix lint

fix lint

fix typo

fix docstring

fix docstring

* [Bugfix] fix random generator: do not gen seed each time

* gen samplers on gpu for test_softmax

* fix test cases

* remove unnecessary prints

* refactor RandGenerator

* get_native_random -> get_parallel_random

* revise test cases + remove dependency of scipy

* raise warning
meissnereric pushed a commit to meissnereric/incubator-mxnet that referenced this pull request Jan 2, 2018
* add tests for distribution generators

fix lint

fix lint

fix typo

fix docstring

fix docstring

* [Bugfix] fix random generator: do not gen seed each time

* gen samplers on gpu for test_softmax

* fix test cases

* remove unnecessary prints

* refactor RandGenerator

* get_native_random -> get_parallel_random

* revise test cases + remove dependency of scipy

* raise warning
yuxiangw pushed a commit to yuxiangw/incubator-mxnet that referenced this pull request Jan 25, 2018
* add tests for distribution generators

fix lint

fix lint

fix typo

fix docstring

fix docstring

* [Bugfix] fix random generator: do not gen seed each time

* gen samplers on gpu for test_softmax

* fix test cases

* remove unnecessary prints

* refactor RandGenerator

* get_native_random -> get_parallel_random

* revise test cases + remove dependency of scipy

* raise warning
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* add tests for distribution generators

fix lint

fix lint

fix typo

fix docstring

fix docstring

* [Bugfix] fix random generator: do not gen seed each time

* gen samplers on gpu for test_softmax

* fix test cases

* remove unnecessary prints

* refactor RandGenerator

* get_native_random -> get_parallel_random

* revise test cases + remove dependency of scipy

* raise warning
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* add tests for distribution generators

fix lint

fix lint

fix typo

fix docstring

fix docstring

* [Bugfix] fix random generator: do not gen seed each time

* gen samplers on gpu for test_softmax

* fix test cases

* remove unnecessary prints

* refactor RandGenerator

* get_native_random -> get_parallel_random

* revise test cases + remove dependency of scipy

* raise warning
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants