import sys
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from scipy import stats
import csv
from datetime import date

INCRS = [0.02, 0.03, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.17, 0.18, 0.19, 0.20, 0.22, 0.23, 0.24, 0.25, 0.25, 0.25]
mins = []
maxs = []
avgs = []
rmin = 0.01
for incr in INCRS:
	mins.append(rmin)
	rmax = round(rmin + incr, 2)
	maxs.append(rmax)
	ravg = round((rmin + rmax) / 2, 2)
	avgs.append(ravg)
	rmin = round(rmax + 0.01, 2)
RAINS = {}
RAINS['min'] = mins
maxs[-1] = 99
RAINS['max'] = maxs
RAINS['avg'] = avgs
print(RAINS)

# NOTE: this code double counts rainfall

def count(filename, y1, y2):
	columns = {}
	raindays = {}
	raindaysYear = {}
	current_year = y1
	previousDayRain = 0
	f = open(filename, 'r')
	csv_reader = csv.reader(f, delimiter=',')
	for line in csv_reader:
		if "STATION" in line:
			numColumns = 0
			for column in line:
				columns[column] = numColumns
				numColumns = numColumns + 1
		else:
			precip = line[columns["PRCP"]]
			dateString = line[columns["DATE"]]
			ymd = dateString.split('-')
			year = int(ymd[0])
			if year < current_year:
				continue
			if year > y2:
				break
			if year > current_year:
				raindays[year] = raindaysYear
				raindaysYear = {}
				current_year = year
			else:
				if len(precip) > 0:
					twoDayRain = previousDayRain + float(precip)
					for min, max, avg in zip(RAINS['min'], RAINS['max'], RAINS['avg']):
						if twoDayRain >= min and twoDayRain <= max:
							try:
								raindaysYear[avg] = raindaysYear[avg] + 1
							except KeyError:
								raindaysYear[avg] = 1
					previousDayRain = float(precip)
	return raindays
		
def plot(station, y1, y2):
	csvFile = station + ".csv"
	raindays = count(csvFile, y1, y2)
	print(raindays)

	years = []
	for year in raindays:
		years.append(year)
	averages = {}
	for avg in RAINS['avg']:
		averages[avg] = []
		for year in raindays:
			try:
				days = raindays[year][avg]
			except KeyError:
				days = 0
			averages[avg].append(days)
	#print(averages)

	plt.figure(figsize=(6, 4))
	plt.xlim(0,3)
	xs = []
	ys = []
	for avg in RAINS['avg']:
		xs.append(avg)
		days = averages[avg]
		slope, intercept, r_value, p_value, std_err = stats.linregress(years, days)
		ys.append(slope)
		plt.scatter(xs, ys, color='b', alpha=0.6)

	points = [0 for avg in averages]
	plt.plot(xs, points, color='black', alpha=0.9)

	plt.title('double-counted rainfall trends  at ' + station + ' ' + str(years[0]) + "-" + str(years[-1]))
	pngFile = station + "_rftrends.png"
	plt.savefig(pngFile)
	plt.show()

if __name__ == '__main__':	
	if len(sys.argv) > 3:
		station = sys.argv[1]
		y1 = int(sys.argv[2])
		y2 = int(sys.argv[3])
		plot(station, y1, y2)
	else:
		print("usage: python3 rainfall.py station, startYear, endYear")
		exit()
