How to train using burst inputs with Keras, but predicting with one example with LSTM?
I have a list of training data that I am using for training. However, when I predict, the prediction will be done online with one example at a time.
If I declare my model with input like this
model = Sequential()
model.add(Dense(64, batch_input_shape=(100, 5, 1), activation='tanh'))
model.add(LSTM(32, stateful=True))
model.add(Dense(1, activation='linear'))
optimizer = SGD(lr=0.0005)
model.compile(loss='mean_squared_error', optimizer=optimizer)
When I go to predict with one example of the shape (1, 5, 1) it gives the following error.
ValueError: Shape mismatch: x has 100 rows but z has 1 rows
The solution I came across was to simply traverse my model using batch_input_shape (1,5,1) and call the match for each individual example. It's incredibly slow.
Is there no way to train with a large batch size, but predict with one example using LSTM?
Thanks for the help.
source to share
Try something like this:
model2 = Sequential()
model2.add(Dense(64, batch_input_shape=(1, 5, 1), activation='tanh'))
model2.add(LSTM(32, stateful=True))
model2.add(Dense(1, activation='linear'))
optimizer2 = SGD(lr=0.0005)
model2.compile(loss='mean_squared_error', optimizer=optimizer)
for nb, layer in enumerate(model.layers):
model2.layers[nb].set_weights(layer.get_weights())
You are simply rewriting weights from one model to another.
source to share
You defined input_shape
in the first layer. Therefore, submitting a form that does not match the preset input_shape
is valid.
There are two ways to achieve this: You can change your model by changing
batch_input_shape=(100, 5, 1)
to
input_shape=(5, 1)
to avoid the specified batch size. You can install batch_size=100
in model.fit()
.
Edit: Method 2
You are defining the same model as model2
. Then model2.set_weights(model1.get_weights())
.
If you want to use stateful==True
, you really want to use the hidden layers from the last batch as the starting states for the next batch. Therefore the very lot size must be compared. Otherwise, you can simply uninstall stateful==True
.
source to share