@@ -64,6 +64,8 @@ train_Y = data_Y[:train_size] | |||||
test_X = data_X[train_size:] | test_X = data_X[train_size:] | ||||
test_Y = data_Y[train_size:] | test_Y = data_Y[train_size:] | ||||
train_Y.shape | |||||
# 最后,我们需要将数据改变一下形状,因为 RNN 读入的数据维度是 (seq, batch, feature),所以要重新改变一下数据的维度,这里只有一个序列,所以 batch 是 1,而输入的 feature 就是我们希望依据的几个月份,这里我们定的是两个月份,所以 feature 就是 2. | # 最后,我们需要将数据改变一下形状,因为 RNN 读入的数据维度是 (seq, batch, feature),所以要重新改变一下数据的维度,这里只有一个序列,所以 batch 是 1,而输入的 feature 就是我们希望依据的几个月份,这里我们定的是两个月份,所以 feature 就是 2. | ||||
# + | # + | ||||
@@ -0,0 +1,117 @@ | |||||
import numpy as np | |||||
import pandas as pd | |||||
import matplotlib.pyplot as plt | |||||
import torch | |||||
from torch import nn | |||||
from torch.autograd import Variable | |||||
""" | |||||
Using torch to do time series analysis by LSTM model | |||||
""" | |||||
# load data | |||||
data_csv = pd.read_csv("./lstm_data.csv", usecols=[1]) | |||||
#plt.plot(data_csv) | |||||
#plt.show() | |||||
# data pre-processing | |||||
data_csv = data_csv.dropna() | |||||
dataset = data_csv.values | |||||
dataset = dataset.astype("float32") | |||||
val_max = np.max(dataset) | |||||
val_min = np.min(dataset) | |||||
val_scale = val_max - val_min | |||||
dataset = (dataset - val_min) / val_scale | |||||
# generate dataset | |||||
def create_dataset(dataset, look_back=6): | |||||
dataX, dataY = [], [] | |||||
dataset = dataset.tolist() | |||||
for i in range(len(dataset) - look_back): | |||||
a = dataset[i:(i+look_back)] | |||||
dataX.append(a) | |||||
dataY.append(dataset[i+look_back]) | |||||
return np.array(dataX), np.array(dataY) | |||||
look_back = 1 | |||||
data_X, data_Y = create_dataset(dataset, look_back) | |||||
# split train/test dataset | |||||
train_size = int(len(data_X) * 0.7) | |||||
test_size = len(data_X) - train_size | |||||
train_X = data_X[:train_size] | |||||
train_Y = data_Y[:train_size] | |||||
test_X = data_X[train_size:] | |||||
test_Y = data_Y[train_size:] | |||||
# convert data for torch | |||||
train_X = train_X.reshape(-1, 1, look_back) | |||||
train_Y = train_Y.reshape(-1, 1, 1) | |||||
test_X = test_X.reshape(-1, 1, look_back) | |||||
train_x = torch.from_numpy(train_X).float() | |||||
train_y = torch.from_numpy(train_Y).float() | |||||
test_x = torch.from_numpy(test_X).float() | |||||
# define LSTM model | |||||
class LSTM_Reg(nn.Module): | |||||
def __init__(self, input_size, hidden_size, output_size=1, num_layer=2): | |||||
super(LSTM_Reg, self).__init__() | |||||
self.rnn = nn.LSTM(input_size, hidden_size, num_layer) | |||||
self.reg = nn.Linear(hidden_size, output_size) | |||||
def forward(self, x): | |||||
x, _ = self.rnn(x) | |||||
s, b, h = x.shape | |||||
x = x.view(s*b, h) | |||||
x = self.reg(x) | |||||
x = x.view(s, b, -1) | |||||
return x | |||||
net = LSTM_Reg(look_back, 4, num_layer=1) | |||||
criterion = nn.MSELoss() | |||||
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2) | |||||
for e in range(1000): | |||||
var_x = Variable(train_x) | |||||
var_y = Variable(train_y) | |||||
# forward | |||||
out = net(var_x) | |||||
loss = criterion(out, var_y) | |||||
# backward | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
# print progress | |||||
if e % 100 == 0: | |||||
print("epoch: %5d, loss: %.5f" % (e, loss.data[0])) | |||||
# do test | |||||
net = net.eval() | |||||
data_X = data_X.reshape(-1, 1, look_back) | |||||
data_X = torch.from_numpy(data_X).float() | |||||
var_data = Variable(data_X) | |||||
pred_test = net(var_data) | |||||
pred_test = pred_test.view(-1).data.numpy() | |||||
# plot | |||||
plt.plot(pred_test, 'r', label="Prediction") | |||||
plt.plot(dataset, 'b', label="Real") | |||||
plt.legend(loc="best") | |||||
plt.show() |
@@ -0,0 +1,148 @@ | |||||
"Month","International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60" | |||||
"1949-01",112 | |||||
"1949-02",118 | |||||
"1949-03",132 | |||||
"1949-04",129 | |||||
"1949-05",121 | |||||
"1949-06",135 | |||||
"1949-07",148 | |||||
"1949-08",148 | |||||
"1949-09",136 | |||||
"1949-10",119 | |||||
"1949-11",104 | |||||
"1949-12",118 | |||||
"1950-01",115 | |||||
"1950-02",126 | |||||
"1950-03",141 | |||||
"1950-04",135 | |||||
"1950-05",125 | |||||
"1950-06",149 | |||||
"1950-07",170 | |||||
"1950-08",170 | |||||
"1950-09",158 | |||||
"1950-10",133 | |||||
"1950-11",114 | |||||
"1950-12",140 | |||||
"1951-01",145 | |||||
"1951-02",150 | |||||
"1951-03",178 | |||||
"1951-04",163 | |||||
"1951-05",172 | |||||
"1951-06",178 | |||||
"1951-07",199 | |||||
"1951-08",199 | |||||
"1951-09",184 | |||||
"1951-10",162 | |||||
"1951-11",146 | |||||
"1951-12",166 | |||||
"1952-01",171 | |||||
"1952-02",180 | |||||
"1952-03",193 | |||||
"1952-04",181 | |||||
"1952-05",183 | |||||
"1952-06",218 | |||||
"1952-07",230 | |||||
"1952-08",242 | |||||
"1952-09",209 | |||||
"1952-10",191 | |||||
"1952-11",172 | |||||
"1952-12",194 | |||||
"1953-01",196 | |||||
"1953-02",196 | |||||
"1953-03",236 | |||||
"1953-04",235 | |||||
"1953-05",229 | |||||
"1953-06",243 | |||||
"1953-07",264 | |||||
"1953-08",272 | |||||
"1953-09",237 | |||||
"1953-10",211 | |||||
"1953-11",180 | |||||
"1953-12",201 | |||||
"1954-01",204 | |||||
"1954-02",188 | |||||
"1954-03",235 | |||||
"1954-04",227 | |||||
"1954-05",234 | |||||
"1954-06",264 | |||||
"1954-07",302 | |||||
"1954-08",293 | |||||
"1954-09",259 | |||||
"1954-10",229 | |||||
"1954-11",203 | |||||
"1954-12",229 | |||||
"1955-01",242 | |||||
"1955-02",233 | |||||
"1955-03",267 | |||||
"1955-04",269 | |||||
"1955-05",270 | |||||
"1955-06",315 | |||||
"1955-07",364 | |||||
"1955-08",347 | |||||
"1955-09",312 | |||||
"1955-10",274 | |||||
"1955-11",237 | |||||
"1955-12",278 | |||||
"1956-01",284 | |||||
"1956-02",277 | |||||
"1956-03",317 | |||||
"1956-04",313 | |||||
"1956-05",318 | |||||
"1956-06",374 | |||||
"1956-07",413 | |||||
"1956-08",405 | |||||
"1956-09",355 | |||||
"1956-10",306 | |||||
"1956-11",271 | |||||
"1956-12",306 | |||||
"1957-01",315 | |||||
"1957-02",301 | |||||
"1957-03",356 | |||||
"1957-04",348 | |||||
"1957-05",355 | |||||
"1957-06",422 | |||||
"1957-07",465 | |||||
"1957-08",467 | |||||
"1957-09",404 | |||||
"1957-10",347 | |||||
"1957-11",305 | |||||
"1957-12",336 | |||||
"1958-01",340 | |||||
"1958-02",318 | |||||
"1958-03",362 | |||||
"1958-04",348 | |||||
"1958-05",363 | |||||
"1958-06",435 | |||||
"1958-07",491 | |||||
"1958-08",505 | |||||
"1958-09",404 | |||||
"1958-10",359 | |||||
"1958-11",310 | |||||
"1958-12",337 | |||||
"1959-01",360 | |||||
"1959-02",342 | |||||
"1959-03",406 | |||||
"1959-04",396 | |||||
"1959-05",420 | |||||
"1959-06",472 | |||||
"1959-07",548 | |||||
"1959-08",559 | |||||
"1959-09",463 | |||||
"1959-10",407 | |||||
"1959-11",362 | |||||
"1959-12",405 | |||||
"1960-01",417 | |||||
"1960-02",391 | |||||
"1960-03",419 | |||||
"1960-04",461 | |||||
"1960-05",472 | |||||
"1960-06",535 | |||||
"1960-07",622 | |||||
"1960-08",606 | |||||
"1960-09",508 | |||||
"1960-10",461 | |||||
"1960-11",390 | |||||
"1960-12",432 | |||||
International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60 | |||||