{
"cells": [
{
"cell_type": "markdown",
"id": "bfb26e82",
"metadata": {},
"source": [
"# Exercise 4 — Glitch or chirp? 2-D CNN on time-frequency spectrograms\n",
"\n",
"**Reference paper.** Cabero, Mahabal, McIver 2020, *GWSkyNet: a real-time classifier\n",
"for public gravitational-wave candidates*, [arXiv:2010.11829](https://arxiv.org/abs/2010.11829).\n",
"\n",
"### Why this matters\n",
"LIGO/Virgo data is dominated by **glitches** — transient noise that can mimic\n",
"real GW signals. Converted to a time-frequency Q-transform image, real chirps\n",
"sweep up in frequency while glitches blob in a localised time-frequency patch.\n",
"Distinguishing them is a pure image-classification problem.\n",
"\n",
"### What you'll do\n",
"1. Simulate two classes of toy spectrograms:\n",
" - **chirp**: a frequency-sweeping signal,\n",
" - **blip glitch**: a localised Gaussian blob in time-frequency.\n",
"2. Build a 2-D CNN to classify them.\n",
"3. Train and visualise the misclassifications.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3406060",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"\n",
"rng = np.random.default_rng(0)\n",
"torch.manual_seed(0)\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n"
]
},
{
"cell_type": "markdown",
"id": "55c058ca",
"metadata": {},
"source": [
"## Generate toy spectrograms\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f3ae959",
"metadata": {},
"outputs": [],
"source": [
"H, W = 32, 64 # frequency bins × time bins\n",
"\n",
"def make_chirp(noise=0.8):\n",
" img = np.zeros((H, W), dtype=np.float32)\n",
" f0 = rng.uniform(2, 12)\n",
" rate = rng.uniform(0.1, 0.5)\n",
" t0 = rng.integers(5, W - 15)\n",
" amp = rng.uniform(1.0, 2.2)\n",
" length = rng.integers(8, W - t0)\n",
" for k in range(length):\n",
" j = t0 + k\n",
" f = f0 + rate * k\n",
" fi = int(np.clip(f, 0, H - 1))\n",
" for df in (-1, 0, 1):\n",
" ff = fi + df\n",
" if 0 <= ff < H:\n",
" img[ff, j] += amp * (1.0 if df == 0 else 0.4)\n",
" img += rng.normal(0, noise, img.shape)\n",
" return img\n",
"\n",
"def make_blip(noise=0.8):\n",
" img = np.zeros((H, W), dtype=np.float32)\n",
" t0 = rng.uniform(10, W - 10)\n",
" f0 = rng.uniform(5, H - 5)\n",
" sigma_t = rng.uniform(1.5, 3.5)\n",
" sigma_f = rng.uniform(2.0, 4.5)\n",
" amp = rng.uniform(1.5, 3.0)\n",
" drift = rng.uniform(-0.1, 0.1)\n",
" yy, xx = np.mgrid[0:H, 0:W]\n",
" # add a slight frequency drift so some blips look chirp-ish\n",
" f_t = f0 + drift * (xx - t0)\n",
" img += amp * np.exp(-((xx - t0) ** 2 / (2 * sigma_t ** 2)\n",
" + (yy - f_t) ** 2 / (2 * sigma_f ** 2)))\n",
" img += rng.normal(0, noise, img.shape)\n",
" return img\n",
"\n",
"N_PER = 500\n",
"X = np.zeros((2 * N_PER, 1, H, W), dtype=np.float32)\n",
"y = np.zeros(2 * N_PER, dtype=np.int64)\n",
"for i in range(N_PER):\n",
" X[i, 0] = make_chirp(); y[i] = 1\n",
" X[N_PER + i, 0] = make_blip(); y[N_PER + i] = 0\n",
"print(X.shape)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61e0712a",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(2, 3, figsize=(9, 5))\n",
"for row, (name, base) in enumerate([(\"chirp\", 0), (\"blip\", N_PER)]):\n",
" for col in range(3):\n",
" axes[row, col].imshow(X[base + col, 0], aspect=\"auto\", origin=\"lower\", cmap=\"viridis\")\n",
" axes[row, col].set_title(name)\n",
" axes[row, col].axis(\"off\")\n",
"plt.tight_layout(); plt.show()\n"
]
},
{
"cell_type": "markdown",
"id": "f81436f8",
"metadata": {},
"source": [
"## Task 1 — Build a 2-D CNN\n",
"\n",
"The input is a single-channel image of shape `(1, 32, 64)`. Output one logit (binary).\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- Two `Conv2d -> ReLU -> MaxPool` blocks, then `AdaptiveAvgPool2d(1)`, then `Linear -> 1`.\\n- Use `BCEWithLogitsLoss`.\\n- Spectrograms are 'images with structure': don't over-augment, but a small horizontal-flip can be useful for some applications (not this one — direction of the chirp matters).\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1be833f",
"metadata": {},
"outputs": [],
"source": [
"class GlitchNet(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # TODO\n",
" pass\n",
" def forward(self, x):\n",
" # TODO\n",
" raise NotImplementedError\n"
]
},
{
"cell_type": "markdown",
"id": "66ccdbb5",
"metadata": {},
"source": [
"## Task 2 — Train and evaluate\n",
"\n",
"Train for ~15 epochs, report test accuracy.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e2d172f",
"metadata": {},
"outputs": [],
"source": [
"# TODO: split + training loop\n"
]
},
{
"cell_type": "markdown",
"id": "9f9b79ac",
"metadata": {},
"source": [
"## Task 3 — Where does the model trip up?\n",
"\n",
"Plot a few misclassified examples from each class. Do you see what\n",
"confused the network? (Hint: short chirps look quite blip-like.)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6cedf9c4",
"metadata": {},
"outputs": [],
"source": [
"# TODO: confusion matrix + plot misclassifications\n"
]
},
{
"cell_type": "markdown",
"id": "f1eba148",
"metadata": {},
"source": [
"## Task 4 — Confusion matrix + confident correct examples\n",
"\n",
"You've seen the failures (Task 3). Now look at the **successes**:\n",
"\n",
"1. Plot the 2×2 confusion matrix at threshold 0.5.\n",
"2. Find the **3 most confident chirps** — `true = 1` with the highest scores.\n",
"3. Find the **3 most confident blips** — `true = 0` with the lowest scores.\n",
"4. Plot each as a spectrogram with its score in the title.\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- Score the whole test set at once: `torch.sigmoid(model(Xte.to(DEVICE)))`.\\n- `np.argsort(scores)` sorts ascending; reverse with `[::-1]` for descending.\\n- For each side, pick the top-3 indices where the true class matches.\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a05e4ec",
"metadata": {},
"outputs": [],
"source": [
"# TODO:\n",
"# 1. Compute scores on the whole test set\n",
"# 2. Plot a 2x2 confusion matrix at threshold 0.5\n",
"# 3. Find top-3 confident chirps and top-3 confident blips\n",
"# 4. Plot the 6 spectrograms with score in title\n"
]
},
{
"cell_type": "markdown",
"id": "d42531ca",
"metadata": {},
"source": [
"## Stretch\n",
"\n",
"- Add a third class: \"gaussian noise only\" — does it change the decision boundary\n",
" for the other two?\n",
"- Make blips that drift slightly in frequency. Where's the line between a\n",
" drifting blip and a chirp?\n",
"- Replace the CNN with a simple linear classifier on flattened pixels. How much\n",
" worse does it do, and why?\n",
"- For the most confident examples, blank out half the image and re-score. Which\n",
" half matters?\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}