Presently, instead of using autograd, we recommend using JAX
JAX is Autograd and XLA (Accelerated Linear Algebra)), brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.
Here's a simple example on how you can use JAX to compute the derivate of the logistic function.
import jax.numpy as jnp
from jax import grad, jit, vmap
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)