{ "cells": [ { "cell_type": "markdown", "id": "26146195", "metadata": {}, "source": [ "# Exercise 2 — Real or bogus? Image-triplet classification\n", "\n", "**Reference paper.** Duev et al. 2019, *Real-bogus classification for the Zwicky Transient Facility\n", "using deep learning*, [arXiv:1907.11259](https://arxiv.org/abs/1907.11259).\n", "\n", "### Why this matters\n", "ZTF generates ~1M alerts/night. Most are artifacts. BRAAI sits at the front of the\n", "pipeline and decides what's worth a human's time. The input is a *triplet*:\n", "science image, reference image, and difference image — stacked into 3 channels.\n", "\n", "### What you'll do\n", "1. Generate synthetic triplets: real (consistent point source in the difference) vs bogus (CR hits / hot pixels).\n", "2. Build a small VGG-style CNN on 3-channel inputs.\n", "3. Train, plot the ROC, and pick a threshold that gives 1% FPR.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e9e33ba7", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, TensorDataset\n", "\n", "rng = np.random.default_rng(0)\n", "torch.manual_seed(0)\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" ] }, { "cell_type": "markdown", "id": "dfcefc72", "metadata": {}, "source": [ "## Generate synthetic triplets\n" ] }, { "cell_type": "code", "execution_count": null, "id": "5d19260e", "metadata": {}, "outputs": [], "source": [ "SIZE = 21 # image side, pixels\n", "\n", "def gaussian2d(x0, y0, sigma=1.5, flux=1.0):\n", " yy, xx = np.mgrid[0:SIZE, 0:SIZE]\n", " return flux * np.exp(-((xx - x0) ** 2 + (yy - y0) ** 2) / (2 * sigma ** 2))\n", "\n", "NOISE = 0.35\n", "\n", "def make_real():\n", " cx, cy = SIZE / 2, SIZE / 2\n", " # reference: extended galaxy + faint sub-structure\n", " ref = gaussian2d(cx, cy, sigma=3.0, flux=2.5)\n", " ref += gaussian2d(cx + rng.uniform(-2, 2), cy + rng.uniform(-2, 2),\n", " sigma=1.5, flux=rng.uniform(0.3, 0.8))\n", " # science: same galaxy plus a near-centred point source (faint sometimes)\n", " flux = rng.uniform(0.5, 1.8)\n", " dx, dy = rng.uniform(-1.0, 1.0, 2)\n", " sci = ref + gaussian2d(cx + dx, cy + dy, sigma=1.2, flux=flux)\n", " sci += rng.normal(0, NOISE, sci.shape)\n", " ref += rng.normal(0, NOISE, ref.shape)\n", " diff = sci - ref\n", " return np.stack([sci, ref, diff]).astype(np.float32)\n", "\n", "def make_bogus():\n", " cx, cy = SIZE / 2, SIZE / 2\n", " ref = gaussian2d(cx, cy, sigma=3.0, flux=2.5)\n", " ref += gaussian2d(cx + rng.uniform(-2, 2), cy + rng.uniform(-2, 2),\n", " sigma=1.5, flux=rng.uniform(0.3, 0.8))\n", " sci = ref.copy()\n", " # cosmic-ray-like artifact — sometimes right on top of the galaxy core\n", " kind = rng.integers(0, 3)\n", " if kind == 0:\n", " # hot pixels scattered\n", " for _ in range(rng.integers(1, 4)):\n", " x = rng.integers(0, SIZE); y = rng.integers(0, SIZE)\n", " sci[y, x] += rng.uniform(1.5, 3.5)\n", " elif kind == 1:\n", " # short bright streak through the field\n", " x0 = rng.integers(0, SIZE); y0 = rng.integers(0, SIZE)\n", " dx_, dy_ = rng.choice([-1, 0, 1], 2)\n", " for step in range(rng.integers(3, 7)):\n", " xs, ys = x0 + dx_ * step, y0 + dy_ * step\n", " if 0 <= xs < SIZE and 0 <= ys < SIZE:\n", " sci[ys, xs] += rng.uniform(1.5, 3.0)\n", " else:\n", " # subtraction-residual: galaxy slightly mis-aligned\n", " shift = rng.uniform(-0.7, 0.7, 2)\n", " sci = gaussian2d(cx + shift[0], cy + shift[1], sigma=3.0, flux=2.5)\n", " sci += gaussian2d(cx + rng.uniform(-2, 2), cy + rng.uniform(-2, 2),\n", " sigma=1.5, flux=rng.uniform(0.3, 0.8))\n", " sci += rng.normal(0, NOISE, sci.shape)\n", " ref += rng.normal(0, NOISE, ref.shape)\n", " diff = sci - ref\n", " return np.stack([sci, ref, diff]).astype(np.float32)\n", "\n", "N = 800\n", "X = np.zeros((2 * N, 3, SIZE, SIZE), dtype=np.float32)\n", "y = np.zeros(2 * N, dtype=np.int64)\n", "for i in range(N):\n", " X[i] = make_real(); y[i] = 1\n", " X[N + i] = make_bogus(); y[N + i] = 0\n", "print(X.shape, y.mean())\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f402cc41", "metadata": {}, "outputs": [], "source": [ "# Visualise a couple of real and bogus triplets\n", "fig, axes = plt.subplots(2, 3, figsize=(8, 5.5))\n", "titles = [\"science\", \"reference\", \"difference\"]\n", "for row, (label, idx) in enumerate([(\"real\", 0), (\"bogus\", N)]):\n", " for col in range(3):\n", " axes[row, col].imshow(X[idx, col], cmap=\"gray\")\n", " axes[row, col].set_title(f\"{label} — {titles[col]}\")\n", " axes[row, col].axis(\"off\")\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "cc27cd21", "metadata": {}, "source": [ "## Task 1 — Train/test split + DataLoader\n", "\n", "Make an 80/20 split and wrap as `DataLoader`s. Don't forget to convert to tensors.\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- `torch.from_numpy(X)` keeps the dtype.\\n- Use `torch.randperm(len(X))` to shuffle indices.\\n- `TensorDataset` + `DataLoader(batch_size=32, shuffle=True)` is enough.\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ca28f486", "metadata": {}, "outputs": [], "source": [ "# TODO: split into train/test, build DataLoaders\n" ] }, { "cell_type": "markdown", "id": "0fcfd51d", "metadata": {}, "source": [ "## Task 2 — Build a small VGG-like CNN\n", "\n", "3 input channels (science, reference, difference), one logit output (binary).\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- Two or three `Conv2d -> ReLU -> MaxPool` blocks are plenty for 21×21 inputs.\\n- Use `nn.AdaptiveAvgPool2d(1)` then a linear layer to one output.\\n- For binary classification use `BCEWithLogitsLoss` (numerically safer than sigmoid + BCE).\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "78f65e9d", "metadata": {}, "outputs": [], "source": [ "class RBNet(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # TODO\n", " pass\n", " def forward(self, x):\n", " # TODO\n", " raise NotImplementedError\n" ] }, { "cell_type": "markdown", "id": "3a40b186", "metadata": {}, "source": [ "## Task 3 — Train, plot ROC, pick a threshold\n", "\n", "Train for ~15 epochs. Then:\n", "1. Score the test set, sort scores.\n", "2. Sweep thresholds and compute (FPR, TPR) at each.\n", "3. Pick the threshold that gives FPR ≤ 1% and report the corresponding TPR.\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- For binary classification with a single logit, `prob = torch.sigmoid(logit)`.\\n- Sweep ~200 thresholds between 0 and 1.\\n- TPR = TP / (TP + FN); FPR = FP / (FP + TN).\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "dad49c8c", "metadata": {}, "outputs": [], "source": [ "# TODO: training loop + ROC + threshold selection\n" ] }, { "cell_type": "markdown", "id": "13f896f8", "metadata": {}, "source": [ "## Task 4 — Confusion matrix + confident examples\n", "\n", "Numbers are easy to over-trust. Look at the actual decisions:\n", "\n", "1. At a fixed threshold (start with 0.5), build the 2×2 confusion matrix on the test set.\n", "2. Find the **3 most confident reals**: test examples where `true = real` and `score` is highest.\n", "3. Find the **3 most confident boguses**: test examples where `true = bogus` and `score` is lowest.\n", "4. For each, plot the (science, reference, difference) triplet with the score in the title.\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- You already have `scores` and `truth` from Task 3 — reuse them.\\n- `np.argsort(scores)` gives indices sorted ascending.\\n- `(truth == 1)` is the mask of real test examples; intersect with sorted scores.\\n- For each example chosen, you need the test-set image `Xte[i]` to plot the 3 channels.\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "df937aa4", "metadata": {}, "outputs": [], "source": [ "# TODO:\n", "# 1. Build a 2x2 confusion matrix at threshold 0.5\n", "# 2. Find 3 most confident reals (highest score, true=1)\n", "# 3. Find 3 most confident boguses (lowest score, true=0)\n", "# 4. Plot each as a (science, reference, difference) triplet\n" ] }, { "cell_type": "markdown", "id": "6e023f60", "metadata": {}, "source": [ "## Stretch\n", "\n", "- Make the bogus class harder: include cases where the cosmic ray sits exactly on the galaxy core.\n", "- Inject class imbalance (90% real, 10% bogus). What metric should you optimise — accuracy or recall at fixed FPR?\n", "- Replace 3-channel input with just the difference image. How much do you lose?\n", "- Look at examples right at the decision boundary (score ≈ 0.5). What do they look like?\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }