Browse Source

add lstm-time-series demo code

fetches/feikei/master
Shuhui Bu 6 years ago
parent
commit
bccc4a07d3
4 changed files with 308 additions and 37 deletions
  1. +41
    -37
      2_pytorch/3_RNN/time-series/lstm-time-series.ipynb
  2. +2
    -0
      2_pytorch/3_RNN/time-series/lstm-time-series.py
  3. +117
    -0
      demo_code/4_LSTM_timeseries.py
  4. +148
    -0
      demo_code/lstm_data.csv

+ 41
- 37
2_pytorch/3_RNN/time-series/lstm-time-series.ipynb
File diff suppressed because it is too large
View File


+ 2
- 0
2_pytorch/3_RNN/time-series/lstm-time-series.py View File

@@ -64,6 +64,8 @@ train_Y = data_Y[:train_size]
test_X = data_X[train_size:]
test_Y = data_Y[train_size:]

train_Y.shape

# 最后,我们需要将数据改变一下形状,因为 RNN 读入的数据维度是 (seq, batch, feature),所以要重新改变一下数据的维度,这里只有一个序列,所以 batch 是 1,而输入的 feature 就是我们希望依据的几个月份,这里我们定的是两个月份,所以 feature 就是 2.

# +


+ 117
- 0
demo_code/4_LSTM_timeseries.py View File

@@ -0,0 +1,117 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.autograd import Variable


"""
Using torch to do time series analysis by LSTM model
"""

# load data
data_csv = pd.read_csv("./lstm_data.csv", usecols=[1])

#plt.plot(data_csv)
#plt.show()

# data pre-processing
data_csv = data_csv.dropna()
dataset = data_csv.values
dataset = dataset.astype("float32")
val_max = np.max(dataset)
val_min = np.min(dataset)
val_scale = val_max - val_min
dataset = (dataset - val_min) / val_scale


# generate dataset
def create_dataset(dataset, look_back=6):
dataX, dataY = [], []
dataset = dataset.tolist()
for i in range(len(dataset) - look_back):
a = dataset[i:(i+look_back)]
dataX.append(a)
dataY.append(dataset[i+look_back])

return np.array(dataX), np.array(dataY)

look_back = 1
data_X, data_Y = create_dataset(dataset, look_back)


# split train/test dataset
train_size = int(len(data_X) * 0.7)
test_size = len(data_X) - train_size

train_X = data_X[:train_size]
train_Y = data_Y[:train_size]
test_X = data_X[train_size:]
test_Y = data_Y[train_size:]


# convert data for torch
train_X = train_X.reshape(-1, 1, look_back)
train_Y = train_Y.reshape(-1, 1, 1)
test_X = test_X.reshape(-1, 1, look_back)

train_x = torch.from_numpy(train_X).float()
train_y = torch.from_numpy(train_Y).float()
test_x = torch.from_numpy(test_X).float()

# define LSTM model
class LSTM_Reg(nn.Module):
def __init__(self, input_size, hidden_size, output_size=1, num_layer=2):
super(LSTM_Reg, self).__init__()

self.rnn = nn.LSTM(input_size, hidden_size, num_layer)
self.reg = nn.Linear(hidden_size, output_size)

def forward(self, x):
x, _ = self.rnn(x)
s, b, h = x.shape
x = x.view(s*b, h)
x = self.reg(x)
x = x.view(s, b, -1)
return x

net = LSTM_Reg(look_back, 4, num_layer=1)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)


for e in range(1000):
var_x = Variable(train_x)
var_y = Variable(train_y)

# forward
out = net(var_x)
loss = criterion(out, var_y)

# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()

# print progress
if e % 100 == 0:
print("epoch: %5d, loss: %.5f" % (e, loss.data[0]))

# do test
net = net.eval()

data_X = data_X.reshape(-1, 1, look_back)
data_X = torch.from_numpy(data_X).float()
var_data = Variable(data_X)
pred_test = net(var_data)

pred_test = pred_test.view(-1).data.numpy()

# plot
plt.plot(pred_test, 'r', label="Prediction")
plt.plot(dataset, 'b', label="Real")
plt.legend(loc="best")
plt.show()

+ 148
- 0
demo_code/lstm_data.csv View File

@@ -0,0 +1,148 @@
"Month","International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60"
"1949-01",112
"1949-02",118
"1949-03",132
"1949-04",129
"1949-05",121
"1949-06",135
"1949-07",148
"1949-08",148
"1949-09",136
"1949-10",119
"1949-11",104
"1949-12",118
"1950-01",115
"1950-02",126
"1950-03",141
"1950-04",135
"1950-05",125
"1950-06",149
"1950-07",170
"1950-08",170
"1950-09",158
"1950-10",133
"1950-11",114
"1950-12",140
"1951-01",145
"1951-02",150
"1951-03",178
"1951-04",163
"1951-05",172
"1951-06",178
"1951-07",199
"1951-08",199
"1951-09",184
"1951-10",162
"1951-11",146
"1951-12",166
"1952-01",171
"1952-02",180
"1952-03",193
"1952-04",181
"1952-05",183
"1952-06",218
"1952-07",230
"1952-08",242
"1952-09",209
"1952-10",191
"1952-11",172
"1952-12",194
"1953-01",196
"1953-02",196
"1953-03",236
"1953-04",235
"1953-05",229
"1953-06",243
"1953-07",264
"1953-08",272
"1953-09",237
"1953-10",211
"1953-11",180
"1953-12",201
"1954-01",204
"1954-02",188
"1954-03",235
"1954-04",227
"1954-05",234
"1954-06",264
"1954-07",302
"1954-08",293
"1954-09",259
"1954-10",229
"1954-11",203
"1954-12",229
"1955-01",242
"1955-02",233
"1955-03",267
"1955-04",269
"1955-05",270
"1955-06",315
"1955-07",364
"1955-08",347
"1955-09",312
"1955-10",274
"1955-11",237
"1955-12",278
"1956-01",284
"1956-02",277
"1956-03",317
"1956-04",313
"1956-05",318
"1956-06",374
"1956-07",413
"1956-08",405
"1956-09",355
"1956-10",306
"1956-11",271
"1956-12",306
"1957-01",315
"1957-02",301
"1957-03",356
"1957-04",348
"1957-05",355
"1957-06",422
"1957-07",465
"1957-08",467
"1957-09",404
"1957-10",347
"1957-11",305
"1957-12",336
"1958-01",340
"1958-02",318
"1958-03",362
"1958-04",348
"1958-05",363
"1958-06",435
"1958-07",491
"1958-08",505
"1958-09",404
"1958-10",359
"1958-11",310
"1958-12",337
"1959-01",360
"1959-02",342
"1959-03",406
"1959-04",396
"1959-05",420
"1959-06",472
"1959-07",548
"1959-08",559
"1959-09",463
"1959-10",407
"1959-11",362
"1959-12",405
"1960-01",417
"1960-02",391
"1960-03",419
"1960-04",461
"1960-05",472
"1960-06",535
"1960-07",622
"1960-08",606
"1960-09",508
"1960-10",461
"1960-11",390
"1960-12",432
International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60

Loading…
Cancel
Save