mercredi 30 octobre 2019

How to find the largest multiple of n that fits in a 32 bit integer

I am reading Functional Programming in Scala and am having trouble understanding a piece of code. I have checked the errata for the book and the passage in question does not have a misprint. (Actually, it does have a misprint, but the misprint does not affect the code that I have a question about.)

The code in question calculates a pseudo-random, non-negative integer that is less than some upper bound. The function that does this is called nonNegativeLessThan.

trait RNG {
  def nextInt: (Int, RNG) // Should generate a random `Int`. 
}

case class Simple(seed: Long) extends RNG {
  def nextInt: (Int, RNG) = {
    val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL // `&` is bitwise AND. We use the current seed to generate a new seed.
    val nextRNG = Simple(newSeed) // The next state, which is an `RNG` instance created from the new seed.
    val n = (newSeed >>> 16).toInt // `>>>` is right binary shift with zero fill. The value `n` is our new pseudo-random integer.
    (n, nextRNG) // The return value is a tuple containing both a pseudo-random integer and the next `RNG` state.
  }
}

type Rand[+A] = RNG => (A, RNG)

def nonNegativeInt(rng: RNG): (Int, RNG) = {
  val (myInt, nextRng) = rng.nextInt
  val nnInt = {
    if (myInt == Int.MinValue) Int.MaxValue
    else if (myInt < 0) -myInt
    else myInt
  }
  (nnInt, nextRng)
}

def nonNegativeLessThan(n: Int): Rand[Int] = { rng =>
  val (i, rng2) = nonNegativeInt(rng)
  val mod = i % n
  if (i + (n-1) - mod >= 0) (mod, rng2)
  else nonNegativeLessThan(n)(rng2)
}

I have trouble understanding the following code in nonNegativeLessThan that looks like this: if (i + (n-1) - mod >= 0) (mod, rng2), etc.

The book explains that this entire if-else expression is necessary because a naive implementation that simply takes the mod of the result of nonNegativeInt would be slightly skewed toward lower values since Int.MaxValue is not guaranteed to be a multiple of n. Therefore, this code is meant to check if the generated output of nonNegativeInt would be larger than the largest multiple of n that fits inside a 32 bit value. If the generated number is larger than the largest multiple of n that fits inside a 32 bit value, the function recalculates the pseudo-random number.

To elaborate, the naive implementation would look like this:

def naiveNonNegativeLessThan(n: Int): Rand[Int] = map(nonNegativeInt){_ % n}

where map is defined as follows

def map[A,B](s: Rand[A])(f: A => B): Rand[B] = {
  rng => 
    val (a, rng2) = s(rng)
    (f(a), rng2)
}

To repeat, this naive implementation is not desirable because of a slight skew towards lower values when Int.MaxValue is not a perfect multiple of n.

So, to reiterate the question: what does the following code do, and how does it help us determine whether a number is smaller that the largest multiple of n that fits inside a 32 bit integer? I am talking about this code inside nonNegativeLessThan:

if (i + (n-1) - mod >= 0) (mod, rng2)
else nonNegativeLessThan(n)(rng2)



Aucun commentaire:

Enregistrer un commentaire