Flask — Sample Code

1. Overview

This page is a code-heavy cookbook for building production Flask services, with a bias toward ML-serving use cases that AI/ML data engineers hit day to day: loading a scikit-learn or PyTorch model at startup, validating request payloads, offloading heavy inference to Celery, and exposing Kubernetes-ready liveness and readiness probes. Generic CRUD, auth, uploads, and pagination are included because a real ML service needs them too.

Prerequisites:

Every snippet below is runnable as shown. Imports are complete. No ... placeholders.

2. Minimal Hello World

The smallest Flask app — a single file, a single route, dev server via flask run.


# app.py
from flask import Flask, jsonify

app = Flask(__name__)

@app.route("/")
def index():
    return jsonify(message="hello from flask", status="ok")

if __name__ == "__main__":
    # dev only; use gunicorn in prod
    app.run(host="0.0.0.0", port=5000, debug=True)

export FLASK_APP=app.py
flask run --host 0.0.0.0 --port 5000
# or
python app.py

For production use a WSGI server — never app.run():


gunicorn -w 4 -b 0.0.0.0:5000 app:app

3. CRUD REST API with SQLAlchemy

A full User resource with SQLite for dev, SQLAlchemy ORM, and Marshmallow schemas for validation and serialization. Notice the explicit 404/400 handling and the use of db.session.get (SQLAlchemy 2.x style).


# crud_app.py
from datetime import datetime
from flask import Flask, jsonify, request, abort
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
from marshmallow import fields, validate, ValidationError

app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///users.db"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False

db = SQLAlchemy(app)
ma = Marshmallow(app)


class User(db.Model):
    __tablename__ = "users"
    id = db.Column(db.Integer, primary_key=True)
    email = db.Column(db.String(255), unique=True, nullable=False, index=True)
    name = db.Column(db.String(120), nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False)


class UserSchema(ma.SQLAlchemyAutoSchema):
    class Meta:
        model = User
        load_instance = True
        sqla_session = db.session

    email = fields.Email(required=True)
    name = fields.Str(required=True, validate=validate.Length(min=1, max=120))


user_schema = UserSchema()
users_schema = UserSchema(many=True)


@app.errorhandler(ValidationError)
def handle_validation(err):
    return jsonify(errors=err.messages), 400


@app.route("/users", methods=["POST"])
def create_user():
    data = request.get_json(silent=True) or {}
    user = user_schema.load(data)             # validates + hydrates User instance
    db.session.add(user)
    db.session.commit()
    return user_schema.dump(user), 201


@app.route("/users", methods=["GET"])
def list_users():
    q = User.query.order_by(User.id.asc()).all()
    return jsonify(users_schema.dump(q))


@app.route("/users/<int:user_id>", methods=["GET"])
def get_user(user_id):
    user = db.session.get(User, user_id) or abort(404)
    return user_schema.dump(user)


@app.route("/users/<int:user_id>", methods=["PUT"])
def update_user(user_id):
    user = db.session.get(User, user_id) or abort(404)
    data = request.get_json(silent=True) or {}
    updated = user_schema.load(data, instance=user, partial=True)
    db.session.commit()
    return user_schema.dump(updated)


@app.route("/users/<int:user_id>", methods=["DELETE"])
def delete_user(user_id):
    user = db.session.get(User, user_id) or abort(404)
    db.session.delete(user)
    db.session.commit()
    return "", 204


if __name__ == "__main__":
    with app.app_context():
        db.create_all()
    app.run(debug=True)

The underlying SQLite schema SQLAlchemy emits is equivalent to:


CREATE TABLE users (
    id         INTEGER PRIMARY KEY AUTOINCREMENT,
    email      VARCHAR(255) NOT NULL UNIQUE,
    name       VARCHAR(120) NOT NULL,
    created_at DATETIME     NOT NULL
);
CREATE INDEX ix_users_email ON users(email);

4. JWT Authentication

flask-jwt-extended with access + refresh tokens and refresh rotation. Access tokens are short-lived (15 min), refresh tokens longer (30 days); every /refresh mints a new refresh token and the old one should be denylisted in a real deployment (Redis set with TTL).


# auth_app.py
from datetime import timedelta
from flask import Flask, jsonify, request
from flask_sqlalchemy import SQLAlchemy
from flask_jwt_extended import (
    JWTManager, create_access_token, create_refresh_token,
    jwt_required, get_jwt_identity, get_jwt,
)
from werkzeug.security import generate_password_hash, check_password_hash

app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///auth.db"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
app.config["JWT_SECRET_KEY"] = "change-me-in-prod"
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)

db = SQLAlchemy(app)
jwt = JWTManager(app)

# simple in-memory refresh-token denylist; use Redis in prod
REFRESH_DENYLIST: set[str] = set()


class Account(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    email = db.Column(db.String(255), unique=True, nullable=False)
    pw_hash = db.Column(db.String(255), nullable=False)


@jwt.token_in_blocklist_loader
def is_revoked(jwt_header, jwt_payload):
    return jwt_payload["jti"] in REFRESH_DENYLIST


@app.post("/register")
def register():
    data = request.get_json() or {}
    if Account.query.filter_by(email=data["email"]).first():
        return jsonify(error="email exists"), 409
    acct = Account(email=data["email"], pw_hash=generate_password_hash(data["password"]))
    db.session.add(acct)
    db.session.commit()
    return jsonify(id=acct.id), 201


@app.post("/login")
def login():
    data = request.get_json() or {}
    acct = Account.query.filter_by(email=data.get("email")).first()
    if not acct or not check_password_hash(acct.pw_hash, data.get("password", "")):
        return jsonify(error="bad credentials"), 401
    return jsonify(
        access_token=create_access_token(identity=str(acct.id)),
        refresh_token=create_refresh_token(identity=str(acct.id)),
    )


@app.post("/refresh")
@jwt_required(refresh=True)
def refresh():
    # rotate: deny the old refresh jti and mint a fresh pair
    old = get_jwt()
    REFRESH_DENYLIST.add(old["jti"])
    identity = get_jwt_identity()
    return jsonify(
        access_token=create_access_token(identity=identity),
        refresh_token=create_refresh_token(identity=identity),
    )


@app.get("/me")
@jwt_required()
def me():
    return jsonify(user_id=get_jwt_identity())


if __name__ == "__main__":
    with app.app_context():
        db.create_all()
    app.run(debug=True)

5. File Upload

Safe upload with secure_filename, a 16 MB ceiling via MAX_CONTENT_LENGTH, extension + magic-byte MIME validation. Never trust the client-supplied Content-Type — sniff the bytes.


# uploads.py
import os
import magic            # pip install python-magic (needs libmagic)
from pathlib import Path
from flask import Flask, request, jsonify, abort
from werkzeug.utils import secure_filename

UPLOAD_DIR = Path("/tmp/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)

ALLOWED_EXT = {".png", ".jpg", ".jpeg", ".pdf", ".csv"}
ALLOWED_MIME = {"image/png", "image/jpeg", "application/pdf", "text/csv", "text/plain"}

app = Flask(__name__)
app.config["MAX_CONTENT_LENGTH"] = 16 * 1024 * 1024   # 16 MB


@app.errorhandler(413)
def too_big(_):
    return jsonify(error="file too large, max 16MB"), 413


@app.post("/upload")
def upload():
    if "file" not in request.files:
        abort(400, "no file part")
    f = request.files["file"]
    if not f.filename:
        abort(400, "empty filename")

    safe = secure_filename(f.filename)
    ext = os.path.splitext(safe)[1].lower()
    if ext not in ALLOWED_EXT:
        abort(400, f"extension {ext} not allowed")

    head = f.stream.read(2048)
    f.stream.seek(0)
    mime = magic.from_buffer(head, mime=True)
    if mime not in ALLOWED_MIME:
        abort(400, f"mime {mime} not allowed")

    dest = UPLOAD_DIR / safe
    f.save(dest)
    return jsonify(filename=safe, mime=mime, size=dest.stat().st_size), 201


if __name__ == "__main__":
    app.run(debug=True)

6. ML Model Serving — scikit-learn

Two files: a training script that persists the model to disk, and a serving app that loads it once at startup (via the app factory — before_first_request is removed in Flask 2.3+). Input is validated by Pydantic; the response includes both the predicted class and a calibrated confidence from predict_proba.


# train.py  (run offline)
import joblib
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

X, y = load_iris(return_X_y=True)
pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", LogisticRegression(max_iter=500, multi_class="multinomial")),
])
pipe.fit(X, y)
joblib.dump({"model": pipe, "labels": ["setosa", "versicolor", "virginica"]},
            "iris_model.joblib")
print("saved iris_model.joblib")

# serve_sklearn.py
import joblib
import numpy as np
from flask import Flask, jsonify, request
from pydantic import BaseModel, Field, ValidationError, conlist


class PredictIn(BaseModel):
    # 4 iris features: sepal length/width, petal length/width (cm)
    features: conlist(float, min_length=4, max_length=4) = Field(...)


