{
"cells": [
{
"cell_type": "markdown",
"id": "245845f5",
"metadata": {},
"source": [
"# Exercise 3 — Multi-view transit detection in light curves\n",
"\n",
"**Reference paper.** Rao et al. 2021, *Detection of exoplanets in TESS data*,\n",
"[arXiv:2101.09227](https://arxiv.org/abs/2101.09227).\n",
"\n",
"### Why this matters\n",
"A planet transit is a small dip in flux, periodic, lasting hours. An eclipsing\n",
"binary makes a *much deeper* and often asymmetric dip. The trick: feed the CNN\n",
"**two views** of the same light curve — a \"global\" full-period view, and a \"local\"\n",
"zoom on the transit window. The local view is what separates real planets from EBs.\n",
"\n",
"### What you'll do\n",
"1. Generate synthetic phase-folded light curves: planet, eclipsing binary, pure noise.\n",
"2. Build a CNN with **two input branches** (global + local).\n",
"3. Train and look at where the model fails.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dae7908a",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"\n",
"rng = np.random.default_rng(0)\n",
"torch.manual_seed(0)\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n"
]
},
{
"cell_type": "markdown",
"id": "998154d9",
"metadata": {},
"source": [
"## Generate three-class synthetic light curves\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da066ca4",
"metadata": {},
"outputs": [],
"source": [
"N_GLOBAL = 201\n",
"N_LOCAL = 81\n",
"\n",
"def transit_dip(phase, depth, duration):\n",
" # box-shaped dip centred at phase=0.5, width = duration\n",
" return np.where(np.abs(phase - 0.5) < duration / 2, -depth, 0.0)\n",
"\n",
"def synth_curve(kind, noise=2.0e-4):\n",
" # kind: 0 noise, 1 planet, 2 eclipsing binary\n",
" phase = np.linspace(0, 1, N_GLOBAL)\n",
" if kind == 0:\n",
" flux = np.zeros_like(phase)\n",
" elif kind == 1:\n",
" depth = rng.uniform(5e-4, 1.5e-3) # 2.5-7.5x noise -> some hard\n",
" dur = rng.uniform(0.015, 0.05)\n",
" flux = transit_dip(phase, depth, dur)\n",
" else: # eclipsing binary: deeper dip + secondary\n",
" depth = rng.uniform(3e-3, 1.5e-2)\n",
" dur = rng.uniform(0.03, 0.09)\n",
" flux = transit_dip(phase, depth, dur)\n",
" flux += transit_dip((phase + 0.5) % 1.0, rng.uniform(0.2, 0.5) * depth, dur)\n",
" flux = flux + rng.normal(0, noise, N_GLOBAL)\n",
" return phase.astype(np.float32), flux.astype(np.float32)\n",
"\n",
"def local_view(phase, flux):\n",
" # zoom +/- 5% around the dip centre\n",
" mask = (phase > 0.45) & (phase < 0.55)\n",
" f = flux[mask]\n",
" # resample to length N_LOCAL by linear interpolation\n",
" xi = np.linspace(0, len(f) - 1, N_LOCAL)\n",
" return np.interp(xi, np.arange(len(f)), f).astype(np.float32)\n",
"\n",
"N_PER = 400\n",
"NOISE_STD = 2.0e-4\n",
"GLOBALS = []; LOCALS = []; LBL = []\n",
"for kind in (0, 1, 2):\n",
" for _ in range(N_PER):\n",
" p, f = synth_curve(kind, noise=NOISE_STD)\n",
" GLOBALS.append(f)\n",
" LOCALS.append(local_view(p, f))\n",
" LBL.append(kind)\n",
"# Normalise to unit noise std — gives the CNN a CNN-friendly dynamic range\n",
"G = (np.stack(GLOBALS) / NOISE_STD)[:, None, :] # (N, 1, 201)\n",
"L = (np.stack(LOCALS) / NOISE_STD)[:, None, :] # (N, 1, 81)\n",
"Y = np.array(LBL, dtype=np.int64)\n",
"print(G.shape, L.shape, Y.shape)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2206cad",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(3, 2, figsize=(9, 6))\n",
"for cls in (0, 1, 2):\n",
" idx = np.where(Y == cls)[0][0]\n",
" axes[cls, 0].plot(G[idx, 0]); axes[cls, 0].set_title(f\"global — class {cls}\")\n",
" axes[cls, 1].plot(L[idx, 0]); axes[cls, 1].set_title(f\"local — class {cls}\")\n",
"plt.tight_layout(); plt.show()\n"
]
},
{
"cell_type": "markdown",
"id": "ce8b620f",
"metadata": {},
"source": [
"## Task 1 — Multi-input CNN\n",
"\n",
"Build a network with **two 1-D conv branches** — one for the global view (length 201),\n",
"one for the local view (length 81). Concatenate the branch outputs, then a final\n",
"linear layer with 3 outputs.\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- `nn.Conv1d(1, 16, 5)` followed by `nn.MaxPool1d(2)` works well.\\n- After 2-3 conv blocks, `AdaptiveAvgPool1d(1)` collapses each branch.\\n- `forward` takes two arguments (global, local); concat then linear.\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bda3cc4c",
"metadata": {},
"outputs": [],
"source": [
"class MultiViewNet(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" # TODO: two branches + a head\n",
" pass\n",
" def forward(self, g, l):\n",
" # TODO\n",
" raise NotImplementedError\n"
]
},
{
"cell_type": "markdown",
"id": "fddc91d6",
"metadata": {},
"source": [
"## Task 2 — Train\n",
"\n",
"3 classes, cross-entropy loss. Track accuracy on a 20% held-out set.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "439f715d",
"metadata": {},
"outputs": [],
"source": [
"# TODO: split, DataLoaders, training loop\n"
]
},
{
"cell_type": "markdown",
"id": "3e1b4059",
"metadata": {},
"source": [
"## Task 3 — Confusion matrix\n",
"\n",
"Plot a **3×3 confusion matrix** on the test set with class labels\n",
"(\"noise\", \"planet\", \"EB\"). Which class is most often confused with planets?\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- Collect predictions and truths across all test batches.\\n- Build a 3×3 numpy array indexed by `(true, pred)`.\\n- Plot with `ax.imshow(cm, cmap='Blues')` and annotate each cell with its count.\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df39308c",
"metadata": {},
"outputs": [],
"source": [
"# TODO: collect test predictions, build confusion matrix, plot it with labels\n"
]
},
{
"cell_type": "markdown",
"id": "bf9fd6f1",
"metadata": {},
"source": [
"## Task 4 — Inspect correct and incorrect examples\n",
"\n",
"Plot 3 **correctly** classified and 3 **incorrectly** classified test examples.\n",
"For each, show **both** views side by side (global on the left, local on the right)\n",
"and annotate with the true / predicted class.\n",
"\n",
"The interesting cases are usually planets predicted as EBs and EBs predicted as\n",
"planets — the local view is what the network used (or failed to use) to tell them\n",
"apart. What looks different about the misclassified ones?\n",
"\n",
"💡 Hint (click to expand)
\n",
"\n",
"- Re-use the `pred` and `true` arrays from Task 3.\\n- `np.where(pred == true)[0]` gives indices of correct test examples.\\n- For each example, plot `Gte[i, 0]` and `Lte[i, 0]` (after `.numpy()`).\n",
"\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57cfb842",
"metadata": {},
"outputs": [],
"source": [
"# TODO:\n",
"# 1. Find indices of correct and incorrect test examples\n",
"# 2. For 3 of each, plot (global, local) views side by side with true/pred in titles\n"
]
},
{
"cell_type": "markdown",
"id": "64f95045",
"metadata": {},
"source": [
"## Stretch\n",
"\n",
"- Add a third \"half-phase\" view (centred at phase=0 instead of 0.5) — does\n",
" it further reduce EB confusion?\n",
"- Inject high-frequency stellar variability on top of every curve and re-train.\n",
"- Compare to a one-branch model that sees only the global view.\n",
"- For misclassified examples, are the dips genuinely ambiguous, or did the model\n",
" miss something obvious? Inspect a handful by eye.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}