This could be solved using np.argpartition
to get the indices of largest k
elements and np.ix_
for selecting and setting the dot product of selected elements from m1
and m2
. So, we would have basically two stages to implement this, as discussed next.
First off, get the indices corresponding to largest k
elements in m1
and m2
, like so -
m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel()
m2_idx = np.argpartition(-m2,k)[:,:k].ravel()
Finally, setup output array. Use np.ix_
to broadcast the m1
and m2
indices along the rows and columns respectively for selecting elements in the output array that are to be set. Next up, calculate the dot product between the highest k
elements from m1
and m2
, which could be obtained from m1
and m2
using indexing with m1_idx
and m2_idx
, like so -
out = np.zeros((n,n))
out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx])
Let's verify the implementation with a sample run by running it against another implementation that does explicit setting of lower n-k
elements as 0
s in m1
, m2
and then performing dot product. Here's a sample run to perform the check -
1) Inputs :
In [170]: m1
Out[170]:
array([[ 0.26980423],
[ 0.30698416],
[ 0.60391089],
[ 0.73246763],
[ 0.35276247]])
In [171]: m2
Out[171]: array([[ 0.30523552, 0.87411242, 0.01071218, 0.81835438, 0.21693231]])
In [172]: k = 2
2) Run proposed implementation :
In [173]: # Proposed solution code
...: m1_idx = np.argpartition(-m1,k,axis=0)[:k].ravel()
...: m2_idx = np.argpartition(-m2,k)[:,:k].ravel()
...: out = np.zeros((n,n))
...: out[np.ix_(m1_idx,m2_idx)] = np.dot(m1[m1_idx],m2[:,m2_idx])
...:
3) Use alternative implementation to get the output :
In [174]: # Explicit setting of lower n-k elements to zeros for m1 and m2
...: m1[np.argpartition(-m1,k,axis=0)[k:]] = 0
...: m2[:,np.argpartition(-m2,k)[:,k:].ravel()] = 0
...:
In [175]: m1 # Verify m1 and m2 have lower n-k elements set to 0s
Out[175]:
array([[ 0. ],
[ 0. ],
[ 0.60391089],
[ 0.73246763],
[ 0. ]])
In [176]: m2
Out[176]: array([[ 0. , 0.87411242, 0. , 0.81835438, 0. ]])
In [177]: m1.dot(m2) # Use m1.dot(m2) to directly get output. This is expensive.
Out[177]:
array([[ 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0.52788601, 0. , 0.49421312, 0. ],
[ 0. , 0.64025905, 0. , 0.59941809, 0. ],
[ 0. , 0. , 0. , 0. , 0. ]])
4) Verify our proposed implementation :
In [178]: out # Print output from proposed solution obtained earlier
Out[178]:
array([[ 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0.52788601, 0. , 0.49421312, 0. ],
[ 0. , 0.64025905, 0. , 0.59941809, 0. ],
[ 0. , 0. , 0. , 0. , 0. ]])