def create_app(model_path: str = "iris_model.joblib") -> Flask:
    app = Flask(__name__)
    bundle = joblib.load(model_path)     # loaded ONCE at process start
    app.config["MODEL"] = bundle["model"]
    app.config["LABELS"] = bundle["labels"]

    @app.post("/predict")
    def predict():
        try:
            payload = PredictIn(**(request.get_json() or {}))
        except ValidationError as e:
            return jsonify(errors=e.errors()), 400

        x = np.asarray(payload.features, dtype=np.float64).reshape(1, -1)
        proba = app.config["MODEL"].predict_proba(x)[0]
        idx = int(np.argmax(proba))
        return jsonify(
            label=app.config["LABELS"][idx],
            class_index=idx,
            confidence=float(proba[idx]),
            proba={app.config["LABELS"][i]: float(p) for i, p in enumerate(proba)},
        )

    return app


app = create_app()

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000)

curl -s -XPOST localhost:5000/predict \
     -H 'content-type: application/json' \
     -d '{"features": [5.1, 3.5, 1.4, 0.2]}'
# {"label":"setosa","class_index":0,"confidence":0.97,...}

7. ML Model Serving — PyTorch (batched, CUDA-aware)

A minimal torch.nn.Module, eval mode, torch.no_grad(), device selection, and a batched endpoint that accepts an array of input vectors and returns an array of predictions. Batching is critical for GPU throughput — one forward pass per HTTP call leaves the GPU idle 99% of the time.


# serve_torch.py
import torch
import torch.nn as nn
from flask import Flask, jsonify, request
from pydantic import BaseModel, ValidationError, conlist


class MLP(nn.Module):
    def __init__(self, in_dim=10, hidden=64, out_dim=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x):
        return self.net(x)


class BatchIn(BaseModel):
    # list of 10-dim vectors, max 512 per request
    inputs: conlist(conlist(float, min_length=10, max_length=10),
                    min_length=1, max_length=512)


def create_app(weights_path: str | None = None) -> Flask:
    app = Flask(__name__)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MLP().to(device)
    if weights_path:
        model.load_state_dict(torch.load(weights_path, map_location=device))
    model.eval()
    app.config["MODEL"] = model
    app.config["DEVICE"] = device

    @app.post("/predict")
    def predict():
        try:
            payload = BatchIn(**(request.get_json() or {}))
        except ValidationError as e:
            return jsonify(errors=e.errors()), 400

        x = torch.tensor(payload.inputs, dtype=torch.float32,
                         device=app.config["DEVICE"])
        with torch.no_grad():
            logits = app.config["MODEL"](x)
            probs = torch.softmax(logits, dim=-1)
            conf, cls = probs.max(dim=-1)

        return jsonify(
            device=str(app.config["DEVICE"]),
            batch_size=x.shape[0],
            predictions=[
                {"class": int(c), "confidence": float(p)}
                for c, p in zip(cls.tolist(), conf.tolist())
            ],
        )

    return app


app = create_app()

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5001)

8. Async Task Offloading with Celery

Flask should never block a gunicorn worker on a 30-second inference call. Push the job to Celery, return a task id, and let the client poll. Redis is both broker and result backend here.


# celery_app.py  -- shared Celery instance
from celery import Celery

celery = Celery(
    "ml_jobs",
    broker="redis://localhost:6379/0",
    backend="redis://localhost:6379/1",
)
celery.conf.update(
    task_serializer="json",
    result_serializer="json",
    accept_content=["json"],
    task_time_limit=300,
    task_soft_time_limit=270,
)

# tasks.py  -- worker-side
import time
import joblib
import numpy as np
from celery_app import celery

_BUNDLE = joblib.load("iris_model.joblib")   # loaded per worker process


@celery.task(name="tasks.batch_predict", bind=True, max_retries=2)
def batch_predict(self, rows: list[list[float]]) -> list[dict]:
    try:
        X = np.asarray(rows, dtype=np.float64)
        proba = _BUNDLE["model"].predict_proba(X)
        idx = proba.argmax(axis=1)
        return [
            {"label": _BUNDLE["labels"][i], "confidence": float(proba[r, i])}
            for r, i in enumerate(idx)
        ]
    except Exception as exc:
        raise self.retry(exc=exc, countdown=5)

# flask_dispatch.py  -- web-side
from flask import Flask, jsonify, request
from celery.result import AsyncResult
from celery_app import celery
import tasks  # noqa: F401  ensures task is registered

app = Flask(__name__)


