vendredi 4 août 2023

Jax: generating random numbers under **JIT**

I have a setup where I need to generate some random number that is consumed by vmap and then lax.scan later on:

def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
    ...
    return num.astype(int)

def forward(key: Array, input: Array) -> Array:
    k = generate_random(key, 1, 5)
    computation = model(.., k, ..)
    ...

# Computing the forward pass
output = jax.vmap(forward, in_axes=.....

But attempting to convert num from a jax.Array to an int32 causes the ConcretizationError.

This can be reproduced through this minimal example:

@jax.jit
def t():
  return jnp.zeros((1,)).item().astype(int)

o = t()
o

JIT requires that all the manipulations be of the Jax type.

But vmap uses JIT implicitly. And I would prefer to keep it for performance reasons.


My Attempt

This was my hacky attempt:

@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
  key, subkey = jax.random.split(key)
  random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
  return random_number.astype(int)

def react_forward(key: Array, input: Array) -> Array:
  k = get_rand_num(key, 1, MAX_ITERS)
  # forward pass the model without tracking grads
  intermediate_array = jax.lax.stop_gradient(model(input, k)) # THIS LINE ERRORS OUT
  ...
  return ...

a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape

Which involves creating the batch_size # of subkeys to use at every batch during vmap (a.shape[0]) thus getting random numbers.

But it doesn't work, because of the k being casted from jax.Array -> int.

But making these changes:

-  k = get_rand_num(key, 1, MAX_ITERS)
+  k = 5 # any hardcoded int

Works perfectly. Clearly, the sampling is causing the problem here...




Aucun commentaire:

Enregistrer un commentaire