I am trying to implement gradient descent for linear regression using this resource: https://spin.atomicobject.com/2014/06/24/gradient-descent-linear-regression/
My problem is that my weights are exploding (increasing exponentially) and essentially doing the opposite of what is intended.
First I created a data set:
def y(x, a):
return 2*x + a*np.random.random_sample(len(x)) - a/2
x = np.arange(20)
y_true = y(x,10)
Which looks like this:
And the linear function to be optimized:
def y_predict(x, m, b):
return m*x + b
So for some randomly chosen parameters, this is the result:
m0 = 1
b0 = 1
a = y_predict(x, m0, b0)
plt.scatter(x, y_true)
plt.plot(x, a)
plt.show()
Now the cost would look like this:
cost = (1/2)* np.sum((y_true - a) ** 2)
The partial derivative of the cost with respect to the prediction (dc_da):
dc_da = (a - y_true) # still a vector
The partial derivative of the cost with respect to the slope parameter (dc_dm):
dc_dm = dc_da.dot(x) # now a constant
And the partial derivative of the cost with respect to the y-intercept parameter (dc_db):
dc_db = np.sum(dc_da) # also a constant
And finally the implementation of gradient descent:
iterations = 10
m0 = 1
b0 = 1
learning_rate = 0.1
N = len(x)
for i in range(iterations):
a = y_predict(x, m0, b0)
cost = (1/2) * np.sum((y_true - a) ** 2)
dc_da = (a - y_true)
mgrad = dc_da.dot(x)
bgrad = np.sum(dc_da)
m0 -= learning_rate * (2 / N) * mgrad
b0 -= learning_rate * (2 / N) * bgrad
if (i % 2 == 0):
print("Iteration {}".format(i))
print("Cost: {}, m: {}, b: {}\n".format(cost, m0, b0))
For which the result is:
Iteration 0
Cost: 1341.5241150881411, m: 26.02473879743261, b: 2.8683883457327797
Iteration 2
Cost: 409781757.38124645, m: 13657.166910552878, b: 1053.5831308528543
Iteration 4
Cost: 132510115599264.75, m: 7765058.4350503925, b: 598610.1166795876
Iteration 6
Cost: 4.284947676217907e+19, m: 4415631880.089208, b: 340401694.5610262
Iteration 8
Cost: 1.3856132043127762e+25, m: 2510967578365.3584, b: 193570850213.62192
Clearly, something is wrong. But I do not know what is wrong with my implementation.
Thanks for reading