@app.post("/jobs/predict")
def enqueue():
    rows = (request.get_json() or {}).get("rows", [])
    async_res = celery.send_task("tasks.batch_predict", args=[rows])
    return jsonify(task_id=async_res.id), 202


@app.get("/jobs/<task_id>")
def status(task_id):
    res = AsyncResult(task_id, app=celery)
    body = {"task_id": task_id, "state": res.state}
    if res.successful():
        body["result"] = res.result
    elif res.failed():
        body["error"] = str(res.result)
    return jsonify(body)


if __name__ == "__main__":
    app.run(debug=True)

# run worker
celery -A celery_app.celery worker --loglevel=info --concurrency=4

# run flask
python flask_dispatch.py

# enqueue
curl -s -XPOST localhost:5000/jobs/predict -H 'content-type: application/json' \
     -d '{"rows": [[5.1,3.5,1.4,0.2], [6.2,3.4,5.4,2.3]]}'

9. WebSocket with Flask-SocketIO

Bi-directional messaging for chat, live prediction streams, or training progress. Use the Redis message queue so multiple gunicorn/eventlet workers share state.


# socket_chat.py
from flask import Flask, render_template_string
from flask_socketio import SocketIO, emit, join_room, leave_room

app = Flask(__name__)
app.config["SECRET_KEY"] = "dev"
socketio = SocketIO(app, message_queue="redis://localhost:6379/2",
                    cors_allowed_origins="*")

PAGE = """
<script src="https://cdn.socket.io/4.7.5/socket.io.min.js"></script>
<script>
  const s = io();
  s.on('connect', () => s.emit('join', {room: 'lobby'}));
  s.on('message', m => console.log(m));
  function send(text){ s.emit('chat', {room:'lobby', text}); }
</script>
"""


@app.get("/")
def page():
    return render_template_string(PAGE)


@socketio.on("join")
def on_join(data):
    join_room(data["room"])
    emit("message", {"sys": f"joined {data['room']}"}, room=data["room"])


@socketio.on("leave")
def on_leave(data):
    leave_room(data["room"])


@socketio.on("chat")
def on_chat(data):
    emit("message", {"text": data["text"]}, room=data["room"])


if __name__ == "__main__":
    socketio.run(app, host="0.0.0.0", port=5000)

10. Pagination with Cursor Tokens

Offset pagination breaks down at scale (deep pages scan huge amounts of data and shift under concurrent writes). Cursor pagination uses the last row's sort key, base64-encoded as an opaque token.


# paginate.py
import base64
import json
from flask import Flask, jsonify, request
from flask_sqlalchemy import SQLAlchemy

app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///feed.db"
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(app)


