3
votes

I am trying to do some operations for the top K value for a tensor in Tensorflow. Basically, what I want is first get the indices of the top K value, do some operations and assign new value. For example:

A = tf.constant([[1,2,3,4,5],[6,7,8,9,10]])
values, indices = tf.nn.top_k(A, k=3)

For here, the values will be array([[ 5, 4, 3],[10, 9, 8]],dtype=int32)

After I do some operation on values, say prob=tf.nn.softmax(values), how should I assign this value to A according to indices? which is similar to numpy A[indices] = prob. Seems couldn't find right function in tensorflow to do this.

1
You cant change the value of A which is constant array.vijay m

1 Answers

0
votes

Unfortunately, Tensorflow is quite painful once you want to use indices with tensors, so to implement your idea you have to use some ugly workarounds. My option would be:

import tensorflow as tf

#First you better use Variable as constant is not designed to be updated
A = tf.Variable(initial_value = [[1,2,3,4,5],[6,7,8,9,10]])

#Create a buffer variable which will store tentative updates,
#initialize it with random values
t = tf.Variable(initial_value = tf.cast(tf.random_normal(shape=[5]),dtype=tf.int32))
values, indices = tf.nn.top_k(A, k=3)

#Create a function for manipulation on the values you want
def val_manipulation(v):
  return 2*v+1

#Create a while loop to update each entry of the A one-be-one, 
#as scatter_nd_update can update only by slices, but not individual entries
i = tf.constant(0)
#Stop once updated every slice
c = lambda i,x: tf.less(i, tf.shape(A)[0])
#Each iteration update i and 
#update every slice of A (A[i]) with updated slice 
b = lambda i,x: [i+1,tf.scatter_nd_update(A,[[i]],[tf.scatter_update(tf.assign(t,A[i]),indices[i],val_manipulation(values[i]) )])]
#While loop
r = tf.while_loop(c, b, [i,A])

init = tf.initialize_all_variables()

with tf.Session() as s:
  s.run(init)
  #Test it!
  print s.run(A)
  s.run(r)
  print s.run(A)

So basically what you do is:

  1. scatter_update can work with Variables only, so we take a slice from A (as A[i]) and store those values to the buffer Variable t
  2. Update values in the buffer Variable with desired ones
  3. Update i-th slice of A with updated t
  4. repeat to the rest entries of A

Eventually you should get the following output:

[[ 1  2  3  4  5]  [ 6  7  8  9 10]] 
[[ 1  2  7  9 11]  [ 6  7 17 19 21]]