I have a small performance bottleneck in an application that requires removing the non-diagonal elements from a large square matrix. So, the matrix x
17 24 1 8 15
23 5 7 14 16
4 6 13 20 22
10 12 19 21 3
11 18 25 2 9
becomes
17 0 0 0 0
0 5 0 0 0
0 0 13 0 0
0 0 0 21 0
0 0 0 0 9
Question: The bsxfun and diag solution below is the fastest solution so far, and I doubt I can improve it while still keeping the code in Matlab, but is there a faster way?
Solutions
Here is what I thought of so far.
Perform element-wise multiplication by the identity matrix. This is the simplest solution:
y = x .* eye(n);
Using bsxfun
and diag
:
y = bsxfun(@times, diag(x), eye(n));
Lower/upper triangular matrices:
y = x - tril(x, -1) - triu(x, 1);
Various solutions using loops:
y = x;
for ix=1:n
for jx=1:n
if ix ~= jx
y(ix, jx) = 0;
end
end
end
and
y = x;
for ix=1:n
for jx=1:ix-1
y(ix, jx) = 0;
end
for jx=ix+1:n
y(ix, jx) = 0;
end
end
Timing
The bsxfun
solution is actually the fastest. This is my timing code:
function timing()
clear all
n = 5000;
x = rand(n, n);
f1 = @() tf1(x, n);
f2 = @() tf2(x, n);
f3 = @() tf3(x);
f4 = @() tf4(x, n);
f5 = @() tf5(x, n);
t1 = timeit(f1);
t2 = timeit(f2);
t3 = timeit(f3);
t4 = timeit(f4);
t5 = timeit(f5);
fprintf('t1: %f s\n', t1)
fprintf('t2: %f s\n', t2)
fprintf('t3: %f s\n', t3)
fprintf('t4: %f s\n', t4)
fprintf('t5: %f s\n', t5)
end
function y = tf1(x, n)
y = x .* eye(n);
end
function y = tf2(x, n)
y = bsxfun(@times, diag(x), eye(n));
end
function y = tf3(x)
y = x - tril(x, -1) - triu(x, 1);
end
function y = tf4(x, n)
y = x;
for ix=1:n
for jx=1:n
if ix ~= jx
y(ix, jx) = 0;
end
end
end
end
function y = tf5(x, n)
y = x;
for ix=1:n
for jx=1:ix-1
y(ix, jx) = 0;
end
for jx=ix+1:n
y(ix, jx) = 0;
end
end
end
which returns
t1: 0.111117 s
t2: 0.078692 s
t3: 0.219582 s
t4: 1.183389 s
t5: 1.198795 s