import autograd.numpy as np
from autograd import grad
def sigmoid(x):
return 0.5 * (np.tanh(x / 2.) + 1)
def logistic_predictions(weights, inputs):
# Outputs probability of a label being true according to logistic model.
return sigmoid(np.dot(inputs, weights))
def training_loss(weights):
# Training loss is the negative log-likelihood of the training labels.
preds = logistic_predictions(weights, inputs)
label_probabilities = preds * targets + (1 - preds) * (1 - targets)
return -np.sum(np.log(label_probabilities))
# Build a toy dataset.
inputs = np.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])
# Define a function that returns gradients of training loss using Autograd.
training_gradient_fun = grad(training_loss)
# Optimize weights using gradient descent.
weights = np.array([0.0, 0.0, 0.0])
print("Initial loss:", training_loss(weights))
for i in range(100):
weights -= training_gradient_fun(weights) * 0.01
print("Trained loss:", training_loss(weights))
Presently, instead of using autograd, we recommend using JAX
JAX is Autograd and XLA (Accelerated Linear Algebra)), brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax import grad as jax_grad
def function(x):
return x**2
def analytical_gradient(x):
return 2*x
def gradient_descent(starting_point, learning_rate, num_iterations, solver="analytical"):
x = starting_point
trajectory_x = [x]
trajectory_y = [function(x)]
if solver == "analytical":
grad = analytical_gradient
elif solver == "jax":
grad = jax_grad(function)
x = jnp.float64(x)
learning_rate = jnp.float64(learning_rate)
for _ in range(num_iterations):
x = x - learning_rate * grad(x)
trajectory_x.append(x)
trajectory_y.append(function(x))
return trajectory_x, trajectory_y
x = np.linspace(-5, 5, 100)
plt.plot(x, function(x), label="f(x)")
descent_x, descent_y = gradient_descent(5, 0.1, 10, solver="analytical")
jax_descend_x, jax_descend_y = gradient_descent(5, 0.1, 10, solver="jax")
plt.plot(descent_x, descent_y, label="Gradient descent", marker="o")
plt.plot(jax_descend_x, jax_descend_y, label="JAX", marker="x")
backend = np
def function(x):
return x*backend.sin(x**2 + 1)
def analytical_gradient(x):
return backend.sin(x**2 + 1) + 2*x**2*backend.cos(x**2 + 1)
x = np.linspace(-5, 5, 100)
plt.plot(x, function(x), label="f(x)")
descent_x, descent_y = gradient_descent(1, 0.01, 300, solver="analytical")
# Change the backend to JAX
backend = jnp
jax_descend_x, jax_descend_y = gradient_descent(1, 0.01, 300, solver="jax")
plt.scatter(descent_x, descent_y, label="Gradient descent", marker="v", s=10, color="red")
plt.scatter(jax_descend_x, jax_descend_y, label="JAX", marker="x", s=5, color="black")