samedi 29 avril 2023

JAX Random generator: random normal numbers, seem to be returned sorted and not completely random

I'm trying to learn jax from 0. I've been trying to do some synthetic linear relationship data, so that

Y = B1*X + B0 + epsilon

My surprise has been that, when using the raw epsilons, the data seems to be ordered, so the points at the extremes get a higher error than the ones at the center. The shuffled version works fine, with errors normally distributed around the linear relationship

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import matplotlib.pyplot as plt

n = 100
p = 2  # number of parameters including linear bias
key = random.PRNGKey(0)

X = random.uniform(key, (n, p - 1))
X = jnp.concatenate([jnp.ones((n, 1)), X], axis=1)
B = random.randint(key, (p, 1), 0, 10)

epsilon = random.normal(key, (n, 1))*0.3

y = X.dot(B) + epsilon
y_s = X.dot(B) + random.shuffle(key, epsilon)
plt.scatter(X[:,1], y, label='Original order')
plt.scatter(X[:,1], y_s, label='Shuffled')
plt.legend()

enter image description here Why is this happening?




Aucun commentaire:

Enregistrer un commentaire