import sys
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import csv

MONTHS = ["JUN", "JUL", "AUG", "SEP"]
MONTHNUM = [6, 7, 8, 9]
COLORS = ["green", "red", "blue", "black"]
YEARS = [1902, 2019]

def create_months(filename, station):
	months = {}
	for month in MONTHS:
		months[month] = []
	columns = {}
	f = open(filename, 'r')
	csv_reader = csv.reader(f, delimiter=',')
	prev_year = YEARS[0]
	monthMax = {}
	for month in MONTHS:
		monthMax[month] = 0
	for line in csv_reader:
		if "STATION" in line:
			numColumns = 0
			for column in line:
				columns[column] = numColumns
				numColumns = numColumns + 1
			continue
		date = line[columns["DATE"]]
		ymd = date.split('-')
		year = int(ymd[0])
		if year < YEARS[0] or year > YEARS[1] + 1:
			continue
		mnum = int(ymd[1])
		if year > prev_year:
			for month in MONTHS:
				if monthMax[month] > 0:
					months[month].append(monthMax[month])
				else:
					months[month].append(999)
				monthMax[month] = 0
			prev_year = year
		else:
			max = line[columns["TMAX"]]
			min = line[columns["TMIN"]]
			if len(max) > 0:
				tmax = int(max)
				for month, monthnum in zip(MONTHS, MONTHNUM):
					if mnum == monthnum and tmax > monthMax[month]:
						monthMax[month] = tmax
	return months
		
def plot():
	years = list(range(YEARS[0], YEARS[1] + 1))
	months = create_months("kentfield.csv", "KENTFIELD, CA US")
	print(len(years))
	print(len(months['JUN']))

	fig = plt.figure(figsize=(8,4))
	for month, color in zip(MONTHS, COLORS):
		years = list(range(YEARS[0], YEARS[1] + 1))
		monthData = months[month]
		n = 0
		validYears = []
		validTemps = []
		for year in years:
			tmax = monthData[n]
			n = n + 1
			if tmax != 999:
				validYears.append(year)
				validTemps.append(tmax)

		slope, intercept, r_value, p_value, std_err = stats.linregress(validYears, validTemps)
		points = [slope * y + intercept for y in validYears]
		label = month + ' ' + str(round(slope, 3))
		plt.scatter(validYears, validTemps, color=color, label=label, alpha=0.5)
		plt.plot(validYears, points, color=color)

#	plt.tight_layout()
	plt.legend(ncol = 4, loc='lower right')
	plt.title("KENTFIELD, CA US")
	pngFile = "output.png"
	plt.savefig(pngFile)
	plt.show()

if __name__ == '__main__':	
	plot()
