Distributed ML model inference

In the state of 2024, some Large Language Models (LLM) are made of hundreds of billions of parameters. To run them you need GPUs, big GPUs. With BLOOM-176 or OPT-175 you will broadly need 3 Nvidia A100, costing $15K each. A paper published in March 2023 introduces Petals, a framework for collaborative inference (ie: process a real user's request). It concludes that the bill can be drastically reduced. Let's see how: we first introduce how training actually works then inference for a big model, then explain how Petals improved that. We’ll conclude by system limitations.

Distributed training

Distributed Machine Learning is required to achieve high performance in training large models based on very large dataset (about terabytes of data). It globally implies to train the model across multiple instances (that can host one or more GPUs), rather than on a single instance. The data is split across the instances, and each of them trains the model on its portion of the data. All resulting models are then combined to produce a final model. This approach can significantly reduce the time it takes to train large models.

Tools like Hivemind, Horovod and BigDL fit well for this purpose. This approach allows it to overrun many single-instance hardware limitations. Concerning privacy, distributed-ML design patterns like Private Aggregation of Teacher Ensembles (PATE) are built to keep data as private as possible during training.

Inference time!

Once the model is trained and fine tuned, it is then post-processed to be prepared for inference. It is not an easy task.

The common pipeline is like:

  1. Optional: Dilute your model to a new one (kind of transfer learning) in order to drastically reduce the number of parameters, but it will affect your model accuracy
  2. Optional: Quantize the model by replacing 4 bytes floats to 1 (int8) or 2 bytes (int16). It can be a good move, but it highly depends on your model (typically you won't quantize a trigonometric function...). If done completely, a model can be run of a CPU or NPU instead of a GPU
  3. Deploy an API or a Triton server that will receive input data / queries in front of the model
  4. Put the model on an instance and pray that it fits into RAM (or VRAM). To illustrate, BLOOM LLM model "fits" into 352GB of RAM
  5. It doesn't fit? Well, you have to split your model into smaller pieces (layer by layer), offload them to RAM or SSD and load them dynamically when needed, performance loss is guaranteed
  6. Because your model is highly requested (congrats!) you need to scale your GPU-equipped instance pool to handle such traffic, which leads to orchestration and load balancing… You know the drill.

Even if you optimize your billions-parameters-model (as it is the trend with LLMs), each instance requires you to run and infer the entire model, which is quite expensive. Some sources argue that ChatGPT cost about $700K a day to run.

Distributed… inference?

In march 2023, BigScience released a new paper onto arXiv that sounds like a relevant proof-of-concept.

In a word: Petals is a protocol that connects a swarm of multi-origin and heterogeneous instances with GPU to share the whole inference of a large language model (the POC is using BLOOM model, about 176 billions params, quite similar to GPT3). Each instance runs a single layer of the model for forward and backward passes, instead of the whole model:

  • When an inference request is received, the instance running the first layer applies the forward pass
  • Result is sent to the instance hosting the second layer
  • And so on until the last layer of the model
  • The final output is the response payload to the input request

Achievements belong in 3 aspects: the nature of the instances, the layer-based load balancing possibility and a memory efficient fine-tuning.

Heterogeneous park

In the article, the nature of the instances themselves is very sparse: "only" equipped with gaming GTX 2080, GTX 3060 or stronger A100 GPU.
In comparison, if you were running the whole BLOOM model (352GB) by offloading on a single A100 (80GB GPU), it would take 5.5 seconds to compute one inference.

An interesting benchmark was done on a set of 14 small servers in real circumstances (with firewalling, heterogeneous network on 2 continents), which shows good performances: Up to x6 in single-batch (processing one request) and x15 in batch-1 and equivalent in batch-64 (ie 64 requests in the same time).

Layer-level load balancing

Because the model's layers can be assigned to a large typology of instances, it is possible to apply a fine workload-scaling. Indeed, in case of high inference demand, it can be good to increase compute-intensive layers presence and decrease low-compute layer presence. It could be done by assigning new instances or by rebalancing the cardinality of each layer (ie: keeping the same number of instances).
But in reality, the computing power of each instance must be taken into consideration to host the layer that suits better.

Scalable fine tuning

Another interesting property from this model distribution is the ability to apply fine tunings 1) without loading the whole model and 2) simultaneously.

As a reminder: fine tuning is to specialize a “general purpose trained model” (also named Foundation models) by training over an ad-hoc dataset. For instance: fine tune an animal-detector model to make it a more precise cat-race-detector model.

In these circumstances we want to manipulate only specific layers, which is what Petals does. 1) Each ML engineer can handle a specific set of layers that fit in their local RAM, compute a forward pass based on a new dataset, then ask the other layer-instances to apply a backpropagation (without changing their original pretrained weights). 2) Each fine-tuning backpropagation result is versioned and stored on its respective instance, so there can be many ML engineers working on their task without interfering with each other. For such a system, storing data through IPFS could be very efficient to reduce data redundancy.

Limitations & challenges

Petals is an interesting step, but it came with many challenges:
There is no incentive / reward system for participants that would share their hardware for model inference.

As is, Petals provided model is already quantized, dividing the memory footprint per 2 from the original model, but still requires instances with at least a GPU with 8GB VRAM and 16GB of RAM. This optimization highly relies on the model architecture, if wisely done it could open the distributed inference to smaller devices with CPU or NPU (like Mac mini).

In the current approach, there is a high security concern about the peers inferring the first layers of the model. Indeed they could use the provided inputs to recover original input (which could be sensitive data, like health or financial). The current workaround is to limit the run of these first layers to trusted instances.

To manage last issue, there are privacy preserving inference (PPI) methods:

One is to apply an homomorphic encryption (HE) over the input data, apply the forward pass, then return the encrypted data. This approach is very secure however there are drawbacks: 1) increases the compute cost, 2) non-linear and pooling operations cannot be encrypted 3) and there is great loss of accuracy. The most recent research results on HE applied to LLM are looking to encrypt the embedding layer (one of the first layer that typically converts a sentence to digits) instead of the input itself.

Another method called differential privacy (DP); consists in adding noise during training and inference (progressively or only at the end). This is an easy way to apply better protection over input data. However, after a certain number of inferences, the noise function could be reverse-engineered, which causes the re-training of the entire model… This work aims at handling this issue by using Quantum computing, generating true noise from quantum particles (photons, atoms, …).

In practice, some public inference instances could be malicious and return incorrect outputs instead of the actual results of forward pass. Indeed, if there is no result-checking, why not just return a random result instead of a costly one? Even worse, an instance could return a wrong result in order to deliberately orient the next layer to -at the end- generate a desired output.

To verify that a computation is correctly done, researchers tend to apply Zero Knowledge Proof (ZKP). In that area, ML servers providing the inference result must also show a cryptographic proof of their computation, this proof must be checked by a trust party (the Verifier). Current works tend to reduce the ZKP cost to make it scalable. This field is quite new, you can find more resources on this upcoming topic here.

Compute over data - sounds like a trend

Borzunov's work on Petals is actually anchored in a wider trend that arose in the 2020+, named Compute over Data. Their target is to build a collaborative and decentralized way to share, manage and work on data

Company like Balcalhau, Expanso, Fluence and Kamu are actively establishing the Open Data Fabric specifications to describe how to ingest, exchange and compute data -including Machine Learning- and store results in a distributed way.

Recommended articles