mardi 23 avril 2019

How to compile a numba jit'ed function with variable input type?

Say I have a function that can accept both an int or a None type as an input argument

import numba as nb
import numpy as np

jitkw = {"nopython": True, "nogil": True, "error_model": "numpy", "fastmath": True}


@nb.jit("f8(i8)", **jitkw)
def get_random(seed=None):
    np.random.seed(None)
    out = np.random.normal()
    return out

I want the function to simply return a normally distributed random number. If I want reproducible results, seed should be an int.

get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327
get_random(42)
>>> 0.4967141530112327

If I want random numbers, seed should be left as None. However, if I do not pass an argument (so seed defaults to None) or explicitly pass seed=None, then numba raises a TypeError

get_random()
>>> TypeError: No matching definition for argument type(s) omitted(default=None)
get_random(None)
>>> TypeError: No matching definition for argument type(s) omitted(default=None)

How can I write the function, still declaring the signature and using nopython mode for such a scenario?

My numba version is 0.43.1




Aucun commentaire:

Enregistrer un commentaire