I'm trying to use the bfgs optimizer from tensorflow_probability.substrates.jax
and from jax.scipy.optimize.minimize
to minimize a function f
which is estimated from pseudo-random samples and has a jax.random.PRNGKey
as argument. To use this function with the jax/tfp bfgs minimizer, I wrap the function inside a lambda function
seed = 100
key = jax.random.PRNGKey(seed)
fun = lambda x: return f(x,key)
result = jax.scipy.optimize.minimize(fun = fun, ...)
What is the best way to update the key when the minimization routine calls the function to be minimized so that I use different pseudo-random numbers in a reproducible way? Maybe a global key variable? If yes, is there an example I could follow?
Secondly, is there a way to make the optimization stop after a certain amount of time, as one could do with a callback in scipy? I could directly use the scipy implementation of bfgs/ l-bfgs-b/ etc and use jax ony for the estimation of the function and of tis gradients, which seems to work. Is there a difference between the scipy, jax.scipy and tfp.jax bfgs implementations?
Finally, is there a way to print the values of the arguments of fun
during the bfgs optimization in jax.scipy or tfp, given that f
is jitted?
Thank you!
Aucun commentaire:
Enregistrer un commentaire