class Post(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    created_at = db.Column(db.DateTime, nullable=False, index=True)
    title = db.Column(db.String(200), nullable=False)


def encode_cursor(created_at, pk) -> str:
    raw = json.dumps({"ts": created_at.isoformat(), "id": pk}).encode()
    return base64.urlsafe_b64encode(raw).decode().rstrip("=")


def decode_cursor(token: str) -> tuple[str, int]:
    pad = "=" * (-len(token) % 4)
    raw = base64.urlsafe_b64decode(token + pad)
    d = json.loads(raw)
    return d["ts"], d["id"]


@app.get("/posts")
def feed():
    limit = min(int(request.args.get("limit", 20)), 100)
    cursor = request.args.get("cursor")

    q = Post.query.order_by(Post.created_at.desc(), Post.id.desc())
    if cursor:
        ts, pk = decode_cursor(cursor)
        q = q.filter(
            (Post.created_at < ts)
            | ((Post.created_at == ts) & (Post.id < pk))
        )

    rows = q.limit(limit + 1).all()
    has_more = len(rows) > limit
    rows = rows[:limit]
    next_cursor = (
        encode_cursor(rows[-1].created_at, rows[-1].id) if has_more and rows else None
    )
    return jsonify(
        items=[{"id": r.id, "title": r.title,
                "created_at": r.created_at.isoformat()} for r in rows],
        next_cursor=next_cursor,
    )

11. Healthcheck & Readiness

Liveness and readiness are not the same probe. Liveness says "the process is alive — don't restart me." Readiness says "I've loaded my model and my dependencies are reachable — send me traffic." In Kubernetes, returning 200 on /ready before the model is loaded causes the service mesh to route traffic that will 500.


# health.py
import time
from flask import Flask, jsonify
from sqlalchemy import text
from flask_sqlalchemy import SQLAlchemy

app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///app.db"
db = SQLAlchemy(app)

STATE = {"model_loaded": False, "started_at": time.time()}


def _load_model_at_startup():
    # simulate a slow warmup (downloading weights, compiling, etc.)
    time.sleep(2)
    STATE["model_loaded"] = True


with app.app_context():
    _load_model_at_startup()


@app.get("/health")
def liveness():
    # cheap, no external deps — just confirms the process responds
    return jsonify(status="alive", uptime_s=int(time.time() - STATE["started_at"]))


@app.get("/ready")
def readiness():
    checks = {"model": STATE["model_loaded"], "db": False}
    try:
        db.session.execute(text("SELECT 1"))
        checks["db"] = True
    except Exception:
        pass
    ok = all(checks.values())
    return jsonify(ready=ok, checks=checks), (200 if ok else 503)

Matching Kubernetes probe config:


livenessProbe:
  httpGet:
    path: /health
    port: 5000
  initialDelaySeconds: 10
  periodSeconds: 15
  failureThreshold: 3
readinessProbe:
  httpGet:
    path: /ready
    port: 5000
  initialDelaySeconds: 5
  periodSeconds: 5
  failureThreshold: 2

12. Full App Factory Pattern

The real-world layout: a create_app() factory, config classes per environment, blueprints for modularity, extensions instantiated at module scope and bound inside the factory. This is the pattern every non-trivial Flask service should start from.


# config.py
import os


class BaseConfig:
    SECRET_KEY = os.environ.get("SECRET_KEY", "dev-secret")
    SQLALCHEMY_TRACK_MODIFICATIONS = False
    JSON_SORT_KEYS = False


class DevConfig(BaseConfig):
    DEBUG = True
    SQLALCHEMY_DATABASE_URI = "sqlite:///dev.db"


class TestConfig(BaseConfig):
    TESTING = True
    SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"


class ProdConfig(BaseConfig):
    DEBUG = False
    SQLALCHEMY_DATABASE_URI = os.environ["DATABASE_URL"]
    SECRET_KEY = os.environ["SECRET_KEY"]


CONFIG_MAP = {"dev": DevConfig, "test": TestConfig, "prod": ProdConfig}

# extensions.py
from flask_sqlalchemy import SQLAlchemy
from flask_jwt_extended import JWTManager

db = SQLAlchemy()
jwt = JWTManager()

# blueprints/health_bp.py
from flask import Blueprint, jsonify

health_bp = Blueprint("health", __name__)


@health_bp.get("/health")
def health():
    return jsonify(status="ok")

# blueprints/predict_bp.py
from flask import Blueprint, jsonify, request, current_app
import numpy as np

predict_bp = Blueprint("predict", __name__, url_prefix="/api/v1")


@predict_bp.post("/predict")
def predict():
    payload = request.get_json() or {}
    x = np.asarray(payload["features"], dtype=np.float64).reshape(1, -1)
    model = current_app.config["MODEL"]
    labels = current_app.config["LABELS"]
    proba = model.predict_proba(x)[0]
    idx = int(np.argmax(proba))
    return jsonify(label=labels[idx], confidence=float(proba[idx]))

# app/__init__.py
import os
import joblib
from flask import Flask
from config import CONFIG_MAP
from extensions import db, jwt
from blueprints.health_bp import health_bp
from blueprints.predict_bp import predict_bp


def create_app(env: str | None = None) -> Flask:
    env = env or os.environ.get("APP_ENV", "dev")
    app = Flask(__name__)
    app.config.from_object(CONFIG_MAP[env])

    # extensions
    db.init_app(app)
    jwt.init_app(app)

    # ML model: loaded once at startup, stored on app.config
    if env != "test":
        bundle = joblib.load(os.environ.get("MODEL_PATH", "iris_model.joblib"))
        app.config["MODEL"] = bundle["model"]
        app.config["LABELS"] = bundle["labels"]

    # blueprints
    app.register_blueprint(health_bp)
    app.register_blueprint(predict_bp)

    with app.app_context():
        db.create_all()

    return app

# wsgi.py  -- entry point for gunicorn
from app import create_app

app = create_app()

# dev
APP_ENV=dev python -m flask --app wsgi:app run

# prod
APP_ENV=prod DATABASE_URL=postgresql://user:pw@db/app \
  SECRET_KEY=$(openssl rand -hex 32) \
  gunicorn -w 4 -k gthread --threads 8 -b 0.0.0.0:5000 wsgi:app

From here you bolt on the pieces above as blueprints: auth routes, upload routes, Celery dispatch, Socket.IO, pagination. The factory stays small; the blueprints grow.