import sys
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from datetime import date
import csv

MAX_MOISTURE = 5
COUNTDOWN_SPRING = 5
COUNTDOWN_FALL = 15
YEARS = [1878, 2019]

def findDrySeason(filename, name):
	dryStarts = {}
	dryEnds = {}
	columns = {}
	f = open(filename, 'r')
	csv_reader = csv.reader(f, delimiter=',')
	prev_year = YEARS[0]
	dryStart = 999
	dryEnd = 999
	moisture = 2
	evap = 0
	for line in csv_reader:
		if "STATION" in line:
			numColumns = 0
			for column in line:
				columns[column] = numColumns
				numColumns = numColumns + 1
			continue
		sname = line[columns["NAME"]]
		if name not in sname:
			continue
		dateStr = line[columns["DATE"]]
		ymd = dateStr.split('-')
		year = int(ymd[0])
		if year < YEARS[0] or year > YEARS[1] + 1:
			continue
		month = int(ymd[1])
		day = int(ymd[2])
		day_of_year = date(year, month, day).timetuple().tm_yday
		if year > prev_year:
			while year > prev_year + 1:
				dryStarts[year] = 999
				dryEnds[year] = 999
				prev_year = prev_year + 1
			dryStarts[year] = dryStart
			dryEnds[year] = dryEnd
			dryStart = 999
			dryEnd = 999
			prev_year = year
		else:
			max = line[columns["TMAX"]]
			min = line[columns["TMIN"]]
			precip = line[columns["PRCP"]]
			if len(precip) > 0:
				moisture = moisture + float(precip)
			if moisture > MAX_MOISTURE:
				moisture = MAX_MOISTURE
			if dryStart < 999:
				countdown_spring = countdown_spring - 1
			if dryEnd < 999:
				countdown_fall = countdown_fall - 1
			if len(max) > 0:
				tmax = int(max)
				if tmax < 125:	# sanity check, data has some errors
					if tmax > 60:
						# linear evaporation formula: 0.1 loss at 60F but 0.5 a day at 100F
						evap = 0.1 + (tmax - 60) / 100
						moisture = moisture - evap
						if moisture <= 0:
							moisture = 0
							# SPRING
							if dryStart == 999:
								dryStart = day_of_year
								countdown_spring = COUNTDOWN_SPRING
						if moisture >= 1.5 and dryStart < 999 and (countdown_spring >= 0 or day_of_year < 60):
							dryStart = 999
				# FALL
				if day_of_year == 365 and dryEnd == 999:
					dryEnd = 365
				if moisture > 1.5 and day_of_year > 180:
					if dryEnd == 999:
						dryEnd = day_of_year
						countdown_fall = COUNTDOWN_FALL
					if moisture == 0 and dryEnd < 999 and (countdown_fall >= 0 or day_of_year < 60):
						dryEND = 999
			#print(year, day_of_year, 'p', precip, 'm', moisture, 's', dryStart, 'e', dryEnd)
					
	return dryStarts, dryEnds
		
def plot(csv, station, name):
	years = list(range(YEARS[0], YEARS[1] + 1))
	dryStarts, dryEnds = findDrySeason(csv, name)

	starts = []
	lengths = []
	ends = []
	dataYears = []
	for year in years:
		try:
			start = dryStarts[year]
			end = dryEnds[year]
			if start != 999 and end != 999:
				starts.append(start)
				ends.append(end)
				lengths.append(end - start)
				dataYears.append(year)
		except KeyError:
			pass

	fig = plt.figure(figsize=(10,4))

	plt.bar(dataYears, lengths, bottom=starts, width=1.0, color='brown', alpha=0.6)

	slope, intercept, r_value, p_value, std_err = stats.linregress(dataYears, starts)
	points = [slope * y + intercept for y in dataYears]
	label = str(round(slope, 3))
	plt.plot(dataYears, points, color='green', label=label)

	slope, intercept, r_value, p_value, std_err = stats.linregress(dataYears, ends)
	points = [slope * y + intercept for y in dataYears]
	label = str(round(slope, 3))
	plt.plot(dataYears, points, color='red', label=label)

	plt.legend(ncol = 4, loc='lower right')
	plt.title(station + " dry season")
	pngFile = "output.png"
	plt.savefig(pngFile)
	plt.show()

if __name__ == '__main__':	
	plot("california.csv", "Nevada City, CA", "NEVA")
