import sys
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import csv

MONTHS = ["JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC"]
COLORS = ['blue', 'purple', 'aqua', 'yellow', 'green', 'red', 'orange', 'lightgray', 'darkgray', 'gray', 'dimgray', 'black']
LATITUDE_FLAG = 40
LONGITUDE_FLAG = -100
LAT_LONG_FILE = "uscities.csv"
END_YEAR = 2021
# if False find lowest low, if True find highest high.
HIGHEST_HIGHS = True
if HIGHEST_HIGHS:
    FLAG = 999
else:
    FLAG = -999
CANADA = False
    
class CityLoc():
    def __init__(self):
        self.translate = {}
        self.translate["SEDRO WOOLLEY, WA US"] = "burlington,wa"
        self.translate["HEADWORKS PORTLAND WATER B, OR US"] = "portland,or"
        self.translate["WASHINGTON REAGAN NATIONAL AIRPORT, VA US"] = "arlington,va"
        self.translate["BLUE HILL, MA US"] = "quincy,ma"
        self.translate["FARMINGTON, ME US"] = "augusta,me"
        self.translate["DOWNTOWN CHARLESTON, SC US"] = "charleston,sc"
        self.translate["SAINT LEO, FL US"] = "dade city,fl"
        self.translate["TABLE MOUNTAIN, CA US"] = "oroville,ca"
        self.translate["FAIRMONT, CA US"] = "lancaster,ca"
        self.translate["SANDBERG, CA US"] = "lebec,ca"
        self.translate["DAL FTW WSCMO AIRPORT, TX US"] = "dallas,tx"
        self.translate["CAVE RUN LAKE, KY US"] = "morehead,ky"
        self.translate["PASO ROBLES MUNICIPAL AIRPORT, CA US"] = "templeton,ca"
        self.translate["BARROW AIRPORT, AK US"] = "utqiagvik,ak"
        self.translate["MONTICELLO, VA US"] = "charlottesville,va"
        self.translate["SHADWELL 1.0 E, VA US"] = "charlottesville,va"
        self.translate["AFTON 5.4 ESE, VA US"] = "waynesboro,va"
        self.translate["BARBOURSVILLE 1.1 NW, VA US"] = "gordonsville,va"
        self.translate["KESWICK 2.8 NE, VA US"] = "charlottesville,va"
        self.translate["CHARLOTTEBURG RESERVOIR, NJ US"] = "kinnelon,nj"
        self.translate["WEST PALM BEACH INTERNATIONAL AIRPORT, FL US"] = "west palm beach,fl"
        self.translate["UNIVERSITY EXPERIMENT STATION, AK US"] = "fairbanks,ak"
        self.translate["MORAN 5 WNW, WY US"] = "hoback,wy"
        self.translate["FOND DU LAC WWTP, WI US"] = "Fond du Lac,wi"
        self.translate["BEOWAWE, NV US"] = "carlin,nv"
        self.translate["ARTHUR 4 NW, NV US"] = "spring creek,nv"
        self.translate["CLEARBROOK, WA US"] = "sumas,wa"
        
        self.LAT_LONG_FILE = "uscities.csv"
        self.lat_long_dict = {}
        columns = {}
        f = open(LAT_LONG_FILE, 'r')
        csv_reader = csv.reader(f, delimiter=',')
        for line in csv_reader:
            if "city" in line:
                numColumns = 0
                for column in line:
                    columns[column] = numColumns
                    numColumns = numColumns + 1
            else:
                city = line[columns["city_ascii"]]
                state = line[columns["state_id"]]
                lat = float(line[columns["lat"]])
                lng = float(line[columns["lng"]])
                city_state = city.lower() + "," + state.lower()
                self.lat_long_dict[city_state] = (lat, lng)

    # "BOZEMAN MONTANA STATE UNIVERSITY, MT US"
    # "BILLINGS WATER TREATMENT PLANT, MT US"
    # "DILLON U OF MONTANA WESTERN, MT US"
    # "GLENDIVE, MT US"
    # "NEW YORK, NY US"
    def lookup_name(self, name):
        alt_names = []
        try:
            an = self.translate[name]
            alt_names.append(an)
        except KeyError:
            ele = name.split()
            if ele[-1] != "US":
                print("not a US location:", name)
                return (10, 10), name
            ele0 = ele[0].lower().strip(',')
            ele1 = ele[1].lower().strip(',')
            alt_names.append(ele0 + ',' + ele[-2].lower())
            alt_names.append(ele0 + ' ' + ele1 + ',' + ele[-2].lower())
        for alt_name in alt_names:
            try:
                lat_long = self.lat_long_dict[alt_name]
                return (lat_long[0], lat_long[1]), alt_name
            except KeyError:
                continue
        print("location not found in uscities.csv:", name)
        return (LATITUDE_FLAG, LONGITUDE_FLAG), name


class Reader():
    def __init__(self, filename):
        self.filename = filename
        self.stations = {}
        self.current_station = None
        self.current_year = 0

    def init_extremes(self):
        self.station_extremes = {}
        self.this_year_extremes = [FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG]
        self.current_year = 0

    def save_station(self):
        self.station_extremes[self.current_year] = self.this_year_extremes
        self.stations[self.current_station] = self.station_extremes

    def read_first_line(self):
        self.columns = {}
        f = open(self.filename, 'r')
        self.csv_reader = csv.reader(f, delimiter=',')
        for line in self.csv_reader:
            if ("STATION" in line and "TMAX" in line) or (CANADA and "Station Name" in line and "Max Temp (°C)" in line):
                numColumns = 0
                for column in line:
                    self.columns[column] = numColumns
                    numColumns = numColumns + 1
                return True
            else:
                return False

    def read_the_rest(self):
        for line in self.csv_reader:
            if CANADA:
                year = int(line[self.columns["Year"]])
                mnum = int(line[self.columns["Month"]]) - 1
                new_station  = line[self.columns["Station Name"]]
            else:
                date = line[self.columns["DATE"]]
                ymd = date.split('-')
                year = int(ymd[0])
                # NOTE: arrays start at zero
                mnum = int(ymd[1]) - 1
                new_station  = line[self.columns["NAME"]]
            if self.current_station == None:
                self.init_extremes()
                self.current_station = new_station
            elif self.current_station != new_station:
                self.save_station()
                self.init_extremes()
                self.current_station = new_station
            if self.current_year == 0:
                self.current_year = year
            if year > self.current_year:
                self.station_extremes[year] = self.this_year_extremes
                if HIGHEST_HIGHS:
                    self.this_year_extremes = [FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG]
                else:
                    self.this_year_extremes = [FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG]
                self.current_year = year
            else:
                if CANADA:
                    max = line[self.columns["Max Temp (°C)"]]
                    min = line[self.columns["Min Temp (°C)"]]
                else:
                    max = line[self.columns["TMAX"]]
                    min = line[self.columns["TMIN"]]
                if HIGHEST_HIGHS and len(max) > 0:
                    if CANADA:
                        tmaxC = float(max)
                        tmax = round(tmaxC * 9 / 5 + 32)
                    else:
                        tmax = int(max)
                    if tmax < 130:  # sanity check, data has some errors
                        if tmax > self.this_year_extremes[mnum] or self.this_year_extremes[mnum] == FLAG:
                            self.this_year_extremes[mnum] = tmax
                elif len(min) > 0:
                    if CANADA:
                        tminC = float(min)
                        tmin = round(tminC * 9 / 5 + 32)
                    else:
                        tmin = int(min)
                    if tmin > -50:  # sanity check, data has some errors
                        if tmin < self.this_year_extremes[mnum] or self.this_year_extremes[mnum] == FLAG:
                            self.this_year_extremes[mnum] = tmin
        self.save_station()
        return self.stations
                
def readCSVs(cityloc):
    stations = {}
    for file in os.listdir("."):
        if file.endswith(".csv"):
            reader = Reader(file)
            if reader.read_first_line():
                new_stations = reader.read_the_rest()
                for station in new_stations:
                    extremes = new_stations[station]
                    lat_lon, altname = cityloc.lookup_name(station)
                    stations[station] = {"lat_lon": lat_lon, "altname": altname, "extremes": extremes}
    return stations

def plot(station, data):
    extremes = data["extremes"]
    altname = data['altname']

    fig = plt.figure(figsize=(10,4))
    at_least_one_plot = False
    month = 0
    for color in COLORS:
        validYears = []
        validTemps = []
        year = list(extremes.keys())[0]
        while year <= END_YEAR:
            try:
                monthData = extremes[year]
            except KeyError:
                monthData = [FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG]
            try:
                tExtreme = monthData[month]
            except IndexError:
                tExtreme = FLAG
            if tExtreme != FLAG:
                validYears.append(year)
                validTemps.append(tExtreme)
            year = year + 1
                
        if len(validYears) > 0:
            at_least_one_plot = True
            slope, intercept, r_value, p_value, std_err = stats.linregress(validYears, validTemps)
            points = [slope * y + intercept for y in validYears]
            if p_value < 0.05:
                label = MONTHS[month] + ' ' + str(round(slope, 3)) + "*"
            else:
                label = MONTHS[month] + ' ' + str(round(slope, 3))
            plt.scatter(validYears, validTemps, color=color, label=label, alpha=0.5)
            plt.plot(validYears, points, color=color)
        month = month + 1

    if at_least_one_plot:
        plt.legend(ncol = 4, loc='lower right')
        if HIGHEST_HIGHS:
            plt.title(station + " highest monthly temp (F)")
        else:
            plt.title(station + " lowest monthly temp (F)")
        pngFile = altname + ".png"
        plt.savefig("plots/" + pngFile)
        #plt.show()
    plt.close()

