0
votes

Good day,

As an input, I have a n x n square, symmetric matrix with zeros on the diagonal. I would like to select the top k values in a particular row (excluding the diagonal) and set the rest of the values to zero.

I have seen the solution here: R: fast determine top k maximum value in a matrix, but I am not sure how to incorporate this to account for the diagonal which is 0 and should be excluded from the check for the top k values.

I have tried to demonstrate this with a reproducible example with n = 4, k = 2, though in reality n is over 1000 and I would like to select a different number for k, for example 5, 10, 15 etc. The problem statement is best seen in row 2 and row 4 in the example.

Thanks

row1 = c(0, 0.2, 0.28,  -0.11)
row2 = c(0.2, 0, -0.65, -0.50)
row3 = c(0.28, -0.65, 0,    0.3)
row4 = c(-0.11, -0.50, 0.3, 0)

input_matrix = rbind(row1, row2, row3, row4)
input_matrix = as.matrix(input_matrix)
input_matrix

k = 2 # select the two largest off diagonal and set the rest to 0

row1 = c(0, 0.2, 0.28,  0)
row2 = c(0.2, 0, 0, -0.50)
row3 = c(0.28, 0, 0,    0.3)
row4 = c(-0.11, 0, 0.3, 0)

output_matrix = rbind(row1, row2, row3, row4)
output_matrix = as.matrix(output_matrix)
output_matrix

2

2 Answers

1
votes

You can use order :

diag(input_matrix) <- NA
k <- 2
t(apply(input_matrix, 1, function(x) 
        replace(x, -order(x, decreasing = TRUE)[1:k], 0)))

#      [,1] [,2] [,3] [,4]
#row1  0.00  0.2 0.28  0.0
#row2  0.20  0.0 0.00 -0.5
#row3  0.28  0.0 0.00  0.3
#row4 -0.11  0.0 0.30  0.0
1
votes

Using your provided data:

row1 = c(0, 0.2, 0.28,  -0.11)
row2 = c(0.2, 0, -0.65, -0.50)
row3 = c(0.28, -0.65, 0,    0.3)
row4 = c(-0.11, -0.50, 0.3, 0)

input_matrix = rbind(row1, row2, row3, row4)
input_matrix = as.matrix(input_matrix)
input_matrix

output_matrix <- input_matrix 
diag(output_matrix) <- NA

kmat <- function(mat, k){
  mat[t(apply(mat, 2, rank, na.last = F)) <= (ncol(mat)-k)] <- 0
  mat
}

kmat(output_matrix, k = 2)
      [,1] [,2] [,3] [,4]
row1  0.00  0.2 0.28  0.0
row2  0.20  0.0 0.00 -0.5
row3  0.28  0.0 0.00  0.3
row4 -0.11  0.0 0.30  0.0