{ "cells": [ { "cell_type": "markdown", "id": "bfb26e82", "metadata": {}, "source": [ "# Exercise 4 — Glitch or chirp? 2-D CNN on time-frequency spectrograms\n", "\n", "**Reference paper.** Cabero, Mahabal, McIver 2020, *GWSkyNet: a real-time classifier\n", "for public gravitational-wave candidates*, [arXiv:2010.11829](https://arxiv.org/abs/2010.11829).\n", "\n", "### Why this matters\n", "LIGO/Virgo data is dominated by **glitches** — transient noise that can mimic\n", "real GW signals. Converted to a time-frequency Q-transform image, real chirps\n", "sweep up in frequency while glitches blob in a localised time-frequency patch.\n", "Distinguishing them is a pure image-classification problem.\n", "\n", "### What you'll do\n", "1. Simulate two classes of toy spectrograms:\n", " - **chirp**: a frequency-sweeping signal,\n", " - **blip glitch**: a localised Gaussian blob in time-frequency.\n", "2. Build a 2-D CNN to classify them.\n", "3. Train and visualise the misclassifications.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b3406060", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, TensorDataset\n", "\n", "rng = np.random.default_rng(0)\n", "torch.manual_seed(0)\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" ] }, { "cell_type": "markdown", "id": "55c058ca", "metadata": {}, "source": [ "## Generate toy spectrograms\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9f3ae959", "metadata": {}, "outputs": [], "source": [ "H, W = 32, 64 # frequency bins × time bins\n", "\n", "def make_chirp(noise=0.8):\n", " img = np.zeros((H, W), dtype=np.float32)\n", " f0 = rng.uniform(2, 12)\n", " rate = rng.uniform(0.1, 0.5)\n", " t0 = rng.integers(5, W - 15)\n", " amp = rng.uniform(1.0, 2.2)\n", " length = rng.integers(8, W - t0)\n", " for k in range(length):\n", " j = t0 + k\n", " f = f0 + rate * k\n", " fi = int(np.clip(f, 0, H - 1))\n", " for df in (-1, 0, 1):\n", " ff = fi + df\n", " if 0 <= ff < H:\n", " img[ff, j] += amp * (1.0 if df == 0 else 0.4)\n", " img += rng.normal(0, noise, img.shape)\n", " return img\n", "\n", "def make_blip(noise=0.8):\n", " img = np.zeros((H, W), dtype=np.float32)\n", " t0 = rng.uniform(10, W - 10)\n", " f0 = rng.uniform(5, H - 5)\n", " sigma_t = rng.uniform(1.5, 3.5)\n", " sigma_f = rng.uniform(2.0, 4.5)\n", " amp = rng.uniform(1.5, 3.0)\n", " drift = rng.uniform(-0.1, 0.1)\n", " yy, xx = np.mgrid[0:H, 0:W]\n", " # add a slight frequency drift so some blips look chirp-ish\n", " f_t = f0 + drift * (xx - t0)\n", " img += amp * np.exp(-((xx - t0) ** 2 / (2 * sigma_t ** 2)\n", " + (yy - f_t) ** 2 / (2 * sigma_f ** 2)))\n", " img += rng.normal(0, noise, img.shape)\n", " return img\n", "\n", "N_PER = 500\n", "X = np.zeros((2 * N_PER, 1, H, W), dtype=np.float32)\n", "y = np.zeros(2 * N_PER, dtype=np.int64)\n", "for i in range(N_PER):\n", " X[i, 0] = make_chirp(); y[i] = 1\n", " X[N_PER + i, 0] = make_blip(); y[N_PER + i] = 0\n", "print(X.shape)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "61e0712a", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(2, 3, figsize=(9, 5))\n", "for row, (name, base) in enumerate([(\"chirp\", 0), (\"blip\", N_PER)]):\n", " for col in range(3):\n", " axes[row, col].imshow(X[base + col, 0], aspect=\"auto\", origin=\"lower\", cmap=\"viridis\")\n", " axes[row, col].set_title(name)\n", " axes[row, col].axis(\"off\")\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "f81436f8", "metadata": {}, "source": [ "## Task 1 — Build a 2-D CNN\n", "\n", "The input is a single-channel image of shape `(1, 32, 64)`. Output one logit (binary).\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- Two `Conv2d -> ReLU -> MaxPool` blocks, then `AdaptiveAvgPool2d(1)`, then `Linear -> 1`.\\n- Use `BCEWithLogitsLoss`.\\n- Spectrograms are 'images with structure': don't over-augment, but a small horizontal-flip can be useful for some applications (not this one — direction of the chirp matters).\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a1be833f", "metadata": {}, "outputs": [], "source": [ "class GlitchNet(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # TODO\n", " pass\n", " def forward(self, x):\n", " # TODO\n", " raise NotImplementedError\n" ] }, { "cell_type": "markdown", "id": "66ccdbb5", "metadata": {}, "source": [ "## Task 2 — Train and evaluate\n", "\n", "Train for ~15 epochs, report test accuracy.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0e2d172f", "metadata": {}, "outputs": [], "source": [ "# TODO: split + training loop\n" ] }, { "cell_type": "markdown", "id": "9f9b79ac", "metadata": {}, "source": [ "## Task 3 — Where does the model trip up?\n", "\n", "Plot a few misclassified examples from each class. Do you see what\n", "confused the network? (Hint: short chirps look quite blip-like.)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6cedf9c4", "metadata": {}, "outputs": [], "source": [ "# TODO: confusion matrix + plot misclassifications\n" ] }, { "cell_type": "markdown", "id": "f1eba148", "metadata": {}, "source": [ "## Task 4 — Confusion matrix + confident correct examples\n", "\n", "You've seen the failures (Task 3). Now look at the **successes**:\n", "\n", "1. Plot the 2×2 confusion matrix at threshold 0.5.\n", "2. Find the **3 most confident chirps** — `true = 1` with the highest scores.\n", "3. Find the **3 most confident blips** — `true = 0` with the lowest scores.\n", "4. Plot each as a spectrogram with its score in the title.\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- Score the whole test set at once: `torch.sigmoid(model(Xte.to(DEVICE)))`.\\n- `np.argsort(scores)` sorts ascending; reverse with `[::-1]` for descending.\\n- For each side, pick the top-3 indices where the true class matches.\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1a05e4ec", "metadata": {}, "outputs": [], "source": [ "# TODO:\n", "# 1. Compute scores on the whole test set\n", "# 2. Plot a 2x2 confusion matrix at threshold 0.5\n", "# 3. Find top-3 confident chirps and top-3 confident blips\n", "# 4. Plot the 6 spectrograms with score in title\n" ] }, { "cell_type": "markdown", "id": "d42531ca", "metadata": {}, "source": [ "## Stretch\n", "\n", "- Add a third class: \"gaussian noise only\" — does it change the decision boundary\n", " for the other two?\n", "- Make blips that drift slightly in frequency. Where's the line between a\n", " drifting blip and a chirp?\n", "- Replace the CNN with a simple linear classifier on flattened pixels. How much\n", " worse does it do, and why?\n", "- For the most confident examples, blank out half the image and re-score. Which\n", " half matters?\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }