In this post we describe the high-level idea behind gradient descent for convex optimization. Much of the intuition comes from Nisheeth Vishnoi’s short course, but he provides a more theoretical treatment, while we aim to focus more on the intuition. We first describe why to use gradient descent. We next define convex sets and functions and then describe the intuitive idea behind gradient descent. We follow this with a toy example and some discussion.
Why Use Gradient Descent?
In many data analysis problems we want to minimize some functions: for instance the negative log-likelihood. However, the function either lacks a closed form solution or becomes very expensive to compute for large datasets. For instance, logistic regression lacks a closed form solution, while the naive closed form solution to linear regression requires solving a linear system, which may have stability issues. For the former, gradient descent provides us a method for solving the problem, while for the latter, gradient descent allows us to avoid these stability issues.
Convex Sets and Convex Functions
Much of the practical application and most of the theory for gradient descent involves convex sets and functions. Intuitively, a convex set is one where for any two points in a set, every point between them is also in the set. For example, consider a line in 1d from points to , and and are points on the line. Then any between them is also on that line. Thus a line is a convex set.
More formally, a convex set is one where if , then for any . In the previous case and .
In the 1d case, a convex function is one where if you draw a line segment between the function evaluated at any two points, the line lies at or above the function everywhere in between. Let’s look at an example: the function .
We see that the line segment joining the points and lies above for . We can also tell just from looking at it that this property would hold for the entire domain of the function. More formally, we have the following definition. For and , where is a convex set, a function is convex if
(1)
here is the line segment between and , and is the function, which lies below the line segment. In higher dimensions instead of the line segment lying above the function, the hyperplane lies above the function.
An equivalent definition of convexity for differentiable functions is that the tangent line or hyperplane to the function at any point lies below the function everywhere. The following plot of the same function illustrates this
Again, we can see from looking at it that this property should hold for the function at any point. More formally, a differentiable function is convex if and only if for any , we have
(2)
where is the inner product .
Gradient Descent: Idea
The goal of gradient descent is to iteratively find the minimizer of : for a convex function . The idea is to at each iteration use a linear approximation to at a fixed , and minimize that. Consider the following approximation
(3)
We want to minimize the right hand side in order to reduce . Since is fixed, we want to find to minimize . Generally we also start by assuming is fixed: that is we are only looking to move within a fixed distance. Since the distance between our next value and our current value is fixed, we only care about the direction we move.
We can note that the gradient is the direction in which the function grows fastest. Since we want to minimize, we want to move in the opposite direction of the gradient. Thus we can set
(4)
where is a constant. This moves in the opposite direction of the gradient, and it can be shown that for small enough this will move us closer to the minimizer .
Let’s now look at example. Consider again the function which we again plot below, and assume that we are currently at the point . Then the gradient at is positive. Thus if we move to the right, we are moving in the direction of the gradient, and if we move to the left, we move away from it and towards the minimizer. The intuition described above tells us to move left.
Now considering the fixed to be our current point , and the new to be our next point in an iterative algorithm, we obtain the following update, which can be used in several variants of gradient descent:
(5)
Why are Convex Functions Important for Gradient Descent?
For differentiable convex functions, the following three properties are equivalent:
- is a local minimum of
- is a global minimum of
For gradient descent this is helpful because we can check when to terminate the algorithm by looking at the derivative and checking its magnitude. Further, we know that under this termination we have achieved (close to) the global minimum.
A Gradient Descent Example
Let’s try minimizing using gradient descent. First we define functions to calculate both the function and its derivative.
import numpy as np
from matplotlib import pyplot as plt
def f(x):
return x**4
def deriv_f(x):
return 4*(x**3)
Then we will write code to run and plot the gradients over iterations.
x=7.5
gradient_magnitudes=[]
while(np.abs(deriv_f(x))>1e-4):
grad = deriv_f(x)
x=x-1e-3*grad
gradient_magnitudes.append(grad)
plt.plot(gradient_magnitudes)
plt.xlabel('iteration')
plt.ylabel('gradient magnitude')
plt.title('Gradient Magnitude of Gradient Descent')
We can see fast convergence initially, but as the slope gets smaller, convergence gets very slow: the algorithm needs over 140,000 iterations with the learning rate we set. Further, if we increase the learning rate by a factor of 10, gradient descent diverges. In future posts we will investigate methods to deal with this.
Discussion
In this post we discussed the intuition behind gradient descent. We first defined convex sets and convex functions, then described the idea behind gradient descent: moving in the direction opposite the direction with the largest rate of increase. We then described why this is useful for convex functions, and finally showed a toy example.