Cut Cross Entropy: 20x Memory reduction in LLM Pre-training through optimized cross entropy kernels

Table of Contents
Introduction
Whilst working on pretraining SabiYarn in 2025, I came across a really interesting paper by a team at Apple called “Cut Your Losses In Large-Vocabulary Language Models”, they had a very interesting proposition - the cross entropy loss function has had a memory problem that has quietly crept up with a recent trend in LLM development, Large Vocabulary sizes.
Deepseek’s emergence in December 2024 marked a significant turning point in the LLM industry. As major AI labs continued to scale model performance through ever-increasing compute budgets, DeepSeek showed that gains in performance, cost and scalability came from optimizing the whole stack, from compute kernels to optimized memory access, networking and storage.While pretraining DeepSeek V3, the team developed an open source distributed file system 3FS (Fire-Flyer- FileSystem) optimized for high throughput training, a new attention mechanism (MultiHead Latent Attention with custom kernels), a highly tuned communication library for mixture-of-experts models (Deep-EP), and Deep GEMM, an FP-8 optimized matrix multiplication kernel library.
A bit on Cross Entropy
Cross Entropy originates from information theory, and is a function that measures the difference between two probability distributions. In Machine Learning, the Cross Entropy Loss function is used in calculating the loss between a model’s predicted probabilities and the dataset’s true value.
LLM Pretraining
Large language models are trained by autoregressively predicting the next token in a dataset, this dataset could vary from the entire internet, scientific texts and literature, textbooks etc in large scale pretraining, to smaller datasets on specific niches or domains as in model fine-tuning.
Given a sequence with N-1 tokens, a large language model can be defined as :
\[ P(x) = \prod_{i=1}^{N} P(x_i | x_1 \ldots x_{i-1}) \]That is, a large language model parameterizes a probability distribution over all possible tokens in it’s vocabulary. This probability distribution is used to predict the next token in a sequence.
The probability distribution is encoded as weights in the model’s architecture, consisting of a backbone network (transformer layers) and a final classifier layer.
The output of the backbone network is often referred to as an embedding vector. These embeddings are the model’s representation of a token, generated by multiple self-attention layers where the model encodes contextual information from the sequence into each token’s embeddings.
\[ E = f(x1 . . . xi−1) \]where
\[ f(x_1 \ldots x_{i-1}) \]signifies the backbone network, the transformer architecture shown above.
These embeddings are then fed as inputs to the classifier layer C⊤ to produce unnormalized Log Probabilities (Logits)
\[ logits = C^\top f(x_1 \ldots x_{i-1}) \]These logits are then converted to a probability distribution over all the tokens in the model’s vocabulary (V), using the softmax function and the result taken as the model’s prediction for the next token.
\[ \text{softmax}_k(\mathbf{v}) = \frac{\exp(v_k)}{\sum_j \exp(v_j)} \]where vk = Logits for the desired class
vj = Logits for all indices j in the model’s vocabulary (V)
The output of the softmax activation function is our desired probability distribution, and is used in generating the next token in inference or calculating the loss score during training.
Thus we could summarize a language model as:
\[ P(x_i | x_1 \ldots x_{i-1}) = \text{softmax}_{x_i}(C^\top f(x_1 \ldots x_{i-1})) \]Like with every other classification problem in machine learning, we take the class with the highest predicted probability as the model’s prediction, in our case, the model’s predicted next token in the sequence, which could be anything from a simple “.” to the answer to the question of God, the Universe and life.
This process is autoregressive, with the predicted token appended to the model’s next input sequence to predict the next token, and so on, looping through the entire training corpus.
Training a Language model, is like training any other deep learning based classifier in this sense, the model makes a prediction (next word), a loss (Cross Entropy) is computed by using the model’s predicted probabilities and the true value(the actual next token in the sequence), this loss is used to compute gradients that are backpropagated through the model to adjust it’s weight, repeatedly till a desired loss value is achieved or a desired number of training tokens have been exhausted.
During training, a langauge model aims to minimize its cross entropy loss or maximize the probability of predicting the right token
\[ \ell(\hat{x}) = \sum_{i=1}^{N} \log P(\hat{x}_i|\hat{x}_1, \ldots, \hat{x}_{i-1}) \]Given a training batch with batch size 8, containing sequences of length 2048
we have our input tensor x with shape: [8, 2048].
Thus our model takes 8 sequence with 2048 tokens each as input.
After a forward pass through the model, the backbone network (transformer layers) generates an embedding matrix E of shape:
\[ B, S, D \]where B = Batch Size (8), S = Sequence Length (2048) D = Hidden Dimension
D represents the size of our embedding vectors. Essentially every token in the batch gets an embedding vector computed.
Note: E is computed in parallel for all tokens, thanks to teacher forcing, the parallel nature of the transformer layers (self-attention) and broadcasting in modern deep learning frameworks like PyTorch.
These embeddings are then passed through the Classifier C⊤, to generate the logits. These logits are of shape:
\[ B, S, V \]where V = Vocabulary Size
Thus for every token, we predict the log probabilities for every token in the model’s vocab. This is also done in parallel.
As mentioned earlier, the Logits are then used to generate the probability distribution, used in computing the cross entropy loss.
The Memory Problem
At inference, sampling from the model’s predicted probabilities is efficient for any token regardless of its position in the sequence, since the model produces one token at a time at each step using the hidden state E of the last token in the sequence.
At Training, for a sequence
\[ (x1, . . . , xN) \]the LLM’s outputs
\[ f(x1), f(x1, x2), . . . , f(x1, . . . , xN) \]are also computed in parallel. However, extra memory is required for saving the activation outputs of each layer and optimizer states for computing the gradients in the backward pass, together with the model’s weight.
Researchers over the years have come up with different optimization techniques like activation checkpointing, Gradient Accumulation etc, to help optimize memory consumption during training, and these are outside the scope of our interests here.
However, despite optimizing the memory consumption of gradients, activations and optimizer states, the unsuspecting final layer (computing the cross entropy loss) has been shown to occupy a significant chunk of the model’s memory footprint during training.
In the Cut Cross entropy paper, the authors noted that in Large Vocabulary models (See why Large Vocabularies are better), the log probabilites materialized while computing the cross entropy loss accounts for 40% to 90% in models with large vocabularies of the memory consumption at training. This poses a problem since larger vocabularies are suitable for multilingual and low-resource-language applications, they also compress text better, which subsequently means less compute is used by the model during inference.
To understand why the cross entropy layer hogs so much memory during training, let’s look closely at how the cross entropy loss is computed.
The cross entropy loss is given as:
\[ \ell_i(x) = - \log(\text{softmax}_{x_i}(C^\top E_i)) \]C⊤ and E maintain their notation as the Classifier and Embedding matrices.
we know the softmax function to be:
\[ \frac{\exp(C^\top_{x_i} E_i)}{\sum_j \exp(C^\top_j E_i)} \]thus, the cross entropy loss can be written as:
\[ \ell_i(x) = - \log \frac{\exp(C^\top_{x_i} E_i)}{\sum_j \exp(C^\top_j E_i)} \]taking the log of the numerator and denominator gives:
\[ \ell_i(x)= - \log \exp(C^\top_{x_i} E_i) + \log \sum_j \exp(C^\top_j E_i) \]which finally reduces to
\[ \ell_i(x) = - C^\top_{x_i} E_i + \log \sum_j \exp(C^\top_j E_i) \]Using Gemma 2(8b), with a vocabulary size of 250,000 as an example:
To compute the cross entropy loss,
we must first perform an indexed matrix multplicaton
\[ C^\top_{x_i} E_i \]This is a simple dot product between the column vector $C^\top_{x_i}$ which corresponds to the classifier weights for the target class (or token in this case) $x_i$ and the Embedding vector $E_I$ of token i in the input sequence.
To compute the second part of the equation, the Log Sum Exp:
\[ \log \sum_j \exp(C^\top_j E_i) \]we first compute
the logits for all tokens j in the model’s vocab
\[(C^\top_j E_i) \]Find the max of the computed logits
Substract the logits from the max (This is done for numerical stability.)
Compute the exponent of the normalized logits
\[exp(C^\top_j E_i)\]sum the exponents
\[\sum_j \exp(C^\top_j E_i)\]finally, we take the log of the sums
Each of these steps requires an intermediary tensor of size $ B, S, V$ in global memory.
Gemma2 has a max sequence length of 80,000 and a vocab size of 256,128, assuming a batch size of 8 sequences and Bf16 floating point precision, the total memory required to calculate the cross entropy sums to:
\[ 8*80000*256128*2 = 256gb \]about 256gb of memory, spent just computing the cross entropy loss, accounting for over 90% of the total memory use in training.
Cut Cross Entropy
Cut Cross Entropy paper approaches this problem by implementing efficient forward and backward passes using custom fused triton kernels, implementing ideas like tiling and indexed loads, together with the reformulation above. This approach ensures that only small chunks of E and C are loaded into the fast GPU shared memory, and by parallelizing across thread blocks, we don’t incure any meaningful latency overheads.
Forward Pass Kernels
Since we already reformulated the cross entropy loss as:
\[ \ell_i(x) = - C^\top_{x_i} E_i + \log \sum_j \exp(C^\top_j E_i) \]The cut cross entropy paper breaks the terms above into two separate triton kernels in the forward pass:
- An indexed Negative Dot Product
- The Log Sum Exp
Indexed Negative Dot Product
A naive computation of the indexed matrix multiplication involves either indexing the classifier weight matrix ($C^\top$) with a memory cost of O(ND) , and then performing the dot product, or computing $\left(C^\top E_i\right)$ which materializes the logits for every token and then indexing into the result to get the logit for the target class, with an $O(N|V|)$ memory cost.
Cut Cross Entropy uses a different approach,

The algorithm above can be simply summarized as:
- Each thread block computes NB dot products for NB tokens from the input sequence and writes them to O.
- For a token position $i$ in the input sequence, to compute C^T X E, we need to load $x_i$ (the actual token value), $E_i$ and $C_{x_i}$ into the shared memory (SRAM) of our threadgroup.
- Since we have limited shared memory, we can’t load the full hidden states $E(N_B, D)$ and $C(N_B, D)$
- We break $E$ and $C$ into tiles of size $(N_B, D_B)$ and $(N_B, D_B)$ respectively, and compute the dot product for these tiles, iterating over the D dimension of E and C and accumulating the scores in shared memory, before writing to O in global memory.
Here’s my implementation:
def indexed_dot_kernel(
E_ptr, # Hidden states
C_ptr, # Model's Classifier weight matrix
x_ptr, # target tokens
O_ptr, # where are we writing dot products to?
N, # sequence_length
D, # Hidden_Dim
V, # Vocab size
BLOCK_N: tl.constexpr, # Block size across x, how many tokens is this block attending to
BLOCK_D: tl.constexpr, # how much of E are we loading at once across the D dimension
stride_en,
stride_ed,
stride_cv,
stride_cd,
):
pid = tl.program_id(axis=0) # what tokens am I handling in the input sequence?
x_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) # computing address offsets
x_mask = x_offsets < N
# loading token indices
x = tl.load(x_ptr + x_offsets, mask=x_mask, other=0)
# temporary tensor to store dot products for NB tokens in SRAM
o = tl.zeros([BLOCK_N], dtype=tl.float32)
# reduction over D
for d in range(0, D, BLOCK_D):
d_offsets = d + tl.arange(0, BLOCK_D)
# loading E tile
e = tl.load(
E_ptr + x_offsets[:, None] * stride_en + d_offsets[None, :] * stride_ed,
mask=(x_offsets[:, None] < N) & (d_offsets[None, :] < D),
other=0.0
)
# loading C tile
c = tl.load(
C_ptr + x[:, None]*stride_cv + d_offsets[None, :] * stride_cd,
mask=(x[:,None] < V) & (d_offsets[None, :] < D),
other=0.0
)
# accumulating
o += tl.sum((e*c).to(tl.float32), axis=1)
# writing back to global memory
tl.store(O_ptr + x_offsets, -o, mask=x_mask)
Each threadblock loads NB tokens, performs reduction over the D dimensions of the hidden states, and loads up tiles of $E$ and $C$, performs the dot product and accumulates the result. However, this is not efficient enough as we only parallelize of the token sequence length N, and would end up utilizing less than 60% of our Streaming Multiprocessors in most cases.
The official apple implementation however, parellizes over both tokens and hidden dimensions, using a 2d logical kernel grid.
def _indexed_neg_dot_forward_kernel(
E,
C,
Inds,
Valids,
Out,
B,
D,
stride_eb,
stride_ed,
stride_cv,
stride_cd,
stride_ib,
stride_vb,
B_BIN,
BLOCK_B: tl.constexpr,
BLOCK_D: tl.constexpr,
GROUP_B: tl.constexpr,
HAS_VALIDS: tl.constexpr,
EVEN_D: tl.constexpr,
SHIFT: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_b_chunks = tl.cdiv(B, BLOCK_B)
num_d_chunks = tl.cdiv(D, BLOCK_D)
num_d_in_group = GROUP_B * num_d_chunks
group_id = pid // num_d_in_group
first_pid_b = group_id * GROUP_B
group_size_b = min(num_b_chunks - first_pid_b, GROUP_B)
pid_b = first_pid_b + ((pid % num_d_in_group) % group_size_b)
pid_d = (pid % num_d_in_group) // group_size_b
offs_b = (tl.arange(0, BLOCK_B) + pid_b * BLOCK_B) % B
if HAS_VALIDS:
offs_b = tl.load(Valids + stride_vb * offs_b)
offs_d = tl.arange(0, BLOCK_D) + pid_d * BLOCK_D
e_ptrs = E + (stride_eb * offs_b[:, None] + stride_ed * offs_d[None, :])
if EVEN_D:
e = tl.load(e_ptrs)
else:
e = tl.load(e_ptrs, mask=offs_d[None, :] < D, other=0.0)
inds = tl.load(Inds + stride_ib * ((offs_b + 1) if SHIFT else offs_b))
c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd)
if EVEN_D:
c = tl.load(c_ptrs)
else:
c = tl.load(c_ptrs, mask=offs_d[None, :] < D, other=0.0)
offs_b = tl.arange(0, BLOCK_B) + pid_b * BLOCK_B
out_ptrs = Out + offs_b
dot = (e * c).to(tl.float32)
neg_dot = -tl.sum(dot, 1).to(out_ptrs.dtype.element_ty)
tl.atomic_add(out_ptrs, neg_dot, mask=offs_b < B)
This way multiple thread blocks can compute over the same E tile, which introduces a need to synchronize threads (minimal overhead compared to the cost of underutilizing the gpu), this synchronization is done by :
tl.atomic_add(out_ptrs, neg_dot, mask=offs_b < B)
The second part of our equation, the Log-Sum-Exp, is also implemented using a similar approach

Again, to summarize:
Each threadblock computes the log sum exp for $N_B$ tokens
We parallelize over both NB(tokens) and VB (vocab size), to avoid loading up the full C matrix into SRAM for each token.
def _cce_lse_forward_kernel(
E,
C,
LSE,
LA,
Locks,
Valids,
softcap,
B,
V,
D,
stride_eb,
stride_ed,
stride_cv,
stride_cd,
stride_lse_b,
stride_vb,
num_locks,
# Meta-parameters
B_BIN,
HAS_VALIDS: tl.constexpr,
BLOCK_B: tl.constexpr,
BLOCK_V: tl.constexpr,
BLOCK_D: tl.constexpr, #
GROUP_B: tl.constexpr, #
EVEN_D: tl.constexpr,
HAS_SOFTCAP: tl.constexpr,
HAS_LA: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_b = tl.cdiv(B, BLOCK_B)
num_pid_v = tl.cdiv(V, BLOCK_V)
num_pid_in_group = GROUP_B * num_pid_v
group_id = pid // num_pid_in_group
first_pid_b = group_id * GROUP_B
group_size_b = min(num_pid_b - first_pid_b, GROUP_B)
pid_b = first_pid_b + ((pid % num_pid_in_group) % group_size_b)
pid_v = (pid % num_pid_in_group) // group_size_b
offs_b = (pid_b * BLOCK_B + tl.arange(0, BLOCK_B)) % B
if HAS_VALIDS:
offs_b = tl.load(Valids + stride_vb * offs_b)
offs_v = (pid_v * BLOCK_V + tl.arange(0, BLOCK_V)) % V
offs_d = tl.arange(0, BLOCK_D)
e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed)
c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd)
- We use reduction to avoid loading up either E or C across the entire D Dimension, so we only ever load $E(N_B, D_B)$ and $C(V_B, D_B)$ at any point in time, accumulating the dot products in $\text{accum}$ in shared memory. $\text{accum}$ now contains partial logits for the NB token in this program, since it was computed from $E(N_B, D)$ and $C(D, V_B)$, and not the full $C(D, V)$
accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32)
for d in range(0, tl.cdiv(D, BLOCK_D)):
# Load the next block of E and C, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_D:
e = tl.load(e_ptrs)
c = tl.load(c_ptrs)
else:
e = tl.load(e_ptrs, mask=offs_d[None, :] < D - d * BLOCK_D, other=0.0)
c = tl.load(c_ptrs, mask=offs_d[:, None] < D - d * BLOCK_D, other=0.0)
accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION)
e_ptrs += BLOCK_D * stride_ed
c_ptrs += BLOCK_D * stride_cd
- The log sum exp over the partial logits is now computed by finding the max of the logits over this $V_B$ block, substracting it from all logits, calculate the exponent of the sums, the log of the exponent, and then sum back the max to the log.
this_mx = tl.max(logits, axis=1)
e = tl.exp(logits - this_mx[:, None])
this_lse = this_mx + tl.log(tl.sum(e, axis=1))
- Remember we only have partial logits of size $(N_B, V_B)$, since each thread block computes the logits over a $V_B$ portion of the model’s vocabulary. $\text{this\_lse}$ holds the log-sum-exp for the partial logits of a particular thread block, which is not the full value we need. However, at this point, other programs with the same $\text{pid\_b}$ (handling the same token block), and different values for $\text{pid\_v}$, would have calculated their $\text{this\_lse}$ values for the other $(V_B)$ blocks in V, we need to find a way to sum these values without introducing race conditions across these threadblocks.
lse_ptrs = LSE + (stride_lse_b * off_b)
this_locks = Locks + (pid_b // tl.cdiv(B, BLOCK_B * num_locks))
while tl.atomic_cas(this_locks, 0, 1) == 1:
pass
- The final log-sum-exp are stored at LSE, however since multiple thread blocks with the same $\text{pid\_b}$ will write to the same locations in LSE, a spin atomic lock is used to prevent race conditions, and whoever holds the lock writes their computed LSE value.
lse = tl.load(lse_ptrs, mask=o_mask, other=0.0, eviction_policy="evict_last")
lse = tl_logaddexp(lse, this_lse)
tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last")
tl.atomic_xchg(this_locks, 0)
- Finally, we load the logits currently accumulated in $\text{LSE}$, and use the logaddexp function to add our partial logits $\text{this\_lse}$ to the accumulated $\text{lse}$ value, write back to our blocks in $\text{LSE}$, and finally release the lock.
