import sys
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from scipy import stats
import csv
from datetime import date

INCRS = [0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1, 1.25, 1.5, 1.75, 2, 2.3, 2.6, 2.9, 3.3, 3.7, 4.2, 4.6, 5.0, 5.5, 6.0, 6.6, 7.2]
mins = []
maxs = []
avgs = []
rmin = 0.01
for incr in INCRS:
	mins.append(rmin)
	rmax = round(rmin + incr, 2)
	maxs.append(rmax)
	ravg = round((rmin + rmax) / 2, 2)
	avgs.append(ravg)
	rmin = round(rmax + 0.01, 2)
RAINS = {}
RAINS['min'] = mins
maxs[-1] = 99
RAINS['max'] = maxs
RAINS['avg'] = avgs
print(RAINS)

def count(filename, y1, y2):
	columns = {}
	raindays = {}
	raindaysYear = {}
	current_year = y1
	sevenDayRain = 0
	dayCount = 0
	f = open(filename, 'r')
	csv_reader = csv.reader(f, delimiter=',')
	for line in csv_reader:
		if "STATION" in line:
			numColumns = 0
			for column in line:
				columns[column] = numColumns
				numColumns = numColumns + 1
		else:
			precip = line[columns["PRCP"]]
			dateString = line[columns["DATE"]]
			ymd = dateString.split('-')
			year = int(ymd[0])
			if year < current_year:
				continue
			if year > y2:
				break
			if year > current_year:
				raindays[year] = raindaysYear
				raindaysYear = {}
				current_year = year
			else:
				if len(precip) > 0:
					sevenDayRain = sevenDayRain + float(precip)
					dayCount = dayCount + 1
					if dayCount == 7:
						for min, max, avg in zip(RAINS['min'], RAINS['max'], RAINS['avg']):
							if sevenDayRain >= min and float(precip) <= max:
								try:
									raindaysYear[avg] = raindaysYear[avg] + 1
								except KeyError:
									raindaysYear[avg] = 1
						dayCount = 0
						sevenDayRain = 0
	return raindays
		
def plot(station, y1, y2):
	csvFile = station + ".csv"
	raindays = count(csvFile, y1, y2)
	print(raindays)

	years = []
	for year in raindays:
		years.append(year)
	averages = {}
	for avg in RAINS['avg']:
		averages[avg] = []
		for year in raindays:
			try:
				days = raindays[year][avg]
			except KeyError:
				days = 0
			averages[avg].append(days)
	#print(averages)

	plt.figure(figsize=(4, 3))
	plt.xlim(0,7)
	xs = []
	ys = []
	for avg in RAINS['avg']:
		xs.append(avg)
		days = averages[avg]
		slope, intercept, r_value, p_value, std_err = stats.linregress(years, days)
		ys.append(slope)
		plt.scatter(xs, ys, color='b', alpha=0.6)

	points = [0 for avg in averages]
	plt.plot(xs, points, color='black', alpha=0.9)

	plt.title('weekly rainfall trends  at ' + station + ' ' + str(years[0]) + "-" + str(years[-1]))
	pngFile = station + "_weekly_trends.png"
	plt.savefig(pngFile)
	plt.show()

if __name__ == '__main__':	
	if len(sys.argv) > 3:
		station = sys.argv[1]
		y1 = int(sys.argv[2])
		y2 = int(sys.argv[3])
		plot(station, y1, y2)
	else:
		print("usage: python3 rainfall.py station, startYear, endYear")
		exit()
