My question in brief: Is the Long Short Term Memory Network detailed below appropriately designed to generate new dance sequences, given dance sequence training data?
Context: I am working with a dancer who wishes to use a neural network to generate new dance sequences. She sent me the 2016 chor-rnn paper that accomplished this task using an LSTM network with a Mixture Density Network layer at the end. After adding an MDN layer to my LSTM network, however, my loss goes negative and the results seem chaotic. This may be due to the very small training data, but I'd like to validate the model fundamentals before scaling up training data size. If anyone can advise whether the model below is overlooking something fundamental (which is highly likely), I would be entirely grateful for their feedback.
The sample data I'm feeding into the network (X
below) has shape (626, 55, 3), which corresponds to 626 time snapshots of 55 body positions, each with 3 coordinates (x, y, then z). So X1[11][2] is the z position of the 11th body part at time 1:
import requests
import numpy as np
# download the data
requests.get('https://s3.amazonaws.com/duhaime/blog/dancing-with-robots/dance.npy')
# X.shape = time_intervals, n_body_parts, 3
X = np.load('dance.npy')
To make sure the data was extracted correctly, I visualize the first few frames of X
:
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.pyplot as plt
from IPython.display import HTML
from matplotlib import animation
import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**128
def update_points(time, points, X):
arr = np.array([[ X[time][i][0], X[time][i][1] ] for i in range(int(X.shape[1]))])
points.set_offsets(arr) # set x, y values
points.set_3d_properties(X[time][:,2][:], zdir='z') # set z value
def get_plot(X, lim=2, frames=200, duration=45):
fig = plt.figure()
ax = p3.Axes3D(fig)
ax.set_xlim(-lim, lim)
ax.set_ylim(-lim, lim)
ax.set_zlim(-lim, lim)
points = ax.scatter(X[0][:,0][:], X[0][:,1][:], X[0][:,2][:], depthshade=False) # x,y,z vals
return animation.FuncAnimation(fig,
update_points,
frames,
interval=duration,
fargs=(points, X),
blit=False
).to_jshtml()
HTML(get_plot(X, frames=int(X.shape[0])))
That produces a little dancing sequence like this:
So far so good. Next I center the features of the x, y, and z dimensions:
X -= np.amin(X, axis=(0, 1))
X /= np.amax(X, axis=(0, 1))
Visualizing the resulting X
with HTML(get_plot(X, frames=int(X.shape[0])))
shows these lines center the data just fine. Next I build the model itself using the Sequential API in Keras:
from keras.models import Sequential, Model
from keras.layers import Dense, LSTM, Dropout, Activation
from keras.layers.advanced_activations import LeakyReLU
from keras.losses import mean_squared_error
from keras.optimizers import Adam
import keras, os
# config
look_back = 32 # number of previous time frames to use to predict the positions at time i
lstm_cells = 256 # number of cells in each LSTM "layer"
n_features = int(X.shape[1]) * int(X.shape[2]) # number of coordinate values to be predicted by each of `m` models
input_shape = (look_back, n_features) # shape of inputs
m = 32 # number of gaussian models to build
# set boolean controlling whether we use MDN or not
use_mdn = True
model = Sequential()
model.add(LSTM(lstm_cells, return_sequences=True, input_shape=input_shape))
model.add(LSTM(lstm_cells, return_sequences=True))
model.add(LSTM(lstm_cells))
if use_mdn:
model.add(MDN(n_features, m))
model.compile(loss=get_mixture_loss_func(n_features, m), optimizer=Adam(lr=0.000001))
else:
model.add(Dense(n_features, activation='tanh'))
model.compile(loss=mean_squared_error, optimizer='sgd')
model.summary()
Once the model is built, I arrange the data in X
to prepare for training. Here we want to predict the x, y, z positions of of the 55 body parts at some time by examining the positions of each body part at the previous look_back
time slices:
# get training data in right shape
train_x = []
train_y = []
n_time, n_obs, n_attrs = [int(i) for i in X.shape]
for i in range(look_back, n_time-1, 1):
train_x.append( X[i-look_back:i].reshape(look_back, n_obs * n_attrs) )
train_y.append( X[i+1].ravel() )
train_x = np.array(train_x)
train_y = np.array(train_y)
And finally I train the model:
from livelossplot import PlotLossesKeras
# fit the model
model.fit(train_x, train_y, epochs=1024, batch_size=1, callbacks=[PlotLossesKeras()])
After training, I visualize the new time slices created by the model:
# generate `n_frames` of new output time slices
n_frames = 3000
# seed the data to plot with the first `look_back` animation frames
data = X[0:look_back]
x0, x1, x2 = [int(i) for i in train_x.shape]
d0, d1, d2 = [int(i) for i in data.shape]
for i in range(look_back, n_frames, 1):
# get the model's prediction for the next position of points at time `i`
result = model.predict(train_x[i].reshape(1, x1, x2))
# if using the mixed density network, pull out vals that describe vertex positions
if use_mdn:
result = np.apply_along_axis(sample_from_output, 1, result, n_features, m, temp=1.0)
# reshape the result into the form of rows in `X`
result = result.reshape(1, d1, d2)
# push the result into the shape of `train_x` observations
stacked = np.vstack((data[i-look_back+1:i], result)).reshape(1, x1, x2)
# add the result to the `train_x` observations
train_x = np.vstack((train_x, stacked))
# add the result to the dataset for plotting
data = np.vstack((data[:i], result))
If I set use_mdn
to False
above and instead use a simple Sum of Squared Errors Loss (L2 Loss), then the resulting visualization seems a little creepy but still has a generally human shape.
If I set use_mdn
to True
, however, and use the custom MDN loss function, the results are quite odd. I recognize that the MDN layer adds a huge number of parameters that need to be trained, and likely requires orders of magnitude more training data to achieve output that's as human-shaped as the L2 loss function output.
That said, I wanted to ask if others who have worked with neural network models more extensively than myself see anything fundamentally wrong with the approach above. Any insights on this question would be tremendously helpful.