the code that I am using for experimenting with GRU.
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import *
class N(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(5,2)
self.layers = 4
self.gru = nn.GRU(2, 512, self.layers, batch_first=True)
self.bat = nn.BatchNorm1d(4)
self.bat1 = nn.BatchNorm1d(4)
self.bat2 = nn.BatchNorm1d(4)
self.fc = nn.Linear(512,100)
self.fc1 = nn.Linear(100,100)
self.fc2 = nn.Linear(100,5)
self.s = nn.Softmax(dim=-1)
def forward(self,x):
h0 = torch.zeros(self.layers, x.size(0), 512).requires_grad_()
x = self.embed(x)
x,hn = self.gru(x,h0)
x = self.bat(x)
x = self.fc(x)
x = nn.functional.relu(x)
x = self.bat1(x)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.bat2(x)
x = self.fc2(x)
softmaxed = self.s(x)
return softmaxed
inp = torch.tensor([[4,3,2,1],[2,3,4,1],[4,1,2,3],[1,2,3,4]])
out = torch.tensor([[3,2,1,4],[3,2,4,1],[1,2,3,4],[2,3,4,1]])
k = 0
n = N()
opt = torch.optim.Adam(n.parameters(),lr=0.0001)
while k<10000:
print(inp.shape)
o = n(inp)
o = o.view(-1, o.size(-1))
out = out.view(-1)
loss = nn.functional.cross_entropy(o.view(-1,o.size(-1)),out.view(-1)-1)
acc = ((torch.argmax(o, dim=1) == (out -1)).sum().item() / out.size(0))
if k==10000:
print(torch.argmax(o, dim=1))
print(out-1)
exit()
print(loss,acc)
loss.backward()
opt.step()
opt.zero_grad()
k+=1
print(o[0])
Shrinked Output:
torch.Size([4, 4])
tensor(0.9593, grad_fn=<NllLossBackward>) 0.9375
torch.Size([4, 4])
tensor(0.9593, grad_fn=<NllLossBackward>) 0.9375
tensor([4.8500e-01, 9.7813e-06, 5.1498e-01, 6.2428e-06, 7.5929e-06],
grad_fn=<SelectBackward>)
The Loss is 0.9593 and accuracy reached up to 0.9375. For this simple input data, the GRU loss is this big. What is the reason? Is there anything wrong in this code? I used cross_entropy as loss function and Adam as the optimizer. Learning rate is 0.001. I tried multiple learning rates but all gave the same final result. I added batch normalization, it speed up the training, but the same loss and accuracy. Why loss does not decrease up to 0.2 or something.