vendredi 20 mai 2022

Is numpy rng thread safe?

I implemented a function that uses the numpy random generator to simulate some process. Here is a minimal example of such a function:

def thread_func(cnt, gen):
    s = 0.0
    for _ in range(cnt):
        s += gen.integers(6)
    return s

Now I wrote a function that uses python's starmap to call the thread_func. If I were to write it like this (passing the same rng reference to all threads):

def evaluate(total_cnt, thread_cnt):
    gen = np.random.default_rng()
    cnt_per_thread = total_cnt // thread_cnt
    with Pool(thread_cnt) as p:
        vals = p.starmap(thread_func, [(cnt_per_thread,gen) for _ in range(thread_cnt)])
    return vals

The result of evaluate(100000, 5) is an array of 5 same values, for example:

[49870.0, 49870.0, 49870.0, 49870.0, 49870.0]

However if I pass a different rng to all threads, for example by doing:

vals = p.starmap(thread_func, [(cnt_per_thread,np.random.default_rng()) for _ in range(thread_cnt)])

I get the expected result (5 different values), for example:

[49880.0, 49474.0, 50232.0, 50038.0, 50191.0]

Why does this happen?




Aucun commentaire:

Enregistrer un commentaire