Affiliation:
1. Laboratory for Information and Decision Systems, Massachusetts Institute of Technology, Cambridge, Massachusetts 02139;
2. Department of Computer Science and Elmore Family School of Electrical and Computer Engineering, Purdue University, West Lafayette, Indiana 47907;
3. Department of Electrical Engineering and Computer Science, Massachusetts Institute of Technology, Cambridge, Massachusetts 02139
Abstract
In this paper, we consider the widely studied problem of empirical risk minimization (ERM) of strongly convex and smooth loss functions using iterative gradient-based methods. A major goal of the existing literature has been to compare different prototypical algorithms, such as batch gradient descent (GD) or stochastic gradient descent (SGD), by analyzing their rates of convergence to ϵ-approximate solutions with respect to the number of gradient computations, which is also known as the oracle complexity. For example, the oracle complexity of GD is [Formula: see text], where n is the number of training samples and p is the parameter space dimension. When n is large, this can be prohibitively expensive in practice, and SGD is preferred due to its oracle complexity of [Formula: see text]. Such standard analyses only utilize the smoothness of the loss function in the parameter being optimized. In contrast, we demonstrate that when the loss function is smooth in the data, we can learn the oracle at every iteration and beat the oracle complexities of GD, SGD, and their variants in important regimes. Specifically, at every iteration, our proposed algorithm, Local Polynomial Interpolation-based Gradient Descent (LPI-GD), first performs local polynomial regression with a virtual batch of data points to learn the gradient of the loss function and then estimates the true gradient of the ERM objective function. We establish that the oracle complexity of LPI-GD is [Formula: see text], where d is the data space dimension, and the gradient of the loss function is assumed to belong to an η-Hölder class with respect to the data. Our proof extends the analysis of local polynomial regression in nonparametric statistics to provide supremum norm guarantees for interpolation in multivariate settings and also exploits tools from the inexact GD literature. Unlike the complexities of GD and SGD, the complexity of our method depends on d. However, our algorithm outperforms GD, SGD, and their variants in oracle complexity for a broad range of settings where d is small relative to n. For example, with typical loss functions (such as squared or cross-entropy loss), when [Formula: see text] for any [Formula: see text] and [Formula: see text] is at the statistical limit, our method can be made to require [Formula: see text] oracle calls for any [Formula: see text], while SGD and GD require [Formula: see text] and [Formula: see text] oracle calls, respectively. Funding: This work was supported in part by the Office of Naval Research [Grant N000142012394], in part by the Army Research Office [Multidisciplinary University Research Initiative Grant W911NF-19-1-0217], and in part by the National Science Foundation [Transdisciplinary Research In Principles Of Data Science, Foundations of Data Science].
Publisher
Institute for Operations Research and the Management Sciences (INFORMS)