@@ -64,6 +64,8 @@ train_Y = data_Y[:train_size] | |||
test_X = data_X[train_size:] | |||
test_Y = data_Y[train_size:] | |||
train_Y.shape | |||
# 最后,我们需要将数据改变一下形状,因为 RNN 读入的数据维度是 (seq, batch, feature),所以要重新改变一下数据的维度,这里只有一个序列,所以 batch 是 1,而输入的 feature 就是我们希望依据的几个月份,这里我们定的是两个月份,所以 feature 就是 2. | |||
# + | |||
@@ -0,0 +1,117 @@ | |||
import numpy as np | |||
import pandas as pd | |||
import matplotlib.pyplot as plt | |||
import torch | |||
from torch import nn | |||
from torch.autograd import Variable | |||
""" | |||
Using torch to do time series analysis by LSTM model | |||
""" | |||
# load data | |||
data_csv = pd.read_csv("./lstm_data.csv", usecols=[1]) | |||
#plt.plot(data_csv) | |||
#plt.show() | |||
# data pre-processing | |||
data_csv = data_csv.dropna() | |||
dataset = data_csv.values | |||
dataset = dataset.astype("float32") | |||
val_max = np.max(dataset) | |||
val_min = np.min(dataset) | |||
val_scale = val_max - val_min | |||
dataset = (dataset - val_min) / val_scale | |||
# generate dataset | |||
def create_dataset(dataset, look_back=6): | |||
dataX, dataY = [], [] | |||
dataset = dataset.tolist() | |||
for i in range(len(dataset) - look_back): | |||
a = dataset[i:(i+look_back)] | |||
dataX.append(a) | |||
dataY.append(dataset[i+look_back]) | |||
return np.array(dataX), np.array(dataY) | |||
look_back = 1 | |||
data_X, data_Y = create_dataset(dataset, look_back) | |||
# split train/test dataset | |||
train_size = int(len(data_X) * 0.7) | |||
test_size = len(data_X) - train_size | |||
train_X = data_X[:train_size] | |||
train_Y = data_Y[:train_size] | |||
test_X = data_X[train_size:] | |||
test_Y = data_Y[train_size:] | |||
# convert data for torch | |||
train_X = train_X.reshape(-1, 1, look_back) | |||
train_Y = train_Y.reshape(-1, 1, 1) | |||
test_X = test_X.reshape(-1, 1, look_back) | |||
train_x = torch.from_numpy(train_X).float() | |||
train_y = torch.from_numpy(train_Y).float() | |||
test_x = torch.from_numpy(test_X).float() | |||
# define LSTM model | |||
class LSTM_Reg(nn.Module): | |||
def __init__(self, input_size, hidden_size, output_size=1, num_layer=2): | |||
super(LSTM_Reg, self).__init__() | |||
self.rnn = nn.LSTM(input_size, hidden_size, num_layer) | |||
self.reg = nn.Linear(hidden_size, output_size) | |||
def forward(self, x): | |||
x, _ = self.rnn(x) | |||
s, b, h = x.shape | |||
x = x.view(s*b, h) | |||
x = self.reg(x) | |||
x = x.view(s, b, -1) | |||
return x | |||
net = LSTM_Reg(look_back, 4, num_layer=1) | |||
criterion = nn.MSELoss() | |||
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2) | |||
for e in range(1000): | |||
var_x = Variable(train_x) | |||
var_y = Variable(train_y) | |||
# forward | |||
out = net(var_x) | |||
loss = criterion(out, var_y) | |||
# backward | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
# print progress | |||
if e % 100 == 0: | |||
print("epoch: %5d, loss: %.5f" % (e, loss.data[0])) | |||
# do test | |||
net = net.eval() | |||
data_X = data_X.reshape(-1, 1, look_back) | |||
data_X = torch.from_numpy(data_X).float() | |||
var_data = Variable(data_X) | |||
pred_test = net(var_data) | |||
pred_test = pred_test.view(-1).data.numpy() | |||
# plot | |||
plt.plot(pred_test, 'r', label="Prediction") | |||
plt.plot(dataset, 'b', label="Real") | |||
plt.legend(loc="best") | |||
plt.show() |
@@ -0,0 +1,148 @@ | |||
"Month","International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60" | |||
"1949-01",112 | |||
"1949-02",118 | |||
"1949-03",132 | |||
"1949-04",129 | |||
"1949-05",121 | |||
"1949-06",135 | |||
"1949-07",148 | |||
"1949-08",148 | |||
"1949-09",136 | |||
"1949-10",119 | |||
"1949-11",104 | |||
"1949-12",118 | |||
"1950-01",115 | |||
"1950-02",126 | |||
"1950-03",141 | |||
"1950-04",135 | |||
"1950-05",125 | |||
"1950-06",149 | |||
"1950-07",170 | |||
"1950-08",170 | |||
"1950-09",158 | |||
"1950-10",133 | |||
"1950-11",114 | |||
"1950-12",140 | |||
"1951-01",145 | |||
"1951-02",150 | |||
"1951-03",178 | |||
"1951-04",163 | |||
"1951-05",172 | |||
"1951-06",178 | |||
"1951-07",199 | |||
"1951-08",199 | |||
"1951-09",184 | |||
"1951-10",162 | |||
"1951-11",146 | |||
"1951-12",166 | |||
"1952-01",171 | |||
"1952-02",180 | |||
"1952-03",193 | |||
"1952-04",181 | |||
"1952-05",183 | |||
"1952-06",218 | |||
"1952-07",230 | |||
"1952-08",242 | |||
"1952-09",209 | |||
"1952-10",191 | |||
"1952-11",172 | |||
"1952-12",194 | |||
"1953-01",196 | |||
"1953-02",196 | |||
"1953-03",236 | |||
"1953-04",235 | |||
"1953-05",229 | |||
"1953-06",243 | |||
"1953-07",264 | |||
"1953-08",272 | |||
"1953-09",237 | |||
"1953-10",211 | |||
"1953-11",180 | |||
"1953-12",201 | |||
"1954-01",204 | |||
"1954-02",188 | |||
"1954-03",235 | |||
"1954-04",227 | |||
"1954-05",234 | |||
"1954-06",264 | |||
"1954-07",302 | |||
"1954-08",293 | |||
"1954-09",259 | |||
"1954-10",229 | |||
"1954-11",203 | |||
"1954-12",229 | |||
"1955-01",242 | |||
"1955-02",233 | |||
"1955-03",267 | |||
"1955-04",269 | |||
"1955-05",270 | |||
"1955-06",315 | |||
"1955-07",364 | |||
"1955-08",347 | |||
"1955-09",312 | |||
"1955-10",274 | |||
"1955-11",237 | |||
"1955-12",278 | |||
"1956-01",284 | |||
"1956-02",277 | |||
"1956-03",317 | |||
"1956-04",313 | |||
"1956-05",318 | |||
"1956-06",374 | |||
"1956-07",413 | |||
"1956-08",405 | |||
"1956-09",355 | |||
"1956-10",306 | |||
"1956-11",271 | |||
"1956-12",306 | |||
"1957-01",315 | |||
"1957-02",301 | |||
"1957-03",356 | |||
"1957-04",348 | |||
"1957-05",355 | |||
"1957-06",422 | |||
"1957-07",465 | |||
"1957-08",467 | |||
"1957-09",404 | |||
"1957-10",347 | |||
"1957-11",305 | |||
"1957-12",336 | |||
"1958-01",340 | |||
"1958-02",318 | |||
"1958-03",362 | |||
"1958-04",348 | |||
"1958-05",363 | |||
"1958-06",435 | |||
"1958-07",491 | |||
"1958-08",505 | |||
"1958-09",404 | |||
"1958-10",359 | |||
"1958-11",310 | |||
"1958-12",337 | |||
"1959-01",360 | |||
"1959-02",342 | |||
"1959-03",406 | |||
"1959-04",396 | |||
"1959-05",420 | |||
"1959-06",472 | |||
"1959-07",548 | |||
"1959-08",559 | |||
"1959-09",463 | |||
"1959-10",407 | |||
"1959-11",362 | |||
"1959-12",405 | |||
"1960-01",417 | |||
"1960-02",391 | |||
"1960-03",419 | |||
"1960-04",461 | |||
"1960-05",472 | |||
"1960-06",535 | |||
"1960-07",622 | |||
"1960-08",606 | |||
"1960-09",508 | |||
"1960-10",461 | |||
"1960-11",390 | |||
"1960-12",432 | |||
International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60 | |||