Star Schema Data Migration: Snowflake to Databricks

Star Schema Design

Dimension Tables

-- Products Dimension Table --
CREATE TABLE retail_db.products (
    product_id INT NOT NULL PRIMARY KEY,
    product_name VARCHAR(255) NOT NULL,
    category VARCHAR(100),
    price DECIMAL(10,2) NOT NULL,
    created_at TIMESTAMP NOT NULL
);

-- Customers Dimension Table --
CREATE TABLE retail_db.customers (
    customer_id INT NOT NULL PRIMARY KEY,
    first_name VARCHAR(100) NOT NULL,
    last_name VARCHAR(100) NOT NULL,
    email VARCHAR(255) NOT NULL,
    country VARCHAR(100),
    created_at TIMESTAMP NOT NULL
);
    

Fact Table

-- Sales Fact Table --
CREATE TABLE retail_db.sales (
    sale_id INT NOT NULL PRIMARY KEY,
    customer_id INT NOT NULL REFERENCES retail_db.customers(customer_id),
    product_id INT NOT NULL REFERENCES retail_db.products(product_id),
    quantity INT NOT NULL,
    unit_price DECIMAL(10,2) NOT NULL,
    total_amount DECIMAL(10,2) NOT NULL,
    sale_date TIMESTAMP NOT NULL
);
    

Schema Definitions

# Sales Schema
sales_schema = StructType([
    StructField("sale_id",      IntegerType(),     False),
    StructField("customer_id",  IntegerType(),     False),
    StructField("product_id",   IntegerType(),     False),
    StructField("quantity",     IntegerType(),     False),
    StructField("unit_price",   DecimalType(10,2), False),
    StructField("total_amount", DecimalType(10,2), False),
    StructField("sale_date",    TimestampType(),   False)
])
    

Example Star Schema Queries

Sales Analysis Query

SELECT
    p.category,
    c.country,
    COUNT(DISTINCT s.sale_id) as total_sales,
    SUM(s.total_amount) as total_revenue
FROM retail_db.sales s
JOIN retail_db.products p ON s.product_id = p.product_id
JOIN retail_db.customers c ON s.customer_id = c.customer_id
GROUP BY p.category, c.country;
    

Data Validation Query

-- Check for orphaned sales records --
SELECT COUNT(*) as invalid_products
FROM retail_db.sales s
LEFT JOIN retail_db.products p ON s.product_id = p.product_id
WHERE p.product_id IS NULL;
    

Migration Process

Important Notes

Sample Output


Starting migration for table:  products
Records read from Snowflake:   1000
Records written to Databricks: 1000
Successfully migrated table:   products

Starting migration for table:  customers
Records read from Snowflake:   5000
Records written to Databricks: 5000
Successfully migrated table:   customers

Starting migration for table:  sales
Records read from Snowflake:   50000
Records written to Databricks: 50000
Successfully migrated table:   sales

All foreign key relationships are valid

Analyzing star schema...
+-------------+----------+-------------+---------------+
|   category  |  country | total_sales | total_revenue |
+-------------+----------+-------------+---------------+
| Electronics | USA      |   15,000    | 2,500,000.00  |
| Clothing    | Canada   |   12,000    | 1,800,000.00  |
| Books       | UK       |    8,000    |   400,000.00  |
...


    


from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DecimalType, TimestampType
from pyspark.sql.functions import col, count

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("Snowflake to Databricks Migration") \
    .getOrCreate()

# Snowflake connection parameters
snowflake_options = {
    "sfURL": "your_snowflake_account.snowflakecomputing.com",
    "sfUser": "your_username",
    "sfPassword": "your_password",
    "sfDatabase": "RETAIL_DB",
    "sfSchema": "PUBLIC",
    "sfWarehouse": "COMPUTE_WH"
}

# Define schemas for our tables
product_schema = StructType([
    StructField("product_id", IntegerType(), False),
    StructField("product_name", StringType(), False),
    StructField("category", StringType(), True),
    StructField("price", DecimalType(10,2), False),
    StructField("created_at", TimestampType(), False)
])

