Moving Beyond “1 Device, 1 Model”

Multi-tenancy, where multiple computing applications are served by the same infrastructure, ensures efficient resource utilization in public or private cloud systems. This is especially important for long running jobs like machine learning model training—models which power everything from facial recognition to natural language understanding. Training a modern ML model is AMAZINGLY expensive, and it is common to have 100-1000s of GPUs in industrial infrastructure.

Today’s popular model training frameworks insist on exclusive use of allocated devices; a “1 device, 1 model” mantra.  While this design paradigm makes scheduling and resource management easy (or at least easy’ish), it is not a sustainable paradigm as the number of different models that we train outpaces our hardware resources. But before we can build such systems, a first question is how to a train model in dynamic environments with changing or pre-empted resources.

In our first project, we look at the problem of auto-scaling Stochastic Gradient Descent-based training tasks. If a training task observes that additional threads are available, should it try request those resources? Counter-intuitively, this may hurt model performance. SGD repeatedly samples random training data and calculates a sample approximation of the gradient (the direction to update the model to lower the prediction error). It takes a step in the direction of the gradient and updates the model and repeats until convergence. Intuitively, as long as the updates are in the right direction on average and the approximation error is small relative to the size of the steps, SGD converges.  Low-level concerns like the level of parallelism and the frequency of synchronization can affect the optimization algorithm by increasing the number of conflicted updates (making each step less exact) or forcing inconvenient synchronization (the time to take an exact step longer).

Not surprisingly to safely auto-scale SGD, one must make these repeated steps “less aggressive” when there are more conflicts. We develop a principled theory that relates the sampling batch size, number of threads, step size, and momentum parameters. Scaling on one axis requires an adjustment on others. Our system, which we call OPT^2, automatically adjusts and evaluates the effects of those adjustments to ensure stable and accurate model training.

Threads are only one form of contended resource. We also recently explored training models under dynamic and stringent memory constraints during training. We explore gracefully degrading the precision of the forward and backward passes of a Convolutional Neural Network to accommodate these memory constraints. This approximation focuses on the convolutional layers of a neural network and we degrade the performance in the frequency domain (akin to what happens in JPEG compression). Due to the mathematical properties of such networks, that already favor low-frequency components, the accuracy of the trained model degrades smoothly with the degree of approximation.  We find competitive reductions in memory usage and floating point operations to reduced precision arithmetic—but with the added advantage of training increased stability and a full continuum of compression ratios.

This form of training, which we call band-limiting,  may additionally provide a new perspective

on adversarial robustness. Adversarial attacks on neural networks tend to involve high-frequency perturbations of input data. By truncating the convolutional operations in the frequency domain,  we may introduce an inherent robustness to high-frequency attacks. Our experiments suggest that band-limited training produces models that can better reject such noise than their full spectra counterparts.

So what have we learned? Dynamic resource management during model training is hard, but the mathematical properties of machine learning models may give us more flexibility than we are used to.  Multi-tenant model training systems will have to come to terms with such problems on the intersection of machine learning, signal processing, and distributed systems.