import os
import sys
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from scipy import stats
import gzip
import csv
from datetime import date
from math import sin, cos, sqrt, atan2, radians

FILE_NAME_START = "StormEvents_details-ftp_v1.0_d"
PRINT = False
SHOW = False
F_VALUES = [0, 1, 2, 3, 4, 5]
MONTHS = ["JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC", "ANNUAL"]

# strong tornadoes
TORNADO_PARAMS = {'f-value': [3, 4, 5],
                  'color': ['blue', 'black', 'red'],
                  'label': ['F3/EF3+', 'F4/EF4+', 'F5/EF5']}

# BEGIN_YEARMONTH,BEGIN_DAY,BEGIN_TIME,END_YEARMONTH,END_DAY,END_TIME,EPISODE_ID,EVENT_ID,STATE,STATE_FIPS,YEAR,MONTH_NAME,EVENT_TYPE,CZ_TYPE,CZ_FIPS,CZ_NAME,WFO,BEGIN_DATE_TIME,CZ_TIMEZONE,END_DATE_TIME,INJURIES_DIRECT,INJURIES_INDIRECT,DEATHS_DIRECT,DEATHS_INDIRECT,DAMAGE_PROPERTY,DAMAGE_CROPS,SOURCE,MAGNITUDE,MAGNITUDE_TYPE,FLOOD_CAUSE,CATEGORY,TOR_F_SCALE,TOR_LENGTH,TOR_WIDTH,TOR_OTHER_WFO,TOR_OTHER_CZ_STATE,TOR_OTHER_CZ_FIPS,TOR_OTHER_CZ_NAME,BEGIN_RANGE,BEGIN_AZIMUTH,BEGIN_LOCATION,END_RANGE,END_AZIMUTH,END_LOCATION,BEGIN_LAT,BEGIN_LON,END_LAT,END_LON,EPISODE_NARRATIVE,EVENT_NARRATIVE,DATA_SOURCE

# 199901,1,2233,199901,1,2255,1500560,5685639,"TEXAS",48,1999,"January","Tornado","C",339,"MONTGOMERY","HGX","01-JAN-99 22:33:00","CST","01-JAN-99 22:55:00","0","0","0","0","75K",,"TRAINED SPOTTER",,,,,"F1","8","75",,,,,,,"PORTER","5","E","SPLENDORA","30.10","-95.23","30.23","-95.08",,"Tornado touched down from Live Oak Estates (Porter area) to near Splendora.  Moderate house damage and trees down throughout the area.","PDC"

class Tornado():
    def __init__(self, line, columns):
        self.has_summary = True
        try:
            begin_lat = float(line[columns["BEGIN_LAT"]])
            begin_lon = float(line[columns["BEGIN_LON"]])
            end_lat = float(line[columns["END_LAT"]])
            end_lon = float(line[columns["END_LON"]])
            distance = self.getDistance(begin_lat, begin_lon, end_lat, end_lon)
        except ValueError:
            begin_lat = 0.0
            begin_lon = 0.0
            end_lat = 0.0
            end_lon = 0.0
            distance = 0
        month = line[columns["MONTH_NAME"]][0:3].upper()
        f = line[columns["TOR_F_SCALE"]]
        if f.startswith('F'):
            try:
                fnum = int(f[1:2])
            except ValueError:
                print("unknown F value", f)
                fnum = -1
        elif f.startswith('E'):
            try:
                fnum = int(f[2:3])
            except ValueError:
                if f != "EFU":
                    print("unknown EF value", f)
                fnum = -1
        else:
            fnum = -1
            if len(f) > 0:
                print("unknown f value", f)
        if fnum not in F_VALUES:
            fnum = -1
        self.summary = {'episode': line[columns["EPISODE_ID"]],
                        'event': line[columns["EVENT_ID"]],
                        'begin_lat': begin_lat,
                        'begin_lon': begin_lon,
                        'end_lat': end_lat,
                        'end_lon': end_lon,
                        'month': month,
                        'distance': distance,
                        'fnum': fnum}

    def getDistance(self, lat1, lon1, lat2, lon2):
        # approximate radius of earth in miles
        R = 3960
        lat1 = radians(lat1)
        lon1 = radians(lon1)
        lat2 = radians(lat2)
        lon2 = radians(lon2)
        dlon = lon2 - lon1
        dlat = lat2 - lat1
        a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
        c = 2 * atan2(sqrt(a), sqrt(1 - a))
        distance = R * c
        return distance

