#!/usr/bin/env python
#
# Groups spectra so that minimum S/N
#  is achieved in all bins after taking
#  into account background subtraction.
#  Unlike the older version, Poisson errors
#  are calculated properly here, using the
#  approximate (but accurate) expressions
#  from Gehrels 1986 (ApJ, 303, 336).
#
# NOTE: Background file name and keywords
#  for background scaling are read in from
#  the FITS header of the spectrum!
#
# VERSION: 150913
#
# USAGE: $> snrgrppha src.pha SNR [--heaccept]
#



import pyfits as pf
from pylab import *
import sys
import os



EMIN = 3.0 # NuSTAR band lower limit [keV]
EMAX = 78.5 # NuSTAR band upper limit [keV]
NCHAN = 4095 # total number of channels
SNR = float(sys.argv[2]) # minimum S/N



def chan(EN):
  # returns channel number for given energy in keV
  # NOTE: Scaling taken from NuSTAR CALDB ARF file.
  return int((EN-1.599)/0.04)

def en(CHN):
  # returns energy given channel number
  # NOTE: Scaling taken from NuSTAR CALDB ARF file.
  return 1.599+CHN*0.04

def poisson_uunc(N):
  # returns an approximation of upper 1 sigma uncertainty
  #  for an observation of N counts, based on Equation 9
  #  in Gehrels (1986) with S=1 (confidence level 84%)
  t = 1.0-1.0/((1.0+N)*9.0)+1.0/(sqrt(1.0+N)*3.0)
  return t*t*t*(1.0+N)

def poisson_lunc(N):
  # returns an approximation of lower 1 sigma uncertainty
  #  for an observation of N counts, based on Equation 12
  #  in Gehrels (1986) with S=1 (confidence level 84%)
  t = 1.0-1.0/(9.0*N)-1.0/(3.0*sqrt(1.0*N))
  return t*t*t*N



### MAIN ###

# use regular 'grppha' to create extra columns:
print '\n -(snrgrppha)-> Using regular grppha to create extra columns ...'
os.system('rm -rf snrgrp.pha')
os.system("grppha "+sys.argv[1]+" snrgrp.pha comm='bad 0-"+str(chan(EMIN))+" & bad "+str(chan(EMAX))+"-"+str(NCHAN)+" & group min 1 & exit'")

# load grouped PHA:
HDULISTSRC = pf.open('snrgrp.pha')
_SRC = HDULISTSRC[1].data
os.system('rm -rf snrgrp.pha')

# get relevant data from FITS header:
HDRSRC = HDULISTSRC[1].header
BACKFILE = HDRSRC['BACKFILE']
BACKSCALSRC = HDRSRC['BACKSCAL'] # unitless
EXPOSURESRC = HDRSRC['EXPOSURE'] # [s]

# load the background spectrum:
HDULISTBKG = pf.open(BACKFILE)
_BKG = HDULISTBKG[1].data
HDRBKG = HDULISTBKG[1].header
BACKSCALBKG = HDRBKG['BACKSCAL'] # unitless
EXPOSUREBKG = HDRBKG['EXPOSURE'] # [s]

# calculate background scaling factor:
# NOTE: Assuming AREASCAL=1, which is ok for NuSTAR but maybe not others!
FBKG = (BACKSCALSRC/BACKSCALBKG)*(EXPOSURESRC/EXPOSUREBKG)
print '\n -(snrgrppha)-> Background scaling factor: ',FBKG

# loop over all channels:
# NOTE: Channels below EMIN are flagged already!
_binsnr = [] # for collecting SNR in each bin
chn = chan(EMIN)
while (chn<chan(EMAX)):
  si = 1.0*_SRC[chn][1] # "source" spectrum counts in this channel
  siu = poisson_uunc(si)-si # upper Poisson uncertainty
  bbi = 1.0*_BKG[chn][1] # background spectrum counts in this channel
  bbiu = poisson_uunc(bbi)-bbi # upper Poisson uncertainty
  bi = FBKG*bbi; biu = FBKG*bbiu # scaled to the source area
  # start filling a new bin:
  s = 1.0*si; su = 1.0*siu
  b = 1.0*bi; bu = 1.0*biu
  # calculate SNR:
  nt = s-b
  if (nt<0.0):
    snr = 0.0
  else:
    ntu = sqrt(su**2+bu**2)
    snr = nt/ntu
  # check if the bin needs expanding:
  if (snr<SNR): # YES, GROUPING IS NEEDED:
    chnstart = 1*chn
    chn += 1 # look at the next channel
    # internal loop -- until SNR is satisfied:
    while (snr<SNR):
      si = 1.0*_SRC[chn][1] # "source" spectrum counts in this channel
      bbi = 1.0*_BKG[chn][1] # background spectrum counts in this channel
      # totals in this bin:
      s += 1.0*si; su = poisson_uunc(s)-s
      b += FBKG*bbi; bu = FBKG*(poisson_uunc(b/FBKG)-b)
      # calculate SNR:
      nt = s-b
      if (nt<0.0):
        snr = 0.0
      else:
        ntu = sqrt(su**2+bu**2)
        snr = nt/ntu
      _SRC[chn][2] = 0
      _SRC[chn][3] = -1
      # stop when EMAX is reached:
      if (chn==chan(EMAX)): # Desired SNR was not achieved in this bin!
        print ' -(snrgrppha)-> The highest-energy bin reaches only SNR=%.2f!'%snr
        # option 1: use lower SNR, if command line switch is given
        if '--heaccept' in sys.argv:
          print ' -(snrgrppha)-> Due to --heaccept, the highest-energy bin will be used anyway.'
          _binsnr.append(snr)
        # option 2 (defalut): flag the channels
        else:
          print ' -(snrgrppha)-> The highest-energy bin (above %.2f keV) will not be used.'%en(chnstart)
          print ' -(snrgrppha)-> Re-run with --heaccept to include the highest-energy bin anyway.'
          for i in range(chnstart,chan(EMAX)):
            _SRC[i][2] = 5
            _SRC[i][3] = 1
        snr = 99.0*SNR # loop breaker
      else:
        chn += 1
    # NOTE: At this point SNR has been built up to the
    #  minimum requirement in the internal loop!
    # mark bin edge:
    _SRC[chn][2] = 0
    _SRC[chn][3] = 1
    _binsnr.append(snr)
  else: # NO, BIN HAS SUFFICIENT SNR:
    # mark bin edge:
    _SRC[chn][2] = 0
    _SRC[chn][3] = 1
    _binsnr.append(snr)
    chn += 1

# mark high-energy channels as bad:
while (chn<NCHAN):
  _SRC[chn][2] = 5
  _SRC[chn][3] = 1
  chn += 1

# report some statistics:
print ' -(snrgrppha)-> The median SNR per bin is %.2f.'%median(_binsnr)
print ' -(snrgrppha)-> The minimum SNR per bin is %.2f.'%min(_binsnr)

# save the new grouping:
HDULISTSRC.writeto('snrgrp.pha')
print ' -(snrgrppha)-> Output saved to snrgrp.pha ... \n'