Build beginner-first intuition for training on Apple silicon: what Metal and MPS are, why unified memory changes the CUDA mental model, how PyTorch exposes the `mps` device, how to check availability, where CPU fallback appears, and how synchronization and memory pressure still shape performance.
If you train on Linux or Windows with an NVIDIA GPU, start with CUDA for ML Training. If you train on a Mac with Apple silicon, use the matching path here.
The CUDA chapter moved a text-classification batch onto an NVIDIA GPU. On an Apple silicon Mac, the same batch shape (32, 128, 768) can run on the Apple GPU through PyTorch's mps device. You still place tensors deliberately, measure queued work honestly, and manage memory pressure; the backend name and memory architecture change.[1][2]
Build on that CUDA placement and timing model, then translate it to Apple's backend checks and unified-memory behavior.
That distinction matters when you develop locally. Large batched tensor math can use the Apple GPU, but unsupported operators, hidden synchronization, or an oversized workload can still leave a training step slow or broken. Use the Mac-specific checks with the same issue classifier: predict needs_docs, needs_test, or needs_review from developer text.
Metal is Apple's graphics and compute framework. PyTorch uses a Metal Performance Shaders (MPS) backend so ordinary tensor code can target Apple GPUs through MPS Graph and tuned MPS kernels.[1][2]
For a beginner, three names matter:
mps device: the PyTorch device string used for tensors and modules assigned to this backend.So the question isn't "CUDA or GPU?" It's "which backend does this machine expose to PyTorch?"
In the CUDA chapter, the key picture was system memory and discrete GPU video memory, with copies crossing an interconnect. Apple silicon uses a unified memory architecture: CPU and GPU share system memory instead of dividing storage into CPU RAM and separate video RAM.[4]
Keep the contrast precise:
tensor.to("cuda") places tensor storage in the GPU's separate device memory.tensor.to("mps") doesn't cross the same CPU-RAM-to-discrete-VRAM boundary. It still returns an mps tensor and selects MPS-backed execution. Don't infer that a framework-level device move is always free or always zero-copy.[4][2]Two practical consequences follow, and both show up below:
cpu and mps, keep model and batch on compatible devices, and avoid needless host-visible scalar reads in the hot path.So unified memory doesn't let you skip device discipline. It changes the hardware boundary, not the PyTorch contract.
The tensor shape from the CUDA lesson gives a concrete first budget. A float32 text-classification batch uses four bytes for every feature value:
1batch, tokens, features = 32, 128, 768
2bytes_per_float = 4
3batch_bytes = batch * tokens * features * bytes_per_float
4
5print("shape:", (batch, tokens, features))
6print(f"one float32 batch: {batch_bytes / (1024 ** 2):.1f} MiB")
7print("training also keeps weights, activations, gradients, and optimizer state")1shape: (32, 128, 768)
2one float32 batch: 12.0 MiB
3training also keeps weights, activations, gradients, and optimizer stateApple's MPS setup page lists an Apple silicon Mac, macOS 14.0 or later, Python 3.10 or later, and Xcode command-line tools for its documented path.[1] PyTorch release numbers change, so use the normal wheel install below and verify the backend instead of pinning a version in your setup notes.
Before writing model code, make sure the machine satisfies those basics.
1xcode-select --install
2python3 --version
3sw_versWhen the machine looks compatible, install PyTorch from the normal wheel path Apple documents:
1pip3 install torch torchvision torchaudioStart with one tiny script that distinguishes three states cleanly:
mps.1import torch
2
3has_mps_backend = hasattr(torch.backends, "mps")
4mps_built = bool(has_mps_backend and torch.backends.mps.is_built())
5mps_available = bool(has_mps_backend and torch.backends.mps.is_available())
6
7device = torch.device("mps") if mps_available else torch.device("cpu")
8
9x = torch.arange(6, dtype=torch.float32).reshape(2, 3).to(device)
10model = torch.nn.Linear(3, 2).to(device)
11y = model(x)
12
13print(f"mps built: {mps_built}")
14print(f"mps available: {mps_available}")
15print(f"selected device: {device}")
16print(f"output shape: {tuple(y.shape)}")1mps built: True
2mps available: True
3selected device: mps
4output shape: (2, 2)PyTorch's MPS note uses the same is_built() and is_available() distinction.[2] Run this check on the target machine: on a compatible Mac it should select mps; elsewhere the fallback path is cpu. That split matters:
torch.device("mps")Check your reasoning before moving on: is_built() answers "does this wheel even know about MPS?" while is_available() answers "can this specific machine use it right now?" Strong answer keeps those two questions separate.
Continue with the CUDA lesson's text-classification batch. Thirty-two issue reports become 128 hidden token positions each, with 768 features per position. For this small classifier, average those token positions into one vector per report, then produce three logits per report for needs_docs, needs_test, and needs_review.
On a Mac training run, one small step usually looks like this:
| Step | CPU side | mps side | Why beginner should care |
|---|---|---|---|
| batch assembly | tokenizer, collator, padding, labels | nothing yet | data still starts on host |
| device move | Python asks for .to("mps") | batch becomes an mps tensor | placement is explicit and inspectable |
| forward pass | host launches ops | Metal kernels run embeddings, matmuls, norms | most heavy math lives here |
| loss read | maybe host asks for scalar | device may need to finish queued work first | innocent logging can stall loop |
| backward pass | autograd schedules gradient work | gradient kernels run on mps | memory now includes activations and grads |
| optimizer step | host calls step() | parameter updates happen on mps | model stays on device across steps |
That table isn't theory for theory's sake. Later chapters on training loops, mixed precision, and checkpointing assume you can point at each row and say what ran on the CPU, what ran on the accelerator, and what could unexpectedly bounce back.
If you can't narrate one batch this way yet, stop here and do it slowly. Name tensor location, kernel location, sync point, and likely failure for each row. That habit transfers directly to real model debugging.
Unified memory tempts Mac users into thinking "one laptop, one memory pool, one device." The hardware shares memory, but PyTorch still doesn't work that way at the code level. cpu and mps are separate device targets in your code, and the forward pass still fails if model and batch land on different devices.[2]
Training code on Mac still follows the same pattern as CUDA:
1import torch
2import torch.nn as nn
3
4device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
5model = nn.Linear(768, 3).to(device)
6ticket_batch = torch.randn(32, 128, 768).to(device)
7
8ticket_vectors = ticket_batch.mean(dim=1)
9logits = model(ticket_vectors)
10devices_match = next(model.parameters()).device == ticket_batch.device == logits.device
11print("model and batch agree:", devices_match)
12print("logits shape:", tuple(logits.shape))1model and batch agree: True
2logits shape: (32, 3)Continue from the placed classifier and ticket batch with one small optimizer step. It uses MPS when available and remains executable on a non-Mac machine:
1import math
2
3import torch.nn.functional as F
4
5optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
6labels = torch.tensor([0, 2, 1, 0], device=device)
7weights_before = model.weight.detach().clone()
8
9optimizer.zero_grad()
10loss = F.cross_entropy(model(ticket_vectors[:4]), labels)
11loss.backward()
12optimizer.step()
13logged_loss = loss.detach().cpu().item()
14
15print("weights changed:", not torch.equal(weights_before, model.weight.detach()))
16print("finite loss:", math.isfinite(logged_loss))1weights changed: True
2finite loss: TrueThe .cpu().item() call is outside gradient computation and deliberately marks the boundary where a scalar returns to the host for reporting.
If model and batch devices don't agree, fix placement before looking for deeper bugs. Catch that failure before a full training run by checking the device contract:
1import torch
2
3def require_same_device(model_device: torch.device, batch: torch.Tensor) -> None:
4 if batch.device != model_device:
5 raise RuntimeError(f"batch device does not match model device {model_device}")
6
7batch = torch.randn(8, 4)
8try:
9 require_same_device(torch.device("mps"), batch)
10except RuntimeError as error:
11 print("caught:", error)1caught: batch device does not match model device mpsThat's the same rule you learned in CUDA. The Mac path isn't an exemption from device consistency.
An operation without an MPS implementation can stop an otherwise valid training loop. PyTorch exposes PYTORCH_ENABLE_MPS_FALLBACK=1 so unsupported MPS operations can run on CPU instead of failing immediately.[3] Treat that flag as a debugging aid, not a promise that every unsupported operation will work.
1PYTORCH_ENABLE_MPS_FALLBACK=1 python train.pyFallback keeps debugging unblocked, but it can also hide backend switches and synchronization. If one layer keeps falling back to CPU inside a hot loop, throughput can collapse even though the script "works."
CPU work isn't automatically fallback. Tokenization and batch assembly normally happen on CPU before .to("mps"). MPS fallback means a PyTorch operation on the accelerator path has no MPS implementation and runs on CPU instead.
Use fallback to identify the blocking operation. Then decide whether to rewrite that part, upgrade PyTorch, change precision, or accept the CPU path for that workload.
Your issue classifier may spend most of its time in embedding lookup, attention, and classifier matmuls on mps, while one tensor-indexing operation has no MPS implementation and falls back to CPU inside every batch.
At small scale, you may barely notice:
mpsmps work resumesAt larger scale, that repeated detour gets expensive because every backend switch disrupts the fast path. Throughput drops, step time becomes noisy, and profiler traces stop looking like one clean accelerator-bound loop. "Script finished" isn't the same as "training path is healthy."
You can reason about fallback cost without requiring an unsupported operator on your machine. Each backend change below is a boundary where execution or data handling has to switch paths:
1path = ["host batch", "mps forward", "fallback op", "mps classifier", "host log"]
2backends = [stage.split()[0] for stage in path]
3switches = sum(left != right for left, right in zip(backends, backends[1:]))
4
5print(" -> ".join(path))
6print("backend changes:", switches)
7print("fix target: remove fallback from hot path")1host batch -> mps forward -> fallback op -> mps classifier -> host log
2backend changes: 4
3fix target: remove fallback from hot pathThe first move and final log are deliberate boundaries. The middle CPU fallback is the boundary to remove from the hot path.
The Mac timing trap looks like the CUDA timing trap. Kernel launches and queued device work can make naive timers lie. Warm up the operation first, then synchronize before and after the measured block. PyTorch exposes torch.mps.synchronize() for that boundary.[5]
1import time
2import torch
3
4device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
5x = torch.randn(256, 256, device=device)
6w = torch.randn(256, 256, device=device)
7
8for _ in range(3):
9 y = x @ w
10if device.type == "mps":
11 torch.mps.synchronize()
12
13start = time.perf_counter()
14for _ in range(10):
15 y = x @ w
16if device.type == "mps":
17 torch.mps.synchronize()
18elapsed_ms = (time.perf_counter() - start) * 1000
19
20print("timed matmuls:", 10)
21print("result shape:", tuple(y.shape))
22print("elapsed is nonnegative:", elapsed_ms >= 0)1timed matmuls: 10
2result shape: (256, 256)
3elapsed is nonnegative: TrueSame hidden sync points still matter:
loss.item() when loss is an MPS tensortensor.cpu(), including tensor.cpu().numpy() for NumPy analysisCalling .numpy() directly on an MPS tensor isn't the route back to NumPy: move the data to CPU first. Keep that reporting boundary explicit:
1import torch
2
3device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
4loss = torch.tensor(2.5, device=device)
5reported_loss = loss.detach().cpu().item()
6
7print(f"reported loss: {reported_loss:.1f}")1reported loss: 2.5Don't trust timing claims until you know whether the host waited for the device.
Say one issue-classifier forward pass launches 35 ms of mps work, but Python finishes queueing it in 3 ms.
That gap changes your engineering decisions.
Good accelerator debugging starts with honest measurement before clever optimization.
The first memory lesson is the same as CUDA: weights aren't the whole bill. Activations, gradients, optimizer state, and temporary workspaces matter too.
Unified memory adds one twist. There isn't a separate video-memory pool: a large run competes for system memory with macOS and every other app. PyTorch exposes current_allocated_memory() for bytes occupied by live tensors, driver_allocated_memory() for total memory allocated by Metal for the process (including cached allocator blocks and MPS/MPSGraph allocations), and empty_cache() to release unoccupied cached memory.[5]
Inspect those counters only after MPS is available:
1import torch
2
3if torch.backends.mps.is_available():
4 before = torch.mps.current_allocated_memory()
5 tensor = torch.ones(1024, 1024, device="mps")
6 after = torch.mps.current_allocated_memory()
7 del tensor
8 torch.mps.empty_cache()
9 print("live tensor allocation increased:", after > before)
10 print("recommended limit reported:", torch.mps.recommended_max_memory() > 0)
11else:
12 print("MPS allocator counters need an available mps device")When memory gets tight, fix order should stay boring:
| Symptom | First question | First fix |
|---|---|---|
| OOM on first real batch | Is batch or sequence length too large? | shrink batch size first |
| Step time swings wildly | Are unsupported ops or sync points bouncing work back to CPU? | check fallback and logging paths |
| MPS allocator errors | Are you near working-set limits? | reduce workload before touching allocator env vars |
PyTorch also exposes MPS-specific allocator controls such as PYTORCH_MPS_HIGH_WATERMARK_RATIO and PYTORCH_MPS_LOW_WATERMARK_RATIO.[3] They are advanced tuning knobs, not first response. The documentation warns that disabling the high watermark can cause system failure under system-wide out-of-memory conditions. Start by shrinking work.
For the text-classification example, common first fixes are boring on purpose:
.cpu() calls before blaming MetalMeasure the simplest workload reductions before changing allocator limits:
1batch_size = 32
2sequence_length = 128
3baseline_positions = batch_size * sequence_length
4
5for label, batch, tokens in [
6 ("baseline", 32, 128),
7 ("half batch", 16, 128),
8 ("half length", 32, 64),
9]:
10 share = (batch * tokens) / baseline_positions
11 print(f"{label:11s}: {share:.0%} of token positions")1baseline : 100% of token positions
2half batch : 50% of token positions
3half length: 50% of token positionsMemory lever: Both changes halve token positions and many activation tensors. Shorter sequences can reduce attention score tensors faster because attention has two sequence axes.
Different backend, same engineering questions:
If you can answer those on Mac, later training chapters stop feeling platform-specific.
Cover the sections above and answer these before you peek.
mps is PyTorch's device name for using it..to("mps") still selects MPS execution and doesn't promise a cost-free move. Apple currently labels the MPS backend beta.[4][1]is_built() and is_available() answer different setup questions.PYTORCH_ENABLE_MPS_FALLBACK=1 is useful, but it can hide slow CPU detours.If you want broader accelerator intuition or you also work on NVIDIA servers, read CUDA for ML Training too. Same mental model, different backend details.
mps deviceis_built() vs is_available()is_built() versus is_available(), moves model and batch onto mps, and identifies when unsupported operations fall back to CPUis_available() and missing the more basic case where the installed PyTorch build has no MPS support at all.Answer every question, then check your score. Score above 75% to mark this lesson complete.
8 questions remaining.