Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY,TOPI] Threefry PRNG: splittable and stateless #7083

Merged
merged 15 commits into from
Jan 9, 2021

Conversation

tkonolige
Copy link
Contributor

This PR adds a fast PRNG to Relay for use in dropout and batch norm. The PRNG is stateless: for a given key, it always returns the same random number. It is also splittable: for a given key, we can split the key to generate two new ones.

JAX provides a good explanation of stateless and splittable: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG.

@altanh

@altanh
Copy link
Contributor

altanh commented Dec 10, 2020

awesome work, cc @tqchen @junrushao1994 @MarisaKirisame @eric-haibin-lin who may be interested

I think it may be worth discussing high vs low level API for this, and what namespace it should live in. I wrote a few examples for how we might use this here https://discuss.tvm.apache.org/t/rfc-handling-effect-in-tvm-and-relay/5946/25?u=altanh

@altanh
Copy link
Contributor

altanh commented Dec 11, 2020

worth noting that threefry_generate can only generate for shapes that have size as multiple of 4, can you update the examples and documentation in the Relay python op to reflect this?

@altanh
Copy link
Contributor

altanh commented Dec 15, 2020

Naming

I propose we move everything PRNG to a new random namespace/module in Relay, so relay.random.threefry_generate etc.

Handling different PRNG kernels

The splitting, keygen, and bit-gen operations will be kernel-specific. However, AFAIK, most if not all of the commonly used random ops simply require random bits as inputs (i.e. don't care about how those bits were generated). I propose the following approach to handling different kernels (thanks to @jroesch for helpful discussion):

  • For each kernel K (e.g. K = threefry, K = philox, etc.), define relay.random.K_key, relay.random.K_split , and relay.random.K_generate.
  • In Python (and any other host language where we plan on writing Relay code directly), define relay.random.key(seed: int, K: RandomAlg) (where RandomAlg is some kind of enum with members RandomAlg.THREEFRY etc.) and relay.random.split(key, K: RandomAlg). We can set a default K (for now obviously Threefry) to make it easier to use.
  • Now, for each random op rop that we care about, define relay.random.rop_from_bits, and in Python define relay.random.rop(key, K) which internally calls relay.random.rop_from_bits(relay.random.K_generate(key)). This should suffice to hide the algorithm from the user-facing API.

Problems.

  • This might fail silently if the user inconsistently uses a key for kernel K with an op that uses kernel L ≠ K, if the keys are the same shape. This seems quite bad. Ideally, the keys will somehow be typed (perhaps we can use a Relay ADT?) but not sure how well supported this kind of use case is currently.

Other notes

  • I will open a follow-up PR that adds some initial operators (currently I have a uniform random op), along with support for non-multiple of 4 sizes.
  • We should think about how to generalize over dtypes - currently, only uint64 is supported for random bit generation in Threefry. Do we want to support more in the bit generation (is it possible?), or should we delegate this responsibility to auxiliary ops that split/join the uint64s as needed (e.g. converting 5 uint64 to 10 uint32)? I think we need to come up with some kind of standardized approach for all the future random ops.

@tqchen
Copy link
Member

tqchen commented Dec 16, 2020

cc @antinucleon

@altanh
Copy link
Contributor

altanh commented Dec 16, 2020

I've moved everything to a new random namespace (for Relay and TOPI) and fixed a small bug. This should be good to go I think.

@@ -25,21 +25,23 @@ namespace relay {

TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs);

static const TensorType THREEFRY_KEY_TYPE = TensorType({10}, tvm::DataType::UInt(64));
Copy link
Member

@tqchen tqchen Dec 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a rule of thumb, try to avoid static variables. As sometimes they have static variable constructing order issues Use static functions that returns these variables instead. As the construction is not that costly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, is the naming style OK for a static function? THREEFRY_KEY_TYPE()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually CamelCase is preferred for a global or static function

@junrushao junrushao linked an issue Dec 31, 2020 that may be closed by this pull request
@altanh
Copy link
Contributor

altanh commented Jan 5, 2021

bump @antinucleon @junrushao1994 @eric-haibin-lin @MarisaKirisame

(please cc anyone else interested in PRNG for review, thanks!)

Copy link
Contributor

@altanh altanh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tkonolige I had some additional questions about TODOs in the Threefry kernel, maybe you can clarify?

python/tvm/relay/op/random/kernel.py Outdated Show resolved Hide resolved
# number of rounds is even, so out always contains the result
(out_buf, tmp) = (tmp, out_buf)
(out_offset, tmp_offset) = (tmp_offset, out_offset)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see some TODO in this function (_threefry), do they affect the correctness of the algorithm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only one that matters is TODO should be wrapping. I do not know if TVM guarantees unsigned integer arithmetic to be wrapping (instead of saturating).

python/tvm/topi/random/kernel.py Show resolved Hide resolved
@tmoreau89
Copy link
Contributor

Thanks for the reviews, @junrushao1994 @MarisaKirisame @eric-haibin-lin if you have some cycles, input would be appreciated!

Copy link
Contributor

@jwfromm jwfromm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really well written and commented code, great work!


TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs);

static TensorType ThreefryKeyType() { return TensorType({10}, tvm::DataType::UInt(64)); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for not pointing this out earlier. Maybe a good thing to do is to wrap this in a newtype? (e.g. define a type ThreefryKey that you cannot use in anyway except in random operation. this will avoid doing arithmetic on random seed.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a new opaque new type to tvm seems really involved. We have to add a new visitor for each type visitor, which seems like it may cause issues with some passes. We'd also have to add a no-op function with implementations to satisfy the type checker. Or we'd have to add a wrapper struct with all the proper conversion functions. Given all this complication, I don't think it is a good idea.

@tqchen
Copy link
Member

tqchen commented Jan 8, 2021

also cc @yzhliu @hzfan @comaniac who might be interested

Copy link
Contributor

@electriclilies electriclilies left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks good to me! I just have a few design questions and notes on comments and your tests.

:py:func:`threefry_generate`. **Do not use this key again after calling
this function.**

shape : Sequence[int]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the total number of outputs need to be a multiple of four?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an implementation detail. Basically, threefry uses 4 64-bit words as its state, inputs, and outputs.

def threefry_split(key):
"""Split an existing Threefry key into two new ones.

This is useful if you have to subsequent calls which each need their own
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why wouldn't someone just create two separate three fry keys using different seeds, and use them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating separate keys has not been theoretically proven to be as random as splitting a single key. Maybe I should add a comment that you should only really create one key. On the other hand, these details might be better handled at a higher level interface (future work).

# there is no state to maintain, we can apply it to a sequence of numbers (0..N) to generate a
# sequence of random numbers in parallel. In order to make the PRNG splittable (that is we can
# generate a sequence of random numbers in one place, and another sequence in another), we add a
# path and key in addition to the counter. The path allows us to encode a sequence of splits (a 0 in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on how path and key are used in number generation? You don't explain what the key is, either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last sentence explains what the key is: "To avoid continuously growing the path, we can compress an existing path into the key portion of the generator by hashing the current key, path, and counter to create the new key (this same technique is used if we run out of room for the counter)." I've added a comment on how it is initialized.

I've also added an explanation of how random numbers are generated (we apply the hash to key, path, and counter).

tests/python/relay/test_prng.py Show resolved Hide resolved
@tmoreau89 tmoreau89 merged commit 701bcc2 into apache:main Jan 9, 2021
@tmoreau89
Copy link
Contributor

Thank you @altanh @electriclilies @tqchen @jwfromm @MarisaKirisame for the reviews, the PR has been merged.

tkonolige added a commit to tkonolige/incubator-tvm that referenced this pull request Jan 11, 2021
* [RELAY,TOPI] Threefry PRNG: splittable and stateless

* Fix sphinx?

* Lint fixes

* sphinx fixes round 2

* fix inputs for tests

* reorganize to random, fix uninitialized memory bug

* silence linter

* silence linter even further

* s

* strengthen Threefry key type checking, add tests

* replace static variable with function for Threefry key type

* lint fix

* Remove old todos, improve assert messages

* describe how random number is generated

* add tests for incorrect output size. also vary test sizes

Co-authored-by: Altan Haan <[email protected]>
masahi pushed a commit to masahi/tvm that referenced this pull request Jan 14, 2021
* [RELAY,TOPI] Threefry PRNG: splittable and stateless

* Fix sphinx?

* Lint fixes

* sphinx fixes round 2

* fix inputs for tests

* reorganize to random, fix uninitialized memory bug

* silence linter

* silence linter even further

* s

* strengthen Threefry key type checking, add tests

* replace static variable with function for Threefry key type

* lint fix

* Remove old todos, improve assert messages

* describe how random number is generated

* add tests for incorrect output size. also vary test sizes

Co-authored-by: Altan Haan <[email protected]>
masahi pushed a commit to masahi/tvm that referenced this pull request Jan 18, 2021
* [RELAY,TOPI] Threefry PRNG: splittable and stateless

* Fix sphinx?

* Lint fixes

* sphinx fixes round 2

* fix inputs for tests

* reorganize to random, fix uninitialized memory bug

* silence linter

* silence linter even further

* s

* strengthen Threefry key type checking, add tests

* replace static variable with function for Threefry key type

* lint fix

* Remove old todos, improve assert messages

* describe how random number is generated

* add tests for incorrect output size. also vary test sizes

Co-authored-by: Altan Haan <[email protected]>
TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
* [RELAY,TOPI] Threefry PRNG: splittable and stateless

* Fix sphinx?

* Lint fixes

* sphinx fixes round 2

* fix inputs for tests

* reorganize to random, fix uninitialized memory bug

* silence linter

* silence linter even further

* s

* strengthen Threefry key type checking, add tests

* replace static variable with function for Threefry key type

* lint fix

* Remove old todos, improve assert messages

* describe how random number is generated

* add tests for incorrect output size. also vary test sizes

Co-authored-by: Altan Haan <[email protected]>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
* [RELAY,TOPI] Threefry PRNG: splittable and stateless

* Fix sphinx?

* Lint fixes

* sphinx fixes round 2

* fix inputs for tests

* reorganize to random, fix uninitialized memory bug

* silence linter

* silence linter even further

* s

* strengthen Threefry key type checking, add tests

* replace static variable with function for Threefry key type

* lint fix

* Remove old todos, improve assert messages

* describe how random number is generated

* add tests for incorrect output size. also vary test sizes

Co-authored-by: Altan Haan <[email protected]>
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* [RELAY,TOPI] Threefry PRNG: splittable and stateless

* Fix sphinx?

* Lint fixes

* sphinx fixes round 2

* fix inputs for tests

* reorganize to random, fix uninitialized memory bug

* silence linter

* silence linter even further

* s

* strengthen Threefry key type checking, add tests

* replace static variable with function for Threefry key type

* lint fix

* Remove old todos, improve assert messages

* describe how random number is generated

* add tests for incorrect output size. also vary test sizes

Co-authored-by: Altan Haan <[email protected]>
@yzhliu
Copy link
Member

yzhliu commented Mar 5, 2021

nice work. do we have plan to support cuda? Looks like it works on cpu only at this moment.

@tkonolige
Copy link
Contributor Author

It should be easy to use on the GPU, just parallelize the outer loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RELAY] Support Random Number Generator
8 participants