Data Migration from Snowflake to Databricks

Overview

Prerequisites

Example Queries

Snowflake Source Query

SELECT 
    product_id,
    product_name,
    category,
    price,
    created_at
FROM RETAIL_DB.PUBLIC.products;

Databricks Target Query

SELECT COUNT(*) as record_count
FROM retail_db.products;
    

Sample Output

Starting migration for table: products
Records read from Snowflake: 1000
Records written to Databricks: 1000
Successfully migrated table: products

Starting migration for table: customers
Records read from Snowflake: 5000
Records written to Databricks: 5000
Successfully migrated table: customers

Important Notes


from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DecimalType, TimestampType
import os
from decimal import Decimal

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("Snowflake to Databricks Migration") \
    .getOrCreate()

# Snowflake connection parameters
snowflake_options = {
    "sfURL":       "snowflake_account.snowflakecomputing.com",
    "sfUser":      "username",
    "sfPassword":  "password",
    "sfDatabase":  "RETAIL_DB",
    "sfSchema":    "PUBLIC",
    "sfWarehouse": "COMPUTE_WH"
}

# Define schemas for our tables
product_schema = StructType([
    StructField("product_id",   IntegerType(),     False),
    StructField("product_name", StringType(),      False),
    StructField("category",     StringType(),      True),
    StructField("price",        DecimalType(10,2), False),
    StructField("created_at",   TimestampType(),   False)
])

customer_schema = StructType([
    StructField("customer_id", IntegerType(),   False),
    StructField("first_name",  StringType(),    False),
    StructField("last_name",   StringType(),    False),
    StructField("email",       StringType(),    False),
    StructField("country",     StringType(),    True),
    StructField("created_at",  TimestampType(), False)
])

def read_from_snowflake(table_name, schema):
    """Read data from Snowflake table"""
    return spark.read \
        .format("snowflake") \
        .options(**snowflake_options) \
        .option("dbtable", table_name) \
        .schema(schema) \
        .load()

def write_to_databricks(df, table_name):
    """Write DataFrame to Databricks Delta table"""
    # Write as managed delta table
    df.write \
        .format("delta") \
        .mode("overwrite") \
        .saveAsTable(f"retail_db.{table_name}")

def migrate_table(table_name, schema):
    """Migrate a single table from Snowflake to Databricks"""
    print(f"Starting migration for table: {table_name}")
    
    try:
        # Read from Snowflake
        df = read_from_snowflake(table_name, schema)
        
        # Print some statistics
        print(f"Records read from Snowflake: {df.count()}")
        
        # Write to Databricks
        write_to_databricks(df, table_name)
        
        # Verify the write
        verification_df = spark.sql(f"SELECT COUNT(*) as count FROM retail_db.{table_name}")
        print(f"Records written to Databricks: {verification_df.collect()[0]['count']}")
        
        print(f"Successfully migrated table: {table_name}")
        
    except Exception as e:
        print(f"Error migrating table {table_name}: {str(e)}")
        raise

def main():
    # Create database if it doesn't exist
    spark.sql("CREATE DATABASE IF NOT EXISTS retail_db")
    
    # Define tables to migrate with their schemas
    tables = {
        "products": product_schema,
        "customers": customer_schema
    }
    
    # Migrate each table
    for table_name, schema in tables.items():
        migrate_table(table_name, schema)

if __name__ == "__main__":
    main()