I am trying to speed up a function that randomly samples a number of records with the possible combinations of a number of categories for a number of records and ensures they are unique (i.e. let's assume there's 3 records, any of them can be either 0 or 1 and I want 10 random samples of unique possible combinations of records).
If I did not use numba, I might would do something like this:
import numpy as np
def myfunc(categories, NumberOfRecords, maxsamples):
return np.unique( np.random.choice(np.arange(categories), size=(maxsamples*10, NumberOfRecords), replace=True), axis=0 )[0:maxsamples]
Annoyingly, numba does not support axis in np.unique, so I can do something like this, but some of the records may turn out to be non-unique.
from numba import njit, int64
import numpy as np
@njit(int64[:,:](int64, int64, int64), cache=True)
def myfunc(categories, NumberOfRecords, maxsamples):
return np.random.choice(np.arange(categories), size=(maxsamples, NumberOfRecords), replace=True)
myfunc(categories=2, NumberOfRecords=3, maxsamples=10)
E.g. in one call (obviously there's some randomness here), I got the below (for which the indices 1 and 6, and 3 and 4, and 7 and 9 are identical rows):
array([[0, 1, 1],
[1, 1, 0],
[0, 1, 0],
[1, 0, 1],
[1, 0, 1],
[1, 1, 1],
[1, 1, 0],
[1, 0, 0],
[0, 0, 0],
[1, 0, 0]])
My questions are:
- Is this something where I would even expect a speed up from numba?
- If so, how can I get a unique rows (this seems rather difficult with numba, but presumably there's a way)?
- Perhaps there's a way to get at this more efficiently (perhaps without creating more random samples than I need in the end)?
Aucun commentaire:
Enregistrer un commentaire