-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MKLDNN RNN seg fault #19265
Description
A customer is experiencing seg fault when feeding in a large input to MKL LSTM. I have reduced the code to this:
import mxnet as mx
from mxnet import gluon, nd, autograd
from mxnet.gluon import nn, rnn, Trainer
hidden_size = 30
num_embed = 100
vocab_size = 13028#len(vocab.token_to_idx.keys())
inp = nd.random.uniform(0, vocab_size, (16758,500))
print(inp)
context = mx.cpu()
model = nn.Sequential()
model.add(nn.Embedding(vocab_size, num_embed), # Embedding layer
rnn.LSTM(hidden_size, num_layers=1,bidirectional=True), # Recurrent layer ,bidirectional=True
nn.Dense(3)) # Output layer
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
val_predictions = model(inp)
nd.waitall()
print(val_predictions)
I think this is some sort of out of memory issue because if we shrink the input (first dim of inp
) then there will not be a seg fault, but still, shall we add some error message here so that users will be notified to reduce the input size?
I also noticed the same input will run fine with export MXNET_USE_MKLDNN_RNN=0
but that is 3x slower than the mkldnn implementation. Another suggestion I made to the customer was to try out a magic number for the seg fault threshold and do multiple batches that are smaller than that (customer was trying to forward pass the entire validation set), but this is also a pretty hacky solution. So maybe better yet, we can optimize the mkldnn implementation to process data that's currently too large?