{ "cells": [ { "cell_type": "markdown", "id": "245845f5", "metadata": {}, "source": [ "# Exercise 3 — Multi-view transit detection in light curves\n", "\n", "**Reference paper.** Rao et al. 2021, *Detection of exoplanets in TESS data*,\n", "[arXiv:2101.09227](https://arxiv.org/abs/2101.09227).\n", "\n", "### Why this matters\n", "A planet transit is a small dip in flux, periodic, lasting hours. An eclipsing\n", "binary makes a *much deeper* and often asymmetric dip. The trick: feed the CNN\n", "**two views** of the same light curve — a \"global\" full-period view, and a \"local\"\n", "zoom on the transit window. The local view is what separates real planets from EBs.\n", "\n", "### What you'll do\n", "1. Generate synthetic phase-folded light curves: planet, eclipsing binary, pure noise.\n", "2. Build a CNN with **two input branches** (global + local).\n", "3. Train and look at where the model fails.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "dae7908a", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, TensorDataset\n", "\n", "rng = np.random.default_rng(0)\n", "torch.manual_seed(0)\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" ] }, { "cell_type": "markdown", "id": "998154d9", "metadata": {}, "source": [ "## Generate three-class synthetic light curves\n" ] }, { "cell_type": "code", "execution_count": null, "id": "da066ca4", "metadata": {}, "outputs": [], "source": [ "N_GLOBAL = 201\n", "N_LOCAL = 81\n", "\n", "def transit_dip(phase, depth, duration):\n", " # box-shaped dip centred at phase=0.5, width = duration\n", " return np.where(np.abs(phase - 0.5) < duration / 2, -depth, 0.0)\n", "\n", "def synth_curve(kind, noise=2.0e-4):\n", " # kind: 0 noise, 1 planet, 2 eclipsing binary\n", " phase = np.linspace(0, 1, N_GLOBAL)\n", " if kind == 0:\n", " flux = np.zeros_like(phase)\n", " elif kind == 1:\n", " depth = rng.uniform(5e-4, 1.5e-3) # 2.5-7.5x noise -> some hard\n", " dur = rng.uniform(0.015, 0.05)\n", " flux = transit_dip(phase, depth, dur)\n", " else: # eclipsing binary: deeper dip + secondary\n", " depth = rng.uniform(3e-3, 1.5e-2)\n", " dur = rng.uniform(0.03, 0.09)\n", " flux = transit_dip(phase, depth, dur)\n", " flux += transit_dip((phase + 0.5) % 1.0, rng.uniform(0.2, 0.5) * depth, dur)\n", " flux = flux + rng.normal(0, noise, N_GLOBAL)\n", " return phase.astype(np.float32), flux.astype(np.float32)\n", "\n", "def local_view(phase, flux):\n", " # zoom +/- 5% around the dip centre\n", " mask = (phase > 0.45) & (phase < 0.55)\n", " f = flux[mask]\n", " # resample to length N_LOCAL by linear interpolation\n", " xi = np.linspace(0, len(f) - 1, N_LOCAL)\n", " return np.interp(xi, np.arange(len(f)), f).astype(np.float32)\n", "\n", "N_PER = 400\n", "NOISE_STD = 2.0e-4\n", "GLOBALS = []; LOCALS = []; LBL = []\n", "for kind in (0, 1, 2):\n", " for _ in range(N_PER):\n", " p, f = synth_curve(kind, noise=NOISE_STD)\n", " GLOBALS.append(f)\n", " LOCALS.append(local_view(p, f))\n", " LBL.append(kind)\n", "# Normalise to unit noise std — gives the CNN a CNN-friendly dynamic range\n", "G = (np.stack(GLOBALS) / NOISE_STD)[:, None, :] # (N, 1, 201)\n", "L = (np.stack(LOCALS) / NOISE_STD)[:, None, :] # (N, 1, 81)\n", "Y = np.array(LBL, dtype=np.int64)\n", "print(G.shape, L.shape, Y.shape)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c2206cad", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(3, 2, figsize=(9, 6))\n", "for cls in (0, 1, 2):\n", " idx = np.where(Y == cls)[0][0]\n", " axes[cls, 0].plot(G[idx, 0]); axes[cls, 0].set_title(f\"global — class {cls}\")\n", " axes[cls, 1].plot(L[idx, 0]); axes[cls, 1].set_title(f\"local — class {cls}\")\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "ce8b620f", "metadata": {}, "source": [ "## Task 1 — Multi-input CNN\n", "\n", "Build a network with **two 1-D conv branches** — one for the global view (length 201),\n", "one for the local view (length 81). Concatenate the branch outputs, then a final\n", "linear layer with 3 outputs.\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- `nn.Conv1d(1, 16, 5)` followed by `nn.MaxPool1d(2)` works well.\\n- After 2-3 conv blocks, `AdaptiveAvgPool1d(1)` collapses each branch.\\n- `forward` takes two arguments (global, local); concat then linear.\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "bda3cc4c", "metadata": {}, "outputs": [], "source": [ "class MultiViewNet(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # TODO: two branches + a head\n", " pass\n", " def forward(self, g, l):\n", " # TODO\n", " raise NotImplementedError\n" ] }, { "cell_type": "markdown", "id": "fddc91d6", "metadata": {}, "source": [ "## Task 2 — Train\n", "\n", "3 classes, cross-entropy loss. Track accuracy on a 20% held-out set.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "439f715d", "metadata": {}, "outputs": [], "source": [ "# TODO: split, DataLoaders, training loop\n" ] }, { "cell_type": "markdown", "id": "3e1b4059", "metadata": {}, "source": [ "## Task 3 — Confusion matrix\n", "\n", "Plot a **3×3 confusion matrix** on the test set with class labels\n", "(\"noise\", \"planet\", \"EB\"). Which class is most often confused with planets?\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- Collect predictions and truths across all test batches.\\n- Build a 3×3 numpy array indexed by `(true, pred)`.\\n- Plot with `ax.imshow(cm, cmap='Blues')` and annotate each cell with its count.\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "df39308c", "metadata": {}, "outputs": [], "source": [ "# TODO: collect test predictions, build confusion matrix, plot it with labels\n" ] }, { "cell_type": "markdown", "id": "bf9fd6f1", "metadata": {}, "source": [ "## Task 4 — Inspect correct and incorrect examples\n", "\n", "Plot 3 **correctly** classified and 3 **incorrectly** classified test examples.\n", "For each, show **both** views side by side (global on the left, local on the right)\n", "and annotate with the true / predicted class.\n", "\n", "The interesting cases are usually planets predicted as EBs and EBs predicted as\n", "planets — the local view is what the network used (or failed to use) to tell them\n", "apart. What looks different about the misclassified ones?\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- Re-use the `pred` and `true` arrays from Task 3.\\n- `np.where(pred == true)[0]` gives indices of correct test examples.\\n- For each example, plot `Gte[i, 0]` and `Lte[i, 0]` (after `.numpy()`).\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "57cfb842", "metadata": {}, "outputs": [], "source": [ "# TODO:\n", "# 1. Find indices of correct and incorrect test examples\n", "# 2. For 3 of each, plot (global, local) views side by side with true/pred in titles\n" ] }, { "cell_type": "markdown", "id": "64f95045", "metadata": {}, "source": [ "## Stretch\n", "\n", "- Add a third \"half-phase\" view (centred at phase=0 instead of 0.5) — does\n", " it further reduce EB confusion?\n", "- Inject high-frequency stellar variability on top of every curve and re-train.\n", "- Compare to a one-branch model that sees only the global view.\n", "- For misclassified examples, are the dips genuinely ambiguous, or did the model\n", " miss something obvious? Inspect a handful by eye.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }