I am using Jax to do some machine learning jobs. Jax uses XLA to do some just-in-time compile for acceleration but the compile itself is too slow on CPU. My situation is that the CPU will only use just a single core to do the compile, which is not efficient at all.
I have found some answers that it can be very fast if I can use GPU for the compile. Can anyone tell me how to use GPU to do the compile part? Since I did not do any configuration about the compile. Thanks!
Some addition for the question: I am using Jax to calculate grad and hessian, which would makes the compile very slow. The code is like:
## get results from model ##
def get_model_value(images):
return jnp.sum(model(images))
def get_model_grad(images):
images = jnp.expand_dims(images, axis=0)
image_grad = jacfwd(get_model_value)(images)
return image_grad
def get_model_hessian(images):
images = jnp.expand_dims(images, axis=0)
image_hess = jacfwd(jacrev(get_model_value))(images)
return image_hess
# get value
model_value = model(dis_img)
FR_value = jnp.expand_dims(FR_value, axis=1)
value_loss = crit_mse(model_value, FR_value)
# get grad
vmap_model_grad = jax.vmap(get_model_grad)
model_grad = vmap_model_grad(dis_img)
# get hessian
vmap_model_hessian = vmap(get_model_hessian)
model_hessian = vmap_model_hessian(dis_img)