lundi 25 avril 2022

Optimizing sampling with varying sample sizes in jax

I'm looking for ideas on how to optimize the sampling of a varying number of guests for a varying number of hosts. Let me clarify what I'm trying to do.

Given a number of hosts "n_hosts", each one with a different number of possible guests, "n_possible_guests_per_host" I want to sample a number "n_guests" from the list of possible guests. "n_guests" is also different for each host. I'm finding it challenging due to jax's fixed input/output requirements. Here is a code sample of the brute-force approach, it took about 2 seconds on my laptop.

import numpy as np
import jax.numpy as jnp
from jax import random

n_possible_guests = 50_000
n_hosts = 1000
n_possible_guests_per_host = [np.random.randint(low=0, high=100) for i in range(n_hosts-1)]

n_possible_guests_per_host += [max(0,n_possible_guests - sum(n_possible_guests_per_host))]
guest_idx = np.arange(n_possible_guests) 
host_idx = np.arange(n_hosts)

n_to_sample = [np.random.randint(low=0, high=n) if n!= 0 else 0 for n in n_possible_guests_per_host]


def brute_force(guest_idx, host_idx, n_possible_guests_per_host, n_to_sample):
    first_guest_idx = np.cumsum(n_possible_guests_per_host) - n_possible_guests_per_host
    key = random.PRNGKey(0)
    chosen_guests = jnp.zeros(sum(n_to_sample))
    n_chosen = 0
    for (host_id, n_sample) in zip(host_idx, n_to_sample):
        possible_guests_in_host = guest_idx[
            first_guest_idx[host_id]:first_guest_idx[host_id] + n_possible_guests_per_host[host_id]
        ]
        chosen_idx = random.choice(
            key, possible_guests_in_host, shape=(n_sample,)
        )
        chosen_guests = chosen_guests.at[
            n_chosen:n_chosen+len(chosen_idx)
        ].set(chosen_idx)
        n_chosen += len(chosen_idx)    
    return chosen_guests

brute_force(guest_idx, host_idx, n_possible_guests_per_host, n_to_sample)



Aucun commentaire:

Enregistrer un commentaire