samedi 13 mai 2023

Spark randomize values of a primary key column in another unique column ensuring no collisions

So I have this problem which is bugging me for a while ... I need to generate all possible values of ssn format xx-xxx-xxxx where x is any number from 0 to 9 in 1 column of a spark dataframe and then add a column which basically should contain the same values, but different than the initial column on each row.

Practical example with a small set of 0/1 possible values:

+-----------+-----------+
|ssn        |false_ssn  |
+-----------+-----------+
|00-000-0000|01-001-0000|
|00-000-0001|00-001-0001|
|00-001-0000|00-000-0001|
|00-001-0001|01-000-0000|
|01-000-0000|01-001-0001|
|01-000-0001|00-000-0000|
|01-001-0000|00-001-0000|
|01-001-0001|01-000-0001|
+-----------+-----------+

I managed to do this using this python code:

def generate_ssn_lookup_data(range_group_1: int, range_group_2: int, range_group_3: int,
                             spark: SparkSession) -> DataFrame:
    # Generate data
    sensitive_values = [f"{i:02d}-{j:03d}-{k:04d}"
                        for i in range(range_group_1)
                        for j in range(range_group_2)
                        for k in range(range_group_3)]
    mock_values = []
    used_values = set()
    for sensitive in sensitive_values:
        mock = sensitive
        while mock == sensitive or mock in used_values:
            mock = f"{random.choice(range(range_group_1)):02d}-{random.choice(range(range_group_2)):03d}-{random.choice(range(range_group_3)):04d}"
        mock_values.append(mock)
        used_values.add(mock)
    data = [(sensitive, mock) for sensitive, mock in zip(sensitive_values, mock_values)]

    df = spark.createDataFrame(data, ["ssn", "false_ssn"])
    return df

#Call 
df = generate_ssn_lookup_data(2, 2, 2, some_spark_session)

But you can imagine how this performs when trying to generate 1billion records with generate_ssn_lookup_data(100, 1000, 10000, some_spark_session)

So I also tried using spark native functions with the code below, but I can't avoid collisions where false_ssn = ssn (I added the hole debug code here, for actual values use a call like: generate_ssn_lookup_data(0, 9)) ... I'm guessing the fact that I try to ensure unique randoms is not enough, as actually the row_numbers could match for the same values... that loop though is worrying.

from pyspark.sql import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.functions import expr, rand, col, lit, concat, floor, when, row_number

spark = SparkSession.builder.master("local[*]") \
    .config("spark.driver.memory", "8g") \
    .config("spark.executor.memory", "8g") \
    .appName("App").getOrCreate()


def generate_ssn_lookup_data(start: int, end: int) -> DataFrame:
    df = spark.range(1).select(expr("null").alias("dummy"))
    ssn_characters = 9
    for _ in range(ssn_characters):
        df = df.crossJoin(spark.range(start, end + 1).select(col("id").alias("col" + str(_))))

    df = df.withColumn("ssn", concat(
        col("col0"), col("col1"), lit("-"),
        col("col2"), col("col3"), col("col4"), lit("-"),
        col("col5"), col("col6"), col("col7"), col("col8")
    )) \
        .withColumn("random", floor(rand() * pow(10, 15))) \
        .withColumn("random2", floor(rand() * pow(10, 15)))

    df = ensure_unique_random(df, random_col1="random", random_col2="random2")
    left = df.withColumn("rnd", row_number().over(Window.orderBy("random")))
    right = df.withColumnRenamed("ssn", "false_ssn").withColumn("rnd", row_number().over(Window.orderBy("random2")))
    df = left.alias("l").join(right.alias("r"), left.rnd == right.rnd).drop("rnd")

    return df.select("l.ssn", "r.false_ssn")


def ensure_unique_random(df: DataFrame, random_col1, random_col2) -> DataFrame:
    while df.where(f"{random_col1} = {random_col2}").count() != 0:
        df.where(f"{random_col1} = {random_col2}").show(truncate=False)
        df = df.withColumn(random_col2,
                           when(col(random_col1) == col(random_col2), 100).otherwise(
                               col(random_col2)))
    return df


df = generate_ssn_lookup_data(0, 1)
df.cache()
print("Generated Df ")
df.show(truncate=False)
df_count = df.count()
print(f"Generated Df Count: {df_count}")
unique_ssn_count = df.select("ssn").distinct().count()
print(f"Distinct SSN Count: {unique_ssn_count}")
false_ssn_count = df.select("false_ssn").distinct().count()
print(f"Distinct False Count: {false_ssn_count} ")
false_non_false_collisions = df.where("ssn = false_ssn")
collision_count = false_non_false_collisions.count()
print(f"False Non False Collisions: {false_non_false_collisions.count()}")
false_non_false_collisions.show(truncate=False)
assert (collision_count == 0)

Basically the problem is shuffling the values of columnA into column ensuring no duplicates appear on columnB and no columnB will equal columnA.

Thanks in advance.




Aucun commentaire:

Enregistrer un commentaire