I'm trying to learn jax
from 0. I've been trying to do some synthetic linear relationship data, so that
Y = B1*X + B0 + epsilon
My surprise has been that, when using the raw epsilons, the data seems to be ordered, so the points at the extremes get a higher error than the ones at the center. The shuffled version works fine, with errors normally distributed around the linear relationship
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import matplotlib.pyplot as plt
n = 100
p = 2 # number of parameters including linear bias
key = random.PRNGKey(0)
X = random.uniform(key, (n, p - 1))
X = jnp.concatenate([jnp.ones((n, 1)), X], axis=1)
B = random.randint(key, (p, 1), 0, 10)
epsilon = random.normal(key, (n, 1))*0.3
y = X.dot(B) + epsilon
y_s = X.dot(B) + random.shuffle(key, epsilon)
plt.scatter(X[:,1], y, label='Original order')
plt.scatter(X[:,1], y_s, label='Shuffled')
plt.legend()
Aucun commentaire:
Enregistrer un commentaire