I'm getting into using Google JAX and the built-in jit and grad functionality. These aspects are working nicely on my machine, but when I increase the number of arguments I get the following notification:
********************************
Slow compile? XLA was built without compiler optimizations, which can be slow. Try rebuilding with -c opt.
Compiling module jit_obj_func__1.9055
********************************
I would love to increase the number of input parameters, and so I think soon I will need a faster compile time, so this notification appeals to me... but I don't understand how to implement it.
I've been using conda to install jax. Basically, I run the following commands in the terminal:
~$ conda create --name jax
~$ conda activate jax
~$ conda install -c conda-forge jax matplotlib cudatoolkit
I'm certain there must be a way to add some options when installing in conda (for example, using conda install jax=arguments
but I can't find how to do it in the documentation anywhere. There doesn't seem to be anything on stack overflow either — a search only turned up the following:
Very slow jit compile for XLA when using jax
Any advice would be greatly appreciated!