import math import random from matplotlib import pyplot as plt def get_data(): """ opens up the Iris.csv file """ lista = [] with open("Iris.csv") as infile: infile.readline() # remove first line for line in infile: values = line.strip().split(',') if values[5] == "Iris-setosa": cat = 1 elif values[5] == "Iris-versicolor": cat = 2 else: cat = 0 tupla = (float(values[1]), float(values[2]), float(values[3]), float(values[4]), cat) lista.append(tupla) return lista def stats(lista): counts = [0,0,0] for element in lista: counts[element[-1]] += 1 return counts def gini(lista): counts = stats(lista) counts = [counts[0]/len(lista), counts[1]/len(lista), counts[2]/len(lista)] return 1-counts[0]**2-counts[1]**2-counts[2]**2 def entropy(lista): counts = stats(lista) proportions = [counts[0]/len(lista), counts[1]/len(lista), counts[2]/len(lista)] entropy = 0 #print(proportions) for prop in proportions: if prop!=0: entropy -= prop*math.log(prop,2) #print(prop*math.log(prop,2)) return entropy def unique(lista): result = [] for value in lista: if value not in result: result.append(value) return result def midpoints(lista, axis): """ calculates the midpoints along the coordinate axis """ values = unique(sorted([pt[axis] for pt in lista])) return [ round((values[i-1]+values[i])/2,3) for i in range(1, len(values))] def split(lista, axis, value): """ returns two lists, depending on pt[axis] < value or not """ left, right = [], [] for element in lista: if element[axis] < value: left.append(element) else: right.append(element) return left, right def best_split(lista, threshold = 3): best_gain = 0 best_split = None gini_total = gini(lista) nr = len(lista) for axis in range(4): for value in midpoints(lista, axis): left, right = split(lista, axis, value) if len(left) > threshold and len(right) > threshold: gain = gini_total - len(left)/nr*gini(left)-len(right)/nr*gini(right) if gain > best_gain: best_gain = gain best_split = (axis, value) return best_split def separate(lista, p): train, test = [], [] for element in lista: if random.random() < p: train.append(element) else: test.append(element) return train, test Iris = get_data() train, test = separate(Iris, 0.8) def predict(element): if element[2] < 2.45: return 1 else: if element[3] < 1.75: if element[2] < 4.95: return 2 else: return 0 else: return 0 plt.figure(figsize = (5,6)) plt.scatter( [element[2] for element in Iris if element[-1]==0], [element[3] for element in Iris if element[-1]==0], c='red' ) plt.scatter( [element[2] for element in Iris if element[-1]==1], [element[3] for element in Iris if element[-1]==1], c='blue' ) plt.scatter( [element[2] for element in Iris if element[-1]==2], [element[3] for element in Iris if element[-1]==2], c='green' ) plt.show()