mercredi 25 janvier 2023

jax minimization with stochastically estimated gradients

I'm trying to use the bfgs optimizer from tensorflow_probability.substrates.jax and from jax.scipy.optimize.minimize to minimize a function f which is estimated from pseudo-random samples and has a jax.random.PRNGKey as argument. To use this function with the jax/tfp bfgs minimizer, I wrap the function inside a lambda function

seed = 100
key  = jax.random.PRNGKey(seed)
fun = lambda x: return f(x,key)
result = jax.scipy.optimize.minimize(fun = fun, ...)

What is the best way to update the key when the minimization routine calls the function to be minimized so that I use different pseudo-random numbers in a reproducible way? Maybe a global key variable? If yes, is there an example I could follow?

Secondly, is there a way to make the optimization stop after a certain amount of time, as one could do with a callback in scipy? I could directly use the scipy implementation of bfgs/ l-bfgs-b/ etc and use jax ony for the estimation of the function and of tis gradients, which seems to work. Is there a difference between the scipy, jax.scipy and tfp.jax bfgs implementations?

Finally, is there a way to print the values of the arguments of fun during the bfgs optimization in jax.scipy or tfp, given that f is jitted?

Thank you!




Aucun commentaire:

Enregistrer un commentaire