I'm looking for ideas on how to optimize the sampling of a varying number of guests for a varying number of hosts. Let me clarify what I'm trying to do.
Given a number of hosts "n_hosts", each one with a different number of possible guests, "n_possible_guests_per_host" I want to sample a number "n_guests" from the list of possible guests. "n_guests" is also different for each host. I'm finding it challenging due to jax's fixed input/output requirements. Here is a code sample of the brute-force approach, it took about 2 seconds on my laptop.
import numpy as np
import jax.numpy as jnp
from jax import random
n_possible_guests = 50_000
n_hosts = 1000
n_possible_guests_per_host = [np.random.randint(low=0, high=100) for i in range(n_hosts-1)]
n_possible_guests_per_host += [max(0,n_possible_guests - sum(n_possible_guests_per_host))]
guest_idx = np.arange(n_possible_guests)
host_idx = np.arange(n_hosts)
n_to_sample = [np.random.randint(low=0, high=n) if n!= 0 else 0 for n in n_possible_guests_per_host]
def brute_force(guest_idx, host_idx, n_possible_guests_per_host, n_to_sample):
first_guest_idx = np.cumsum(n_possible_guests_per_host) - n_possible_guests_per_host
key = random.PRNGKey(0)
chosen_guests = jnp.zeros(sum(n_to_sample))
n_chosen = 0
for (host_id, n_sample) in zip(host_idx, n_to_sample):
possible_guests_in_host = guest_idx[
first_guest_idx[host_id]:first_guest_idx[host_id] + n_possible_guests_per_host[host_id]
]
chosen_idx = random.choice(
key, possible_guests_in_host, shape=(n_sample,)
)
chosen_guests = chosen_guests.at[
n_chosen:n_chosen+len(chosen_idx)
].set(chosen_idx)
n_chosen += len(chosen_idx)
return chosen_guests
brute_force(guest_idx, host_idx, n_possible_guests_per_host, n_to_sample)
Aucun commentaire:
Enregistrer un commentaire