I have an tensor of probabilities, and I want to select the highest probability per column.
import tensorflow as tf
x = tf.random.normal(mean=0, stddev=1, shape=(5, 4))
<tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[ 0.29182285, -0.30140358, -1.6745052 , 0.04754949],
[-0.4013166 , 0.98574334, -0.33083338, -1.2661675 ],
[-0.49397126, -0.19175254, 1.4922557 , -0.9857015 ],
[ 0.57078785, 0.30159777, 0.34956333, 0.17599773],
[ 1.0847769 , -0.58611375, -0.10727347, -0.13440256]],
dtype=float32)>
I'm trying to select the highest probabilities per column like I would with numpy
and it doesn't work:
x[tf.argmax(x, axis=0), :]
TypeError: Only integers, slices (
:
), ellipsis (...
), tf.newaxis (None
) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Tensor: shape=(4,), dtype=int32, numpy=array([4, 1, 2, 3])>
What would be the correct way to do this?