You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
33 lines
803 B
33 lines
803 B
import pickle
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from sklearn.metrics import accuracy_score
|
|
|
|
INFILE = "data1.pickle"
|
|
OUTFILE = "model.pickle"
|
|
|
|
if __name__ == "__main__":
|
|
df = pd.read_pickle(INFILE)
|
|
|
|
X = df.values[:,0:-1]
|
|
y = df.values[:,-1]
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 100)
|
|
|
|
clf = DecisionTreeClassifier()
|
|
clf.fit(X_train, y_train)
|
|
|
|
y_pred = clf.predict(X_test)
|
|
|
|
print("Accuracy score: {:.4f}".format(accuracy_score(y_test, y_pred)*100))
|
|
|
|
# above was a test, now fit the actual model using the entire data
|
|
clf_full = DecisionTreeClassifier()
|
|
clf_full.fit(X, y)
|
|
|
|
with open(OUTFILE, "wb") as fp:
|
|
pickle.dump(clf_full, fp)
|
|
|