def summarize_year(year):
    year['fcounts'] = {}
    for m in MONTHS:
        year['fcounts'][m] = {}
        for f in F_VALUES:
            year['fcounts'][m][f] = 0
        year['fcounts'][m]["UNK"] = 0
    tornadoes = year["tornadoes"]
    for tornadoEvent in tornadoes:
        tornado = tornadoes[tornadoEvent]
        fnum = tornado.summary['fnum']
        month = tornado.summary['month']
        count = 0
        while fnum > 0:
            try:
                count = year['fcounts'][month][fnum]
            except KeyError:
                count = 0
            try:
                acount = year['fcounts']["ANNUAL"][fnum]
            except KeyError:
                acount = 0
            year['fcounts'][month][fnum] = count + 1
            year['fcounts']["ANNUAL"][fnum] = acount + 1
            fnum = fnum - 1

def read_year(filename):
    columns = {}
    units = {}
    year = 0
    year_data = {}
    with gzip.open(filename, mode='rt') as f:
        csv_reader = csv.reader(f, delimiter = ',')
        storm = None
        try:
            for line in csv_reader:
                if len(columns) == 0:
                    if line[0] == "BEGIN_YEARMONTH":
                        numColumns = 0
                        for column in line:
                            columns[column] = numColumns
                            numColumns = numColumns + 1
                        continue
                    else:
                        print("unexpected first line:", line)
                        return
                event_type = line[columns["EVENT_TYPE"]]
                if event_type != "Tornado":
                    continue
                yearNum = int(line[columns["YEAR"]])
                if len(year_data) == 0:
                    year_data = {"yearNum": yearNum, "tornadoes": {}}
                tornadoes = year_data["tornadoes"]
                event_id = line[columns["EVENT_ID"]]
                try:
                    tornado = tornadoes[event_id]
                except KeyError:
                    tornado = tornadoes[event_id] = Tornado(line, columns)
        except gzip.BadGzipFile:
            print("bad gzip file", filename)
    return yearNum, year_data
        
def tornadoCounts(years, month, y1, y2):
    plt.clf()
    slopes = {}
    for f_value, color, label in zip(TORNADO_PARAMS['f-value'], TORNADO_PARAMS['color'], TORNADO_PARAMS['label']):
        validYears = []
        validCounts = []
        total = 0
        for year in range(y1, y2 + 1):
            counts = years[year]['fcounts']
            count = counts[month][f_value]
            validYears.append(year)
            validCounts.append(count)
            total = total + count

        if total > 0:
            axes = plt.gca()
            slope, intercept, r_value, p_value, std_err = stats.linregress(validYears, validCounts)
            slopes[label] = [slope, p_value]
            points = [slope * y + intercept for y in validYears]
            if p_value < 0.05:
                label = label + ': ' + str(round(slope, 2)) + "*"
            else:
                label = label + ': ' + str(round(slope, 2))
            plt.bar(validYears, validCounts, color=color, label=label, alpha=0.8)
            plt.plot(validYears, points, color=color, alpha=0.9)

    plt.legend(ncol = 4, loc='upper left')
    plt.title("Tornado counts " + month + " " + str(y1) + "-" + str(y2))
    png_file = "plots/count" + "_" + str(y1) + "_" + month + ".png"
    plt.savefig(png_file)
    if SHOW:
        plt.show()
    return slopes, png_file

def do_plots(years):
    start_years = [1950, 1970, 1990]
    y2 = 2020
    f_html = open("tornado-plots.html", "w")
    f_html.write("<HTML><BODY><TABLE>\n")
    f_html.write("<TR>")
    f_html.write("<TH></TH>")
    for y1 in start_years:
        f_html.write("<TH>" + str(y1) + '-' + str(y2) + "</TH>")
    f_html.write("</TR>\n")
    for month in MONTHS:
        f_html.write("<TR><TD>" + month + "</TD>")
        for y1 in start_years:
            fig = plt.figure(figsize=(9,3))
            slopes, png_file = tornadoCounts(years, month, y1, y2)
            plt.close()
            f_html.write("<TD align=center><A HREF=\"" + png_file + "\" target=\"_blank\"><img src=\"" + png_file + "\" width=200></A><br>")
            if len(slopes) > 0:
                for label in slopes:
                    slope, p_value = slopes[label]
                    if p_value < 0.05:
                        f_html.write(label + ': ' + str(round(slope, 2)) + "*<br>")
                    else:
                        f_html.write(label + ': ' + str(round(slope, 2)) + "<br>")
            f_html.write("</TD>")
        f_html.write("</TR>\n")
    f_html.write("</TABLE></BODY></HTML>\n")

if __name__ == '__main__':    
    files = os.listdir(".")
    years = {}
    for f in files:
        if os.path.isfile(f) and f.startswith("StormEvents_details-ftp_v1.0_d"):
            year, year_data = read_year(f)
            years[year] = year_data
    for year in years:
        summarize_year(years[year])
        print(year, years[year]['fcounts'])
    do_plots(years)
