NavigationContentFooter
Jump toSuggest an edit

Deploying the MLX array framework for Apple silicon

Reviewed on 25 June 2024Published on 15 December 2023
  • apple-silicon
  • mlx
  • framework
  • apple
  • mac-mini

MLX, an array framework designed for effective and versatile machine learning on Apple silicon, was developed by Apple’s machine learning research team. It was developed by machine-learning researchers focusing on catering to their peers, emphasizing a balance between user-friendliness and efficiency in model training and deployment. The framework boasts a deliberately straightforward design that facilitates seamless extension and enhancement by researchers, fostering the swift exploration of innovative ideas.

Regarding APIs, the Python interface of MLX closely emulates NumPy, incorporating a few distinctive features. MLX also offers a robust C++ API designed closely with its Python counterpart.

Taking inspiration from frameworks such as PyTorch, Jax, and ArrayFire, MLX takes a unique path with its unified memory model. Within MLX, arrays reside in shared memory, allowing operations on MLX arrays across diverse supported devices without data copies. This approach marks a notable distinction from other frameworks.

Key features of MLX include:

  • Familiar APIs: MLX provides a Python API closely aligned with NumPy, complemented by a comprehensive C++ API mirroring its Python counterpart. Higher-level packages like mlx.nn and mlx.optimizers closely follow PyTorch, simplifying the construction of intricate models.
  • Composable Function Transformations: MLX facilitates composable function transformations, supporting automatic differentiation, vectorization, and optimization of computation graphs.
  • Lazy Computation: Within MLX, computations adopt a lazy approach, only materializing arrays when needed.
  • Dynamic Graph Construction: MLX dynamically constructs computation graphs, allowing for shape changes in function arguments without triggering sluggish compilations. Debugging remains straightforward and intuitive.
  • Multi-device Support: MLX operations are versatile and capable of running on supported devices, currently spanning the CPU and GPU.
  • Unified Memory: A standout distinction lies in MLX’s unified memory model. Arrays within MLX reside in shared memory, permitting operations on arrays across different device types without necessitating data transfers.

Before you start

To complete the actions presented below, you must have:

  • A Scaleway account logged into the console
  • Owner status or IAM permissions allowing you to perform actions in the intended Organization
  • Created a Mac mini
  • Installed native Python >= 3.8 on the Mac (preinstalled by default)
Note

MLX is only available on devices running MacOS >= 13.3.

Installing MLX

To install MLX from PyPI for use on your Apple silicon computer, execute the following command from a terminal or the SSH shell:

pip3 install mlx

Getting started with MLX

MLX basics

To work with MLX, begin by importing mlx.core and creating an array:

import mlx.core as mx
a = mx.array([1, 2, 3, 4])
print(a.shape) # Output: [4]
print(a.dtype) # Output: int32
b = mx.array([1.0, 2.0, 3.0, 4.0])
print(b.dtype) # Output: float32

Operations in MLX are lazy, meaning the outputs are not computed until necessary. To enforce array evaluation, use eval. Arrays may also be automatically evaluated in specific scenarios, such as inspecting a scalar with array.item, printing an array, or converting an array from array to numpy.ndarray.

c = a + b # c not yet evaluated
mx.eval(c) # evaluates c
c = a + b
print(c) # Also evaluates c
# Output: array([2, 4, 6, 8], dtype=float32)
c = a + b
import numpy as np
np.array(c) # Also evaluates c
# Output: array([2., 4., 6., 8.], dtype=float32)

Function and graph transformations

MLX supports standard function transformations like grad and vmap, allowing arbitrary compositions. For instance, grad(vmap(grad(fn))) is a valid composition.

x = mx.array(0.0)
print(mx.sin(x))
# Output: array(0, dtype=float32)
print(mx.grad(mx.sin)(x))
# Output: array(1, dtype=float32)
print(mx.grad(mx.grad(mx.sin))(x))
# Output: array(-0, dtype=float32)

Other gradient transformations include vjp for vector-Jacobian products and jvp for Jacobian-vector products.

Efficiently compute both a function’s output and gradient concerning the function’s input using value_and_grad.

Going further

  • For comprehensive MLX documentation, refer to the official MLX documentation.
  • The MLX examples repository hosts a diverse collection of examples, including:
    • Training a Transformer language model.
    • Large-scale text generation using LLaMA and subsequent finetuning with LoRA.
    • Image generation employing Stable Diffusion.
    • Speech recognition utilizing OpenAI’s Whisper.
Was this page helpful?
API DocsScaleway consoleDedibox consoleScaleway LearningScaleway.comPricingBlogCareers
© 2023-2024 – Scaleway