{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Supervised Classification Module\n", "\n", "**Lecturer:** Ashish Mahabal
\n", "**Jupyter Notebook Authors:** Ashish Mahabal & Yuhan Yao\n", "\n", "This is a Jupyter notebook lesson extending the GROWTH Summer School 2019 and adapted for the NARIT-EACOA 2019 summer workshop.\n", "\n", "## Objective\n", "Classify different classes using (a) decision trees and (b) random forest \n", "\n", "## Key steps\n", "- Pick variable types\n", "- Select training sample\n", "- Select method\n", "- Look at confusion matrix and details \n", "\n", "## Required dependencies\n", "\n", "See GROWTH school webpage for detailed instructions on how to install these modules and packages. Nominally, you should be able to install the python modules with pip install . The external astromatic packages are easiest installed using package managers (e.g., rpm, apt-get).\n", "\n", "### Python modules\n", "* python 3\n", "* astropy\n", "* numpy\n", "* astroquery\n", "* pandas\n", "* matplotlib\n", "* pydotplus\n", "* IPython.display\n", "* sklearn\n", "\n", "### External packages\n", "None\n", "\n", "### Partial Credits\n", "Pavlos Protopapas (LSSDS notebook)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# For inline plots\n", "%matplotlib inline\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import io\n", "import pydotplus\n", "from IPython.display import Image\n", "\n", "# Various scikit-learn modules\n", "from sklearn.model_selection import train_test_split\n", "from sklearn import tree\n", "from sklearn.metrics import confusion_matrix\n", "from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier, export_graphviz\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define datadir and files we will use" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "datadir = 'data'\n", "\n", "# for decision tree\n", "catalog = datadir + '/CatalinaVars.tbl.gz'\n", "lightcurves = datadir + '/CRTS_6varclasses.csv.gz'\n", "\n", "# for random forest\n", "featuresfile = datadir + '/cvs_and_blazars.dat'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read the light curves" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0110906502672553705.50192516.9437970.082004182.258719.76580
1110906502672553731.48331416.6451020.075203182.258679.76585
2110906502672553731.49140616.6937910.076497182.258709.76574
3110906502672553731.49946516.7936510.078755182.258699.76576
4110906502672553731.50752916.7678170.077436182.258789.76581
\n", "
" ], "text/plain": [ " ID MJD Mag magerr RA Dec\n", "0 1109065026725 53705.501925 16.943797 0.082004 182.25871 9.76580\n", "1 1109065026725 53731.483314 16.645102 0.075203 182.25867 9.76585\n", "2 1109065026725 53731.491406 16.693791 0.076497 182.25870 9.76574\n", "3 1109065026725 53731.499465 16.793651 0.078755 182.25869 9.76576\n", "4 1109065026725 53731.507529 16.767817 0.077436 182.25878 9.76581" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lcs = pd.read_csv(lightcurves,\n", " compression='gzip',\n", " header=1,\n", " sep=',',\n", " skipinitialspace=True,\n", " nrows=100000)\n", " #skiprows=[4,5])\n", " #,nrows=100000)\n", "\n", "lcs.columns = ['ID', 'MJD', 'Mag', 'magerr', 'RA', 'Dec']\n", "lcs.head()\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "301" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(lcs.groupby('ID'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read catalog with class information for variables" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Catalina_Surveys_IDNumerical_IDRA_J2000DecV_magPeriod_daysAmplitudeNumber_ObsVar_Type
0CSS_J000020.4+103118110900104123200:00:20.41+10:31:18.914.621.4917582.392232
1CSS_J000031.5-084652100900104499700:00:31.50-08:46:52.314.140.4041850.121631
2CSS_J000036.9+412805114000106336600:00:36.94+41:28:05.717.390.2746270.731581
3CSS_J000037.5+390308113800106984900:00:37.55+39:03:08.117.740.306910.232191
4CSS_J000103.3+105724110900105073900:01:03.37+10:57:24.415.251.58375820.112238
\n", "
" ], "text/plain": [ " Catalina_Surveys_ID Numerical_ID RA_J2000 Dec V_mag \\\n", "0 CSS_J000020.4+103118 1109001041232 00:00:20.41 +10:31:18.9 14.62 \n", "1 CSS_J000031.5-084652 1009001044997 00:00:31.50 -08:46:52.3 14.14 \n", "2 CSS_J000036.9+412805 1140001063366 00:00:36.94 +41:28:05.7 17.39 \n", "3 CSS_J000037.5+390308 1138001069849 00:00:37.55 +39:03:08.1 17.74 \n", "4 CSS_J000103.3+105724 1109001050739 00:01:03.37 +10:57:24.4 15.25 \n", "\n", " Period_days Amplitude Number_Obs Var_Type \n", "0 1.491758 2.39 223 2 \n", "1 0.404185 0.12 163 1 \n", "2 0.274627 0.73 158 1 \n", "3 0.30691 0.23 219 1 \n", "4 1.5837582 0.11 223 8 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cat = pd.read_csv(catalog,\n", " compression='gzip',\n", " header=5,\n", " sep=' ',\n", " skipinitialspace=True,\n", " )\n", "\n", "columns = cat.columns[1:]\n", "cat = cat[cat.columns[:-1]]\n", "cat.columns = columns\n", "\n", "cat.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get a subset of variable types, and with minimum length of light curves\n", "### The classes are from Drake et al. 2014 and Mahabal et al. 2017\n", "### 2: EA (detached binaries), 4: RRab, 5: RRc, 6:RRd, 8: RS CVn, 13: LPV" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Catalina_Surveys_IDNumerical_IDRA_J2000DecV_magPeriod_daysAmplitudeNumber_ObsVar_Type
0CSS_J000020.4+103118110900104123200:00:20.41+10:31:18.914.621.4917582.392232
4CSS_J000103.3+105724110900105073900:01:03.37+10:57:24.415.251.58375820.112238
8CSS_J000131.5+324913113200105201000:01:31.54+32:49:13.114.7113.0495490.171888
16CSS_J000216.1-165109101500100209100:02:16.16-16:51:09.716.070.304870.171245
23CSS_J000309.5+193816111800106063900:03:09.56+19:38:16.617.821.125820.592062
\n", "