samedi 4 janvier 2020

Sharing class among processes in multiprocessing

I implemented the following code:

import numpy as np
import time
import multiprocessing as mp

class RandomStater():

    def __init__(self, num_per_chunks=1e7, seed = 42):
        np.random.RandomState(seed)
        self.num_per_chunks = int(num_per_chunks)
        self.d = {}
        self.i = {}

    def _choice_n(self, n):
        if self.i[n] % self.num_per_chunks == 0:
            random_numbers = np.random.choice(n, size=self.num_per_chunks)
        self.i[n] += 1
        return random_numbers[self.i[n]]

    def choice(self, actions, p=None):
        num_actions = len(actions)
        if num_actions not in self.d:
            print('Creating for action', num_actions)
            self.i[num_actions] = 0
            self.d[num_actions] = self._choice_n(num_actions)
        return actions[self.d[num_actions]]

    def shuffle(self, l):
        np.random.shuffle(l)


class SampleNumber():
    def __init__(self):
        self._random_start = RandomStater()

    def sample_number(self, _):
        return self._random_start.choice([1,2,3])

Basically, the class RandomStater is a batch version of the method np.random.choice, which pre-computes a large number of random samples to be used next. The second class instead is a test class to collect in sequence these random samples.
I would like to use the method sample_number among different processes in parallel. What should happen is that the values self.d[num_actions] and self.i[num_actions] are created just the first time the first process execute the method choice, while for all the other calls the counter self.i[num_actions] is incremented and the result is so returned very quickly.

I guess the class should be shared among the processes or something, but I am new to multiprocessing and do not well how to do that.

Up to now I tried the following without any luck:

ns = range(0, 2000)
rn = RandomStater()
sn = SampleNumber()
s = time.time()

p = mp.Pool(4)
got = p.map(sn.sample_number, ns)
print(got)
print(time.time() - s)




Aucun commentaire:

Enregistrer un commentaire