StockPredict/predict.py
2023-06-18 13:24:04 +08:00

63 lines
1.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
import baostock as bs
import pandas as pd
import datetime
import joblib
import sys
code = str(sys.argv[1])
lg = bs.login()
print('login respond error_code:'+lg.error_code)
print('login respond error_msg:'+lg.error_msg)
rs = bs.query_history_k_data_plus(code,
"date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,peTTM,pbMRQ,psTTM,pcfNcfTTM,isST",
start_date='2020-01-01', end_date=datetime.datetime.now().strftime('%Y-%m-%d'),
frequency="d", adjustflag="3") #frequency="d"取日k线adjustflag="3"默认不复权
print('query_history_k_data_plus respond error_code:'+rs.error_code)
print('query_history_k_data_plus respond error_msg:'+rs.error_msg)
data_list = []
while (rs.error_code == '0') & rs.next():
data_list.append(rs.get_row_data())
result = pd.DataFrame(data_list, columns=rs.fields)
result.to_csv("datasets/{code}.csv".format(code=code), encoding="utf-8", index=False)
print(result)
bs.logout()
df = pd.read_csv('datasets/{code}.csv'.format(code=code))
df['target'] = [1 if x > 0 else 0 for x in df['pctChg']]
features = ['open', 'high', 'low', 'close', 'volume']
X = df[features]
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print('模型准确率:', accuracy)
joblib.dump(clf, 'models/{code}.joblib'.format(code=code))
df = pd.read_csv('datasets/{code}.csv'.format(code=code)).tail(1)
features = ['open', 'high', 'low', 'close', 'volume']
X_test = df[features]
clf = joblib.load('models/{code}.joblib'.format(code=code))
prediction = clf.predict(X_test)
print('预测结果:', ['' if x == 0 else '' for x in prediction][0])