vendredi 29 avril 2022

How to model random neurons in Jax?

Suppose I want to model random neurons in Jax, for instance, adding to each neuron's bias a small random value each time the neuron is used. How can I best do it?

The issue here is that random numbers in Jax are generated from a seed, in a stateless way. Therefore, I can think of two approaches, neither of which seems particularly appealing; my question is whether there is a better approach (and indeed, what is the recommended approach for such a thing).

The first approach would be to generate a matrix of random numbers as big as the neuron layer, and send to each neuron both their inputs, and these random numbers. This might work, but the problem I am imagining is that generating a big random matrix from a seed is a sequential, not parallel, operation (I assume?), and so by generating a big random matrix each time I do a forward step through the neurons, I forego much of the benefits of parallelization (because neuron computation can be done in parallel, but generating the random matrix each time might not be parallelizable).

The second approach would be to endow each neuron with their own seed, which would function as a random-state at the neuron level. Each time the neuron computes the output, it also splits the random seeds and output the split seed, which it will get as input the next time. That is, for a layer with N neurons, I would have an array of size N of random seeds, and I would vmap on the N (seeds, inputs) a function that computes as output the N outputs, and N new split seeds.

Is any of these two approaches the recommended one? Or is there a different approach that is better?




Aucun commentaire:

Enregistrer un commentaire