I have a setup where I need to generate some random number that is consumed by vmap
and then lax.scan
later on:
def generate_random(key: Array, upper_bound: int, lower_bound: int) -> int:
...
return num.astype(int)
def forward(key: Array, input: Array) -> Array:
k = generate_random(key, 1, 5)
computation = model(.., k, ..)
...
# Computing the forward pass
output = jax.vmap(forward, in_axes=.....
But attempting to convert num
from a jax.Array
to an int32
causes the ConcretizationError
.
This can be reproduced through this minimal example:
@jax.jit
def t():
return jnp.zeros((1,)).item().astype(int)
o = t()
o
JIT requires that all the manipulations be of the Jax type.
But vmap
uses JIT implicitly. And I would prefer to keep it for performance reasons.
My Attempt
This was my hacky attempt:
@partial(jax.jit, static_argnums=(1, 2))
def get_rand_num(key: Array, lower_bound: int, upper_bound: int) -> int:
key, subkey = jax.random.split(key)
random_number = jax.random.randint(subkey, shape=(), minval=lower_bound, maxval=upper_bound)
return random_number.astype(int)
def react_forward(key: Array, input: Array) -> Array:
k = get_rand_num(key, 1, MAX_ITERS)
# forward pass the model without tracking grads
intermediate_array = jax.lax.stop_gradient(model(input, k)) # THIS LINE ERRORS OUT
...
return ...
a = jnp.zeros((300, 32)).astype(int)
rndm_keys = jax.random.split(key, a.shape[0])
jax.vmap(react_forward, in_axes=(0, 0))(rndm_keys, a).shape
Which involves creating the batch_size
# of subkeys to use at every batch during vmap
(a.shape[0]
) thus getting random numbers.
But it doesn't work, because of the k
being casted from jax.Array -> int
.
But making these changes:
- k = get_rand_num(key, 1, MAX_ITERS) + k = 5 # any hardcoded int
Works perfectly. Clearly, the sampling is causing the problem here...
Aucun commentaire:
Enregistrer un commentaire