{ "cells": [ { "cell_type": "markdown", "id": "fe3a0b38", "metadata": {}, "source": [ "# Exercise 1 — DMDT: turning irregular light curves into images\n", "\n", "**Reference paper.** Mahabal et al. 2017, *Deep-learnt classification of light curves*,\n", "[arXiv:1709.06257](https://arxiv.org/abs/1709.06257).\n", "\n", "### Why this matters\n", "A light curve is irregular in time, has gaps, and has heteroscedastic errors.\n", "A CNN can't eat it directly. The DMDT trick: for every pair of points, compute\n", "`(Δmag, Δt)` and bin into a 2-D histogram. Now you have an image, and any CNN\n", "works out of the box.\n", "\n", "### What you'll do\n", "1. Generate synthetic light curves of two variable-star classes.\n", "2. Implement the DMDT representation yourself.\n", "3. Train a small CNN to tell the two classes apart from their DMDT images.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0ac7a984", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, TensorDataset\n", "\n", "rng = np.random.default_rng(0)\n", "torch.manual_seed(0)\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(\"device:\", DEVICE)\n" ] }, { "cell_type": "markdown", "id": "a1700920", "metadata": {}, "source": [ "## Generate synthetic light curves\n" ] }, { "cell_type": "code", "execution_count": null, "id": "427e936a", "metadata": {}, "outputs": [], "source": [ "# Two toy variable-star classes:\n", "# class 0 — RR Lyrae-like: sinusoidal, period 0.3-0.8 d\n", "# class 1 — eclipsing binary-like: double-dip Gaussian profile, period 0.5-1.5 d\n", "def synth_lightcurve(cls, n_points=60, t_max=30.0, noise=0.25):\n", " # irregular sampling\n", " t = np.sort(rng.uniform(0, t_max, n_points))\n", " if cls == 0:\n", " P = rng.uniform(0.2, 1.0) # overlap a bit with class 1\n", " amp = rng.uniform(0.3, 0.6)\n", " m = amp * np.sin(2 * np.pi * t / P + rng.uniform(0, 2 * np.pi))\n", " else:\n", " P = rng.uniform(0.4, 1.5)\n", " phase = (t / P) % 1.0\n", " # two dips per period at phase 0 and 0.5\n", " d1 = rng.uniform(0.7, 1.0)\n", " d2 = rng.uniform(0.3, 0.7) * d1\n", " dip1 = d1 * np.exp(-((phase - 0.0) / 0.07) ** 2)\n", " dip2 = d2 * np.exp(-((phase - 0.5) / 0.07) ** 2)\n", " m = -(dip1 + dip2)\n", " m += rng.normal(0, noise, size=n_points)\n", " return t.astype(np.float32), m.astype(np.float32)\n", "\n", "# Build a dataset\n", "N_PER_CLASS = 400\n", "data = []\n", "labels = []\n", "for cls in (0, 1):\n", " for _ in range(N_PER_CLASS):\n", " data.append(synth_lightcurve(cls))\n", " labels.append(cls)\n", "labels = np.array(labels)\n", "print(f\"{len(data)} light curves, balanced 2 classes\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "52920df3", "metadata": {}, "outputs": [], "source": [ "# Plot one of each class\n", "fig, axes = plt.subplots(1, 2, figsize=(10, 3))\n", "for ax, cls in zip(axes, (0, 1)):\n", " idx = np.where(labels == cls)[0][0]\n", " t, m = data[idx]\n", " ax.scatter(t, m, s=8)\n", " ax.set_title(f\"class {cls}\")\n", " ax.set_xlabel(\"time [d]\"); ax.set_ylabel(\"mag\")\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "278449a6", "metadata": {}, "source": [ "## Task 1 — Implement the DMDT representation\n", "\n", "For one light curve with times `t` and magnitudes `m` (both length N):\n", "\n", "1. Form every pair `(i, j)` with `i < j` (there are `N(N-1)/2` such pairs).\n", "2. Compute `dt_ij = t[j] - t[i]` and `dm_ij = m[j] - m[i]`.\n", "3. Bin `(dt, dm)` into a 2-D histogram of fixed shape.\n", "4. Normalise the resulting image so it has comparable scale across light curves.\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- Use `np.triu_indices(N, k=1)` to get all (i, j) with j > i.\\n- `np.histogram2d(dt, dm, bins=[dt_edges, dm_edges])` returns the counts you need.\\n- For edges, use log-spaced `dt` (e.g. logspace from 0.01 to 30 d) and linear `dm` (e.g. -2 to 2 mag).\\n- Normalise by dividing by the maximum count or by `N(N-1)/2`.\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "87e7cf62", "metadata": {}, "outputs": [], "source": [ "DT_EDGES = np.logspace(np.log10(0.01), np.log10(30.0), 23 + 1)\n", "DM_EDGES = np.linspace(-2.0, 2.0, 24 + 1)\n", "\n", "def dmdt_image(t, m):\n", " # TODO: implement the DMDT representation.\n", " # Should return a 2-D numpy array of shape (len(DM_EDGES)-1, len(DT_EDGES)-1).\n", " raise NotImplementedError\n", "\n", "# When you've implemented dmdt_image, this should plot two distinct-looking images:\n", "fig, axes = plt.subplots(1, 2, figsize=(10, 3.5))\n", "for ax, cls in zip(axes, (0, 1)):\n", " idx = np.where(labels == cls)[0][0]\n", " t, m = data[idx]\n", " img = dmdt_image(t, m)\n", " ax.imshow(img, aspect=\"auto\", origin=\"lower\",\n", " extent=[DT_EDGES[0], DT_EDGES[-1], DM_EDGES[0], DM_EDGES[-1]])\n", " ax.set_xscale(\"log\")\n", " ax.set_title(f\"DMDT — class {cls}\")\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "28a29cf9", "metadata": {}, "source": [ "## Task 2 — Build a small CNN\n", "\n", "You now have 800 images, each of shape `(24, 23)` (or whatever your bin grid was).\n", "Build a small CNN: two conv layers, a pool, a linear head, and a 2-class output.\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- Treat the DMDT image as a single-channel input: shape `(B, 1, H, W)`.\\n- `nn.Conv2d(1, 16, 3, padding=1)` is a fine first layer.\\n- A `nn.AdaptiveAvgPool2d(1)` after the conv stack avoids hand-computing the flatten size.\\n- Output two logits; use `CrossEntropyLoss`.\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fb0e5d44", "metadata": {}, "outputs": [], "source": [ "class DMDTNet(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " # TODO: define a small conv stack and a linear head with 2 outputs.\n", " pass\n", "\n", " def forward(self, x):\n", " # TODO\n", " raise NotImplementedError\n" ] }, { "cell_type": "markdown", "id": "b6b3e471", "metadata": {}, "source": [ "## Task 3 — Train and evaluate\n", "\n", "Compute DMDT images for every light curve, do an 80/20 split, train for\n", "~20 epochs, and report test accuracy. CPU should be fine — the dataset\n", "is tiny.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "19affe14", "metadata": {}, "outputs": [], "source": [ "# 1. Pre-compute DMDT for the whole dataset\n", "# 2. Stack into a tensor X of shape (N, 1, H, W) and labels y\n", "# 3. Train/test split\n", "# 4. Training loop with optimizer = Adam, lr=1e-3\n", "# 5. Report test accuracy\n", "\n", "# TODO: write the training loop. Helpful skeleton below — fill in the gaps.\n", "def train_model(model, train_loader, test_loader, epochs=20):\n", " opt = torch.optim.Adam(model.parameters(), lr=1e-3)\n", " for ep in range(epochs):\n", " model.train()\n", " for xb, yb in train_loader:\n", " # TODO: forward, loss, backward, step\n", " pass\n", " # TODO: evaluate on test set, print accuracy\n" ] }, { "cell_type": "markdown", "id": "1d17097f", "metadata": {}, "source": [ "## Task 4 — Inspect the results\n", "\n", "A bare accuracy number tells you very little. Now look at *what* the model gets right and wrong:\n", "\n", "1. Run the trained model on the test set and collect predictions.\n", "2. Plot the 2×2 confusion matrix.\n", "3. Plot 3 **correctly** classified test examples — for each one, show the **light curve** and the **DMDT image** side by side.\n", "4. Plot 3 **incorrectly** classified test examples the same way. What features of the light curve or DMDT seem to have fooled the model?\n", "\n", "
💡 Hint (click to expand)\n", "\n", "- Use `torch.no_grad()` and `model.eval()` for inference.\\n- Build the confusion matrix with a small loop or `np.add.at`.\\n- To plot a test example, you need its **original** index in `data` (to grab `t, m`). Save the permutation `perm` you used for the train/test split — the test indices are `perm[ntr:]`.\\n- `imshow(..., aspect='auto', origin='lower', extent=[DT_EDGES[0], DT_EDGES[-1], DM_EDGES[0], DM_EDGES[-1]])` puts the DMDT image on physical axes.\n", "\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "id": "266285b4", "metadata": {}, "outputs": [], "source": [ "# TODO:\n", "# 1. Run model on the test set, collect predictions + truth\n", "# 2. Build and plot the 2x2 confusion matrix\n", "# 3. Find indices of 3 correctly and 3 incorrectly classified test examples\n", "# 4. For each, plot (light curve, DMDT image) side by side\n" ] }, { "cell_type": "markdown", "id": "cf3b900e", "metadata": {}, "source": [ "## Stretch goals\n", "\n", "- What happens to accuracy if you halve the bin grid (e.g. 12×11)? Why?\n", "- Add Gaussian noise to the synthetic curves and see when the classifier breaks down.\n", "- Replace the CNN with a logistic regression on flattened DMDT pixels — how does it compare?\n", "- Do the misclassified examples cluster around a particular period or amplitude? Plot a histogram to find out.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }