Strange memory usage
Memory profiling forward/backward passes
One of the assignments for CS336 is to run the forward and backward pass for the transformer, looking at the memory utilisation.
I tried out a bunch of experiments where I blocked either the forward or the backward pass and switched no_grad to true/false. I expected that a forward pass without the backward pass, but outside the no_grad context, should have similar memory to when the backward pass is there, because the only thing that’s not happening is the updation of the weights, but all the grads and activations are being calculated.
Funnily, when I ran the forward pass without the backward pass, it actually took more peak memory!
running only the forward pass uses 16% MORE memory than running both forward + backward passes.
It seems that the backward pass triggers the cleanup of some activations accumulated during the forward pass. Without backward, PyTorch will keep all activations until the loss is replaced by another tensor, at which point I guess the entire DAG is cleaned up because there are no more references to the previous loss.
Here is the code I used -
for i in range(config.TOTAL_STEPS):
if config.device == "cuda" or config.device.startswith("cuda:"):
torch.cuda.reset_peak_memory_stats()
this_batch_x, this_batch_y = corpus.get_batch()
# forward pass
forward_start_time = timeit.default_timer()
if config.no_grad:
with torch.no_grad():
res = model(this_batch_x)
loss = cross_entropy(res,this_batch_y)
else:
res = model(this_batch_x)
loss = cross_entropy(res,this_batch_y)
if config.sync:
torch.cuda.synchronize()
forward_end_time = timeit.default_timer()
forward_time = forward_end_time - forward_start_time
# backward pass
backward_start_time = timeit.default_timer()
if config.backward:
loss.backward()
if config.use_optimizer:
optimizer.step()
optimizer.zero_grad()
if config.sync:
torch.cuda.synchronize()
backward_end_time = timeit.default_timer()
backward_time = backward_end_time - backward_start_time
# Record peak memory for this step
peak_memory_mb = 0.0
if config.device == "cuda" or config.device.startswith("cuda:"):
peak_memory_mb = torch.cuda.max_memory_allocated() / 1024**2 # Convert to MB