customer_schema = StructType([
    StructField("customer_id", IntegerType(), False),
    StructField("first_name", StringType(), False),
    StructField("last_name", StringType(), False),
    StructField("email", StringType(), False),
    StructField("country", StringType(), True),
    StructField("created_at", TimestampType(), False)
])

sales_schema = StructType([
    StructField("sale_id", IntegerType(), False),
    StructField("customer_id", IntegerType(), False),
    StructField("product_id", IntegerType(), False),
    StructField("quantity", IntegerType(), False),
    StructField("unit_price", DecimalType(10,2), False),
    StructField("total_amount", DecimalType(10,2), False),
    StructField("sale_date", TimestampType(), False)
])

def read_from_snowflake(table_name, schema):
    """Read data from Snowflake table"""
    return spark.read \
        .format("snowflake") \
        .options(**snowflake_options) \
        .option("dbtable", table_name) \
        .schema(schema) \
        .load()

def write_to_databricks(df, table_name):
    """Write DataFrame to Databricks Delta table"""
    df.write \
        .format("delta") \
        .mode("overwrite") \
        .saveAsTable(f"retail_db.{table_name}")

def validate_foreign_keys(sales_df, products_df, customers_df):
    """Validate foreign key relationships"""
    # Check for orphaned sales records
    valid_products = sales_df.join(products_df, "product_id", "left_anti").count()
    valid_customers = sales_df.join(customers_df, "customer_id", "left_anti").count()
    
    if valid_products > 0:
        print(f"Warning: Found {valid_products} sales with invalid product_id")
    if valid_customers > 0:
        print(f"Warning: Found {valid_customers} sales with invalid customer_id")
    
    return valid_products == 0 and valid_customers == 0

def analyze_star_schema():
    """Perform analysis on the star schema"""
    # Example analysis query
    analysis_query = """
    SELECT 
        p.category,
        c.country,
        COUNT(DISTINCT s.sale_id) as total_sales,
        SUM(s.total_amount) as total_revenue
    FROM retail_db.sales s
    JOIN retail_db.products p ON s.product_id = p.product_id
    JOIN retail_db.customers c ON s.customer_id = c.customer_id
    GROUP BY p.category, c.country
    """
    
    analysis_results = spark.sql(analysis_query)
    return analysis_results

def migrate_table(table_name, schema):
    """Migrate a single table from Snowflake to Databricks"""
    print(f"Starting migration for table: {table_name}")
    
    try:
        # Read from Snowflake
        df = read_from_snowflake(table_name, schema)
        
        # Print some statistics
        print(f"Records read from Snowflake: {df.count()}")
        
        # Write to Databricks
        write_to_databricks(df, table_name)
        
        # Verify the write
        verification_df = spark.sql(f"SELECT COUNT(*) as count FROM retail_db.{table_name}")
        print(f"Records written to Databricks: {verification_df.collect()[0]['count']}")
        
        print(f"Successfully migrated table: {table_name}")
        return df
        
    except Exception as e:
        print(f"Error migrating table {table_name}: {str(e)}")
        raise

def main():
    # Create database if it doesn't exist
    spark.sql("CREATE DATABASE IF NOT EXISTS retail_db")
    
    # Define tables to migrate with their schemas
    tables = {
        "products": product_schema,
        "customers": customer_schema,
        "sales": sales_schema
    }
    
    # Migrate dimension tables first
    products_df = migrate_table("products", product_schema)
    customers_df = migrate_table("customers", customer_schema)
    
    # Migrate fact table
    sales_df = migrate_table("sales", sales_schema)
    
    # Validate foreign key relationships
    if validate_foreign_keys(sales_df, products_df, customers_df):
        print("All foreign key relationships are valid")
    
    # Perform star schema analysis
    print("\nAnalyzing star schema...")
    analysis_results = analyze_star_schema()
    analysis_results.show()

if __name__ == "__main__":
    main()