Unfortunately, Tensorflow is quite painful once you want to use indices with tensors, so to implement your idea you have to use some ugly workarounds. My option would be:
import tensorflow as tf
#First you better use Variable as constant is not designed to be updated
A = tf.Variable(initial_value = [[1,2,3,4,5],[6,7,8,9,10]])
#Create a buffer variable which will store tentative updates,
#initialize it with random values
t = tf.Variable(initial_value = tf.cast(tf.random_normal(shape=[5]),dtype=tf.int32))
values, indices = tf.nn.top_k(A, k=3)
#Create a function for manipulation on the values you want
def val_manipulation(v):
return 2*v+1
#Create a while loop to update each entry of the A one-be-one,
#as scatter_nd_update can update only by slices, but not individual entries
i = tf.constant(0)
#Stop once updated every slice
c = lambda i,x: tf.less(i, tf.shape(A)[0])
#Each iteration update i and
#update every slice of A (A[i]) with updated slice
b = lambda i,x: [i+1,tf.scatter_nd_update(A,[[i]],[tf.scatter_update(tf.assign(t,A[i]),indices[i],val_manipulation(values[i]) )])]
#While loop
r = tf.while_loop(c, b, [i,A])
init = tf.initialize_all_variables()
with tf.Session() as s:
s.run(init)
#Test it!
print s.run(A)
s.run(r)
print s.run(A)
So basically what you do is:
scatter_update
can work with Variables only, so we take a slice from A (as A[i]) and store those values to the buffer Variable t
- Update values in the buffer Variable with desired ones
- Update
i
-th slice of A
with updated t
- repeat to the rest entries of
A
Eventually you should get the following output:
[[ 1 2 3 4 5] [ 6 7 8 9 10]]
[[ 1 2 7 9 11] [ 6 7 17 19 21]]
constant
array. – vijay m