I've faced a ValueError while training a BiLSTM part of speech tagger using pytorch. ValueError: Expected input batch_size (256) to match target batch_size (128).
def train(model, iterator, optimizer, criterion, tag_pad_idx):
epoch_loss = 0
epoch_acc = 0
model.train()
for batch in iterator:
text = batch.p
tags = batch.t
optimizer.zero_grad()
#text = [sent len, batch size]
predictions = model(text)
#predictions = [sent len, batch size, output dim]
#tags = [sent len, batch size]
predictions = predictions.view(-1, predictions.shape[-1])
tags = tags.view(-1)
#predictions = [sent len * batch size, output dim]
#tags = [sent len * batch size]
loss = criterion(predictions, tags)
acc = categorical_accuracy(predictions, tags, tag_pad_idx)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
def evaluate(model, iterator, criterion, tag_pad_idx):
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for batch in iterator:
text = batch.p
tags = batch.t
predictions = model(text)
predictions = predictions.view(-1, predictions.shape[-1])
tags = tags.view(-1)
loss = criterion(predictions, tags)
acc = categorical_accuracy(predictions, tags, tag_pad_idx)
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
class BiLSTMPOSTagger(nn.Module):
def __init__(self,
input_dim,
embedding_dim,
hidden_dim,
output_dim,
n_layers,
bidirectional,
dropout,
pad_idx):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)
self.lstm = nn.LSTM(embedding_dim,
hidden_dim,
num_layers = n_layers,
bidirectional = bidirectional,
dropout = dropout if n_layers > 1 else 0)
self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, text):
embedded = self.dropout(self.embedding(text))
outputs, (hidden, cell) = self.lstm(embedded)
predictions = self.fc(self.dropout(outputs))
return predictions
........................... ........................... ........................... ...........................
INPUT_DIM = len(POS.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = len(TAG.vocab)
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25
PAD_IDX = POS.vocab.stoi[POS.pad_token]
print(INPUT_DIM) #output 22147
print(OUTPUT_DIM) #output 42
model = BiLSTMPOSTagger(INPUT_DIM,
EMBEDDING_DIM,
HIDDEN_DIM,
OUTPUT_DIM,
N_LAYERS,
BIDIRECTIONAL,
DROPOUT,
PAD_IDX)
........................... ........................... ........................... ...........................
N_EPOCHS = 10
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'tut1-model.pt')
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
ValueError Traceback (most recent call last)
<ipython-input-55-83bf30366feb> in <module>()
7 start_time = time.time()
8
----> 9 train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
10 valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
11
4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
2260 if input.size(0) != target.size(0):
2261 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262 .format(input.size(0), target.size(0)))
2263 if dim == 2:
2264 ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
ValueError: Expected input batch_size (256) to match target batch_size (128).
tags.shapeandpredictions.shapeonce before you apply.view()and once after? - Theodor Peifer