Star Schema Data Migration: Snowflake to Databricks
Star Schema Design
Dimension Tables
-- Products Dimension Table --
CREATE TABLE retail_db.products (
product_id INT NOT NULL PRIMARY KEY,
product_name VARCHAR(255) NOT NULL,
category VARCHAR(100),
price DECIMAL(10,2) NOT NULL,
created_at TIMESTAMP NOT NULL
);
-- Customers Dimension Table --
CREATE TABLE retail_db.customers (
customer_id INT NOT NULL PRIMARY KEY,
first_name VARCHAR(100) NOT NULL,
last_name VARCHAR(100) NOT NULL,
email VARCHAR(255) NOT NULL,
country VARCHAR(100),
created_at TIMESTAMP NOT NULL
);
Fact Table
-- Sales Fact Table --
CREATE TABLE retail_db.sales (
sale_id INT NOT NULL PRIMARY KEY,
customer_id INT NOT NULL REFERENCES retail_db.customers(customer_id),
product_id INT NOT NULL REFERENCES retail_db.products(product_id),
quantity INT NOT NULL,
unit_price DECIMAL(10,2) NOT NULL,
total_amount DECIMAL(10,2) NOT NULL,
sale_date TIMESTAMP NOT NULL
);
Schema Definitions
# Sales Schema
sales_schema = StructType([
StructField("sale_id", IntegerType(), False),
StructField("customer_id", IntegerType(), False),
StructField("product_id", IntegerType(), False),
StructField("quantity", IntegerType(), False),
StructField("unit_price", DecimalType(10,2), False),
StructField("total_amount", DecimalType(10,2), False),
StructField("sale_date", TimestampType(), False)
])
Example Star Schema Queries
Sales Analysis Query
SELECT
p.category,
c.country,
COUNT(DISTINCT s.sale_id) as total_sales,
SUM(s.total_amount) as total_revenue
FROM retail_db.sales s
JOIN retail_db.products p ON s.product_id = p.product_id
JOIN retail_db.customers c ON s.customer_id = c.customer_id
GROUP BY p.category, c.country;
Data Validation Query
-- Check for orphaned sales records --
SELECT COUNT(*) as invalid_products
FROM retail_db.sales s
LEFT JOIN retail_db.products p ON s.product_id = p.product_id
WHERE p.product_id IS NULL;
Migration Process
- Dimension tables (products, customers) are migrated first
- Fact table (sales) is migrated last
- Foreign key relationships are validated
- Star schema analysis is performed to verify data integrity
Important Notes
- Maintain referential integrity between fact and dimension tables
- Consider partitioning the sales table by date for better performance
- Monitor memory usage when joining large fact tables
- Consider incremental loads for the fact table in production
Sample Output
Starting migration for table: products
Records read from Snowflake: 1000
Records written to Databricks: 1000
Successfully migrated table: products
Starting migration for table: customers
Records read from Snowflake: 5000
Records written to Databricks: 5000
Successfully migrated table: customers
Starting migration for table: sales
Records read from Snowflake: 50000
Records written to Databricks: 50000
Successfully migrated table: sales
All foreign key relationships are valid
Analyzing star schema...
+-------------+----------+-------------+---------------+
| category | country | total_sales | total_revenue |
+-------------+----------+-------------+---------------+
| Electronics | USA | 15,000 | 2,500,000.00 |
| Clothing | Canada | 12,000 | 1,800,000.00 |
| Books | UK | 8,000 | 400,000.00 |
...
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DecimalType, TimestampType
from pyspark.sql.functions import col, count
# Initialize Spark Session
spark = SparkSession.builder \
.appName("Snowflake to Databricks Migration") \
.getOrCreate()
# Snowflake connection parameters
snowflake_options = {
"sfURL": "your_snowflake_account.snowflakecomputing.com",
"sfUser": "your_username",
"sfPassword": "your_password",
"sfDatabase": "RETAIL_DB",
"sfSchema": "PUBLIC",
"sfWarehouse": "COMPUTE_WH"
}
# Define schemas for our tables
product_schema = StructType([
StructField("product_id", IntegerType(), False),
StructField("product_name", StringType(), False),
StructField("category", StringType(), True),
StructField("price", DecimalType(10,2), False),
StructField("created_at", TimestampType(), False)
])
customer_schema = StructType([
StructField("customer_id", IntegerType(), False),
StructField("first_name", StringType(), False),
StructField("last_name", StringType(), False),
StructField("email", StringType(), False),
StructField("country", StringType(), True),
StructField("created_at", TimestampType(), False)
])
sales_schema = StructType([
StructField("sale_id", IntegerType(), False),
StructField("customer_id", IntegerType(), False),
StructField("product_id", IntegerType(), False),
StructField("quantity", IntegerType(), False),
StructField("unit_price", DecimalType(10,2), False),
StructField("total_amount", DecimalType(10,2), False),
StructField("sale_date", TimestampType(), False)
])
def read_from_snowflake(table_name, schema):
"""Read data from Snowflake table"""
return spark.read \
.format("snowflake") \
.options(**snowflake_options) \
.option("dbtable", table_name) \
.schema(schema) \
.load()
def write_to_databricks(df, table_name):
"""Write DataFrame to Databricks Delta table"""
df.write \
.format("delta") \
.mode("overwrite") \
.saveAsTable(f"retail_db.{table_name}")
def validate_foreign_keys(sales_df, products_df, customers_df):
"""Validate foreign key relationships"""
# Check for orphaned sales records
valid_products = sales_df.join(products_df, "product_id", "left_anti").count()
valid_customers = sales_df.join(customers_df, "customer_id", "left_anti").count()
if valid_products > 0:
print(f"Warning: Found {valid_products} sales with invalid product_id")
if valid_customers > 0:
print(f"Warning: Found {valid_customers} sales with invalid customer_id")
return valid_products == 0 and valid_customers == 0
def analyze_star_schema():
"""Perform analysis on the star schema"""
# Example analysis query
analysis_query = """
SELECT
p.category,
c.country,
COUNT(DISTINCT s.sale_id) as total_sales,
SUM(s.total_amount) as total_revenue
FROM retail_db.sales s
JOIN retail_db.products p ON s.product_id = p.product_id
JOIN retail_db.customers c ON s.customer_id = c.customer_id
GROUP BY p.category, c.country
"""
analysis_results = spark.sql(analysis_query)
return analysis_results
def migrate_table(table_name, schema):
"""Migrate a single table from Snowflake to Databricks"""
print(f"Starting migration for table: {table_name}")
try:
# Read from Snowflake
df = read_from_snowflake(table_name, schema)
# Print some statistics
print(f"Records read from Snowflake: {df.count()}")
# Write to Databricks
write_to_databricks(df, table_name)
# Verify the write
verification_df = spark.sql(f"SELECT COUNT(*) as count FROM retail_db.{table_name}")
print(f"Records written to Databricks: {verification_df.collect()[0]['count']}")
print(f"Successfully migrated table: {table_name}")
return df
except Exception as e:
print(f"Error migrating table {table_name}: {str(e)}")
raise
def main():
# Create database if it doesn't exist
spark.sql("CREATE DATABASE IF NOT EXISTS retail_db")
# Define tables to migrate with their schemas
tables = {
"products": product_schema,
"customers": customer_schema,
"sales": sales_schema
}
# Migrate dimension tables first
products_df = migrate_table("products", product_schema)
customers_df = migrate_table("customers", customer_schema)
# Migrate fact table
sales_df = migrate_table("sales", sales_schema)
# Validate foreign key relationships
if validate_foreign_keys(sales_df, products_df, customers_df):
print("All foreign key relationships are valid")
# Perform star schema analysis
print("\nAnalyzing star schema...")
analysis_results = analyze_star_schema()
analysis_results.show()
if __name__ == "__main__":
main()