0
votes

I will try to keep as less code as possible. sorry in advance.

I have implemented a gradient descent algorithm on the following grid:

[x,y] = meshgrid(-3:.1:3,-3:.1:3);
f = 80.*(x.^4 )+0.01.*(y.^6 );

which I am plotting using:

surf(x,y,f); xlabel('x'); ylabel('y'); zlabel('f(x,y)');
print('f.png','-dpng');
hold on;

After which comes the algorithm. The hold on; is there because in each iteration when I have a new x,y point I want to plot the convergence plot of the algorithm on the face of the original function.

In other words, I wish each iteration to add a line on the existing plots, describing the descent of the function.

How do I plot this?

I am adding my gradient algorithm for reference. the quiver method is my failed attempt and any other suggestion is more than welcome:

eta = 0.001;
x0 = 1;
y0 = 1;
eps = 1e-4;
dw = 10;
maxIter=5000000;
itr=1;
f_GD=zeros(maxIter,1);

x = x0;
y= y0;
while itr<maxIter && dw > eps

    dx = 4.*80.*(x.^3 );
    dy = 6.*0.01.*(y.^5 );

    x = x - eta*dx;
    y = y - eta*dy;

    dw = sqrt(dx^2+dy^2);
    f_GD(itr)=80.*(x.^4 )+0.01.*(y.^6 );
    quiver(x,y,dx,dy,2,'linewidth',3);
    itr=itr+1;

end
hold off;
1
Duplicate: stackoverflow.com/q/48033731/7328782 -- the question doesn't sound like it's a duplicate, but will certainly show you how to solve your problem. - Cris Luengo
I can't understand how you want to plot quiver on the face of gradient descent surface while the quiver values much bigger than gradient y range (-3:3). Can you explain, please? - Mohammad nagdawi
@nagdawi I was told it might help. I am happy with any other solution as well. - havakok
@CrisLuengo I have read this question a number of times. I am failing to understand how this answers my need. With your help maybe you can elaborate? - havakok
It shows how to add points to plotted line. Inside your gradient descent algorithm, you'd update the line to add the new point, showing how the function progresses. - Cris Luengo

1 Answers

0
votes

I have found a solution by simply using the plot3(x,y,z) function:

eta = 0.001;
x0 = 1;
y0 = 1;
eps = 1e-4;
dw = 10;
maxIter=5000000;
itr=1;
f_GD=zeros(maxIter,1);

x = x0;
y= y0;
while itr<maxIter && dw > eps

    dx = 4.*80.*(x.^3 );
    dy = 6.*0.01.*(y.^5 );

    x = x - eta*dx;
    y = y - eta*dy;

    dw = sqrt(dx^2+dy^2);
    f_GD(itr)=80.*(x.^4 )+0.01.*(y.^6 );
    plot3(x,y,f_GD);
    itr=itr+1;

end
hold off;