def plot_map(stations, month, table):
    up_xs = []
    up_ys = []
    up_cs = []
    down_xs = []
    down_ys = []
    down_cs = []
    nt_xs = []
    nt_ys = []
    nt_cs = []
    cs = []
    month_num = MONTHS.index(month)
    for station in stations:
        data = stations[station]
        extremes = data["extremes"]
        lat_lon = data["lat_lon"]
        start_year = list(extremes.keys())[0]
        if start_year > 1940:
            continue
        if lat_lon[0] == 10:
            continue
        if lat_lon[1] < -140:
            continue
        validYears = []
        validTemps = []
        year = list(extremes.keys())[0]
        while year <= END_YEAR:
            try:
                monthData = extremes[year]
            except KeyError:
                monthData = [FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG, FLAG]
            try:
                tExtreme = monthData[month_num]
            except IndexError:
                tExtreme = FLAG
            if tExtreme != FLAG:
                validYears.append(year)
                validTemps.append(tExtreme)
            year = year + 1
        if len(validYears) > 0:
            slope, intercept, r_value, p_value, std_err = stats.linregress(validYears, validTemps)
            try:
                row = table[station]
            except KeyError:
                table[station] = {}
                row = table[station]
            row[month] = slope
            row[month + 'p'] = p_value
            x = 1600 - (lat_lon[1] * -1 * (830 - 0) / (130 - 65))
            y = 838 - (lat_lon[0] * (670 - 0) / (65 - 25))
            c = slope * 20
            if p_value > 0.05:
                nt_xs.append(x)
                nt_ys.append(y)
                nt_cs.append(c)
                cs.append(c)
            elif c > 0:
                up_xs.append(x)
                up_ys.append(y)
                up_cs.append(c)
                cs.append(c)
            else:
                down_xs.append(x)
                down_ys.append(y)
                down_cs.append(c)
                cs.append(c)

    for station in table:
        data = stations[station]
        altname = data['altname']
        table[station]["altname"] = altname

    # clip, so coolwarm colormap is centered on zero
    abs_min_nt_c = abs(min(nt_cs))
    max_nt_c = max(nt_cs)
    max_nt_c = min(max_nt_c, abs_min_nt_c)
    new_up_cs = []
    new_down_cs = []
    new_nt_cs = []
    for c in nt_cs:
        if c < -1.0 * max_nt_c:
            new_nt_cs.append(-1.0 * max_nt_c)
        elif c > max_nt_c:
            new_nt_cs.append(max_nt_c)
        else:
            new_nt_cs.append(c)

    img = plt.imread("usa3.png")
    imgplot = plt.imshow(img)
    if HIGHEST_HIGHS:
        plt.title("Highest monthly temperature trend - " + month)
    else:
        plt.title("lowest monthly temperature trend - " + month)
    plt.scatter(up_xs, up_ys, marker='^', c=up_cs, cmap="Reds")
    plt.scatter(down_xs, down_ys, marker='v', c=down_cs, cmap="Blues")
    plt.scatter(nt_xs, nt_ys, marker='o', c=new_nt_cs, cmap="coolwarm")
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    pngFile = month + ".png"
    plt.savefig("plots/" + pngFile)
    plt.show()

def printTableAsHtml(table):
    print("<TABLE><TR><TH>station</TH>")
    for month in MONTHS:
        print("  <TH>" + month + "</TH>")
    print("</TR>")
    for station in table:
        row = table[station]
        print("<TR>")
        print("  <TD><A HREF=\"trends/" + row["altname"] + ".png\" target=\"_blank\">" + station + "</A></TD>")
        for month in MONTHS:
            trend = round(row[month], 2)
            p = row[month + 'p']
            if trend <= -0.04 and p < 0.05:
                print("  <TD bgcolor=#3399ff>" + str(trend) + "</TD>")
            elif trend <= -0.01 and p < 0.05:
                print("  <TD bgcolor=#99ccff>" + str(trend) + "</TD>")
            elif trend >= +0.04 and p < 0.05:
                print("  <TD bgcolor=#ff0000>" + str(trend) + "</TD>")
            elif trend >= +0.01 and p < 0.05:
                print("  <TD bgcolor=#ff9999>" + str(trend) + "</TD>")
            else:
                print("  <TD>" + str(trend) + "</TD>")
        print("</TR>")
    print("</TABLE>")

if __name__ == '__main__':      
    cityloc = CityLoc()
    if len(sys.argv) > 1:
        stations = {}
        csvFile = sys.argv[1]
        reader = Reader(csvFile)
        if reader.read_first_line():
            new_stations = reader.read_the_rest()
            for station in new_stations:
                extremes = new_stations[station]
                lat_lon, altname = cityloc.lookup_name(station)
                stations[station] = {"lat_lon": lat_lon, "altname": altname, "extremes": extremes}
                plot(station, stations[station])
    else:
        stations = readCSVs(cityloc)
        if False:
            for station in stations:
                plot(station, stations[station])
        else:
            table = {}
            for month in MONTHS:
                plot_map(stations, month, table)
            printTableAsHtml(table)
    
