-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
107 lines (84 loc) · 3 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import bs4 as bs
import datetime as dt
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import style
import pandas as pd
import pandas_datareader.data as web
import pickle
import requests
style.use('ggplot')
# Save company data in a csv file
def save_sp500_tickers():
resp = requests.get('https://en.wikipedia.org/wiki/List_of_S%26P_500_companies')
soup = bs.BeautifulSoup(resp.text, 'lxml')
table = soup.find('table', {'id': 'constituents'})
tickers = []
for row in table.findAll('tr')[1:]:
ticker = row.find('td').text.replace('\n','')
if "." in ticker:
ticker = ticker.replace('.','-')
tickers.append(ticker)
with open('sp500tickers.pickle', 'wb') as f:
pickle.dump(tickers, f)
print(tickers)
return tickers
# Get S&P 500 company data
def get_data_from_yahoo(reload_sp500=False):
if reload_sp500:
tickers = save_sp500_tickers()
else:
with open('sp500tickers.pickle', 'rb') as f:
tickers = pickle.load(f)
if not os.path.exists('stock_dfs'):
os.makedirs('stock_dfs')
start = dt.datetime(2000, 1, 1)
end = dt.datetime(2020, 5, 7)
for ticker in tickers:
print(ticker)
if not os.path.exists('stock_dfs/{}.csv'.format(ticker)):
df = web.DataReader(ticker, 'yahoo', start, end)
df.to_csv('stock_dfs/{}.csv'.format(ticker))
else:
print('Already have {}'.format(ticker))
# Compile all adjusted close numbers
def compile_data():
with open('sp500tickers.pickle', 'rb') as f:
tickers = pickle.load(f)
main_df = pd.DataFrame()
for count,ticker in enumerate(tickers):
df = pd.read_csv('stock_dfs/{}.csv'.format(ticker))
df.set_index('Date', inplace=True)
df.rename(columns = {'Adj Close': ticker}, inplace=True)
df.drop(['Open', 'High', 'Low', 'Close', 'Volume'], 1, inplace=True)
if main_df.empty:
main_df = df
else:
main_df = main_df.join(df, how='outer')
if count % 10 == 0:
print(count)
print(main_df.head())
main_df.to_csv('sp500_joined_closes.csv')
# Graph data on a heatmap
def visualize_data():
df = pd.read_csv('sp500_joined_closes.csv')
df.set_index('Date', inplace=True)
df_corr = df.pct_change().corr()
data = df_corr.values
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
heatmap = ax.pcolor(data, cmap=plt.cm.RdYlGn)
fig.colorbar(heatmap)
ax.set_xticks(np.arange(data.shape[0]) + 0.5, minor=False)
ax.set_yticks(np.arange(data.shape[1]) + 0.5, minor=False)
ax.invert_yaxis()
ax.xaxis.tick_top()
column_labels = df_corr.columns
row_labels = df_corr.index
ax.set_xticklabels(column_labels)
ax.set_yticklabels(row_labels)
plt.xticks(rotation=90)
heatmap.set_clim(-1,1)
plt.tight_layout()
visualize_data()