{
"cells": [
{
"cell_type": "markdown",
"id": "26146195",
"metadata": {},
"source": [
"# Exercise 2 — Real or bogus? Image-triplet classification\n",
"\n",
"**Reference paper.** Duev et al. 2019, *Real-bogus classification for the Zwicky Transient Facility\n",
"using deep learning*, [arXiv:1907.11259](https://arxiv.org/abs/1907.11259).\n",
"\n",
"### Why this matters\n",
"ZTF generates ~1M alerts/night. Most are artifacts. BRAAI sits at the front of the\n",
"pipeline and decides what's worth a human's time. The input is a *triplet*:\n",
"science image, reference image, and difference image — stacked into 3 channels.\n",
"\n",
"### What you'll do\n",
"1. Generate synthetic triplets: real (consistent point source in the difference) vs bogus (CR hits / hot pixels).\n",
"2. Build a small VGG-style CNN on 3-channel inputs.\n",
"3. Train, plot the ROC, and pick a threshold that gives 1% FPR.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9e33ba7",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"\n",
"rng = np.random.default_rng(0)\n",
"torch.manual_seed(0)\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n"
]
},
{
"cell_type": "markdown",
"id": "dfcefc72",
"metadata": {},
"source": [
"## Generate synthetic triplets\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d19260e",
"metadata": {},
"outputs": [],
"source": [
"SIZE = 21 # image side, pixels\n",
"\n",
"def gaussian2d(x0, y0, sigma=1.5, flux=1.0):\n",
" yy, xx = np.mgrid[0:SIZE, 0:SIZE]\n",
" return flux * np.exp(-((xx - x0) ** 2 + (yy - y0) ** 2) / (2 * sigma ** 2))\n",
"\n",
"NOISE = 0.35\n",
"\n",
"def make_real():\n",
" cx, cy = SIZE / 2, SIZE / 2\n",
" # reference: extended galaxy + faint sub-structure\n",
" ref = gaussian2d(cx, cy, sigma=3.0, flux=2.5)\n",
" ref += gaussian2d(cx + rng.uniform(-2, 2), cy + rng.uniform(-2, 2),\n",
" sigma=1.5, flux=rng.uniform(0.3, 0.8))\n",
" # science: same galaxy plus a near-centred point source (faint sometimes)\n",
" flux = rng.uniform(0.5, 1.8)\n",
" dx, dy = rng.uniform(-1.0, 1.0, 2)\n",
" sci = ref + gaussian2d(cx + dx, cy + dy, sigma=1.2, flux=flux)\n",
" sci += rng.normal(0, NOISE, sci.shape)\n",
" ref += rng.normal(0, NOISE, ref.shape)\n",
" diff = sci - ref\n",
" return np.stack([sci, ref, diff]).astype(np.float32)\n",
"\n",
"def make_bogus():\n",
" cx, cy = SIZE / 2, SIZE / 2\n",
" ref = gaussian2d(cx, cy, sigma=3.0, flux=2.5)\n",
" ref += gaussian2d(cx + rng.uniform(-2, 2), cy + rng.uniform(-2, 2),\n",
" sigma=1.5, flux=rng.uniform(0.3, 0.8))\n",
" sci = ref.copy()\n",
" # cosmic-ray-like artifact — sometimes right on top of the galaxy core\n",
" kind = rng.integers(0, 3)\n",
" if kind == 0:\n",
" # hot pixels scattered\n",
" for _ in range(rng.integers(1, 4)):\n",
" x = rng.integers(0, SIZE); y = rng.integers(0, SIZE)\n",
" sci[y, x] += rng.uniform(1.5, 3.5)\n",
" elif kind == 1:\n",
" # short bright streak through the field\n",
" x0 = rng.integers(0, SIZE); y0 = rng.integers(0, SIZE)\n",
" dx_, dy_ = rng.choice([-1, 0, 1], 2)\n",
" for step in range(rng.integers(3, 7)):\n",
" xs, ys = x0 + dx_ * step, y0 + dy_ * step\n",
" if 0 <= xs < SIZE and 0 <= ys < SIZE:\n",
" sci[ys, xs] += rng.uniform(1.5, 3.0)\n",
" else:\n",
" # subtraction-residual: galaxy slightly mis-aligned\n",
" shift = rng.uniform(-0.7, 0.7, 2)\n",
" sci = gaussian2d(cx + shift[0], cy + shift[1], sigma=3.0, flux=2.5)\n",
" sci += gaussian2d(cx + rng.uniform(-2, 2), cy + rng.uniform(-2, 2),\n",
" sigma=1.5, flux=rng.uniform(0.3, 0.8))\n",
" sci += rng.normal(0, NOISE, sci.shape)\n",
" ref += rng.normal(0, NOISE, ref.shape)\n",
" diff = sci - ref\n",
" return np.stack([sci, ref, diff]).astype(np.float32)\n",
"\n",
"N = 800\n",
"X = np.zeros((2 * N, 3, SIZE, SIZE), dtype=np.float32)\n",
"y = np.zeros(2 * N, dtype=np.int64)\n",
"for i in range(N):\n",
" X[i] = make_real(); y[i] = 1\n",
" X[N + i] = make_bogus(); y[N + i] = 0\n",
"print(X.shape, y.mean())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f402cc41",
"metadata": {},
"outputs": [],
"source": [
"# Visualise a couple of real and bogus triplets\n",
"fig, axes = plt.subplots(2, 3, figsize=(8, 5.5))\n",
"titles = [\"science\", \"reference\", \"difference\"]\n",
"for row, (label, idx) in enumerate([(\"real\", 0), (\"bogus\", N)]):\n",
" for col in range(3):\n",
" axes[row, col].imshow(X[idx, col], cmap=\"gray\")\n",
" axes[row, col].set_title(f\"{label} — {titles[col]}\")\n",
" axes[row, col].axis(\"off\")\n",
"plt.tight_layout(); plt.show()\n"
]
},
{
"cell_type": "markdown",
"id": "cc27cd21",
"metadata": {},
"source": [
"## Task 1 — Train/test split + DataLoader\n",
"\n",
"Make an 80/20 split and wrap as `DataLoader`s. Don't forget to convert to tensors.\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- `torch.from_numpy(X)` keeps the dtype.\\n- Use `torch.randperm(len(X))` to shuffle indices.\\n- `TensorDataset` + `DataLoader(batch_size=32, shuffle=True)` is enough.\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca28f486",
"metadata": {},
"outputs": [],
"source": [
"# TODO: split into train/test, build DataLoaders\n"
]
},
{
"cell_type": "markdown",
"id": "0fcfd51d",
"metadata": {},
"source": [
"## Task 2 — Build a small VGG-like CNN\n",
"\n",
"3 input channels (science, reference, difference), one logit output (binary).\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- Two or three `Conv2d -> ReLU -> MaxPool` blocks are plenty for 21×21 inputs.\\n- Use `nn.AdaptiveAvgPool2d(1)` then a linear layer to one output.\\n- For binary classification use `BCEWithLogitsLoss` (numerically safer than sigmoid + BCE).\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78f65e9d",
"metadata": {},
"outputs": [],
"source": [
"class RBNet(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # TODO\n",
" pass\n",
" def forward(self, x):\n",
" # TODO\n",
" raise NotImplementedError\n"
]
},
{
"cell_type": "markdown",
"id": "3a40b186",
"metadata": {},
"source": [
"## Task 3 — Train, plot ROC, pick a threshold\n",
"\n",
"Train for ~15 epochs. Then:\n",
"1. Score the test set, sort scores.\n",
"2. Sweep thresholds and compute (FPR, TPR) at each.\n",
"3. Pick the threshold that gives FPR ≤ 1% and report the corresponding TPR.\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- For binary classification with a single logit, `prob = torch.sigmoid(logit)`.\\n- Sweep ~200 thresholds between 0 and 1.\\n- TPR = TP / (TP + FN); FPR = FP / (FP + TN).\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dad49c8c",
"metadata": {},
"outputs": [],
"source": [
"# TODO: training loop + ROC + threshold selection\n"
]
},
{
"cell_type": "markdown",
"id": "13f896f8",
"metadata": {},
"source": [
"## Task 4 — Confusion matrix + confident examples\n",
"\n",
"Numbers are easy to over-trust. Look at the actual decisions:\n",
"\n",
"1. At a fixed threshold (start with 0.5), build the 2×2 confusion matrix on the test set.\n",
"2. Find the **3 most confident reals**: test examples where `true = real` and `score` is highest.\n",
"3. Find the **3 most confident boguses**: test examples where `true = bogus` and `score` is lowest.\n",
"4. For each, plot the (science, reference, difference) triplet with the score in the title.\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- You already have `scores` and `truth` from Task 3 — reuse them.\\n- `np.argsort(scores)` gives indices sorted ascending.\\n- `(truth == 1)` is the mask of real test examples; intersect with sorted scores.\\n- For each example chosen, you need the test-set image `Xte[i]` to plot the 3 channels.\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df937aa4",
"metadata": {},
"outputs": [],
"source": [
"# TODO:\n",
"# 1. Build a 2x2 confusion matrix at threshold 0.5\n",
"# 2. Find 3 most confident reals (highest score, true=1)\n",
"# 3. Find 3 most confident boguses (lowest score, true=0)\n",
"# 4. Plot each as a (science, reference, difference) triplet\n"
]
},
{
"cell_type": "markdown",
"id": "6e023f60",
"metadata": {},
"source": [
"## Stretch\n",
"\n",
"- Make the bogus class harder: include cases where the cosmic ray sits exactly on the galaxy core.\n",
"- Inject class imbalance (90% real, 10% bogus). What metric should you optimise — accuracy or recall at fixed FPR?\n",
"- Replace 3-channel input with just the difference image. How much do you lose?\n",
"- Look at examples right at the decision boundary (score ≈ 0.5). What do they look like?\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}