Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected Overheads with Activation Checkpointing with Pipeline Parallelism #17

Open
abhinavgoel95 opened this issue Apr 5, 2023 · 1 comment

Comments

@abhinavgoel95
Copy link
Contributor

abhinavgoel95 commented Apr 5, 2023

We notice a buggy behavior with bitcasts and dynamic update slices. When we turn on activation checkpointing (e.g., saving outputs of projection layers using the SAVE_OUT_PROJ flag in PAXML) we see multiple extra updates and copies.

For example, we want to checkpoint an activation of shape [2,2048,48,128]. However, in the HLO below we see that the copies are of shape [15,1,2,2048,48,128]. Here, 15 is the number of microbatches we are using with pipeline parallelism.

Snippet of HLO:

fusion.549 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, ..., kind=kLoop, calls=fused_computation.549, metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
get-tuple-element.5874 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=0
copy.583 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5874)
get-tuple-element.5866 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=1
copy.575 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5866)
get-tuple-element.5868 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=2
copy.577 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5868)
get-tuple-element.5870 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=3
copy.579 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5870)
get-tuple-element.5872 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=4
copy.581 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5872)

...

fused_computation.549 {
  param_1.8511 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(1)
  bitcast.52601 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_1.8511)
  param_0.6313 = bf16[2,48,128,2048]{3,2,1,0} parameter(0)
  bitcast.52600 = bf16[1,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_0.6313)
  param_2.5901 = s32[] parameter(2)
  constant_7564 = s32[] constant(0)
  compare.3477 = pred[] compare(param_2.5901, constant_7564), direction=LT, metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/pipeline._scan_fn/pipeline._get_iteration_inputs/jit(remainder)/rem" source_file="/pax/praxis/praxis/layers/pipeline.py" source_line=422}
  constant_11524 = s32[] constant(15)
  add.6580 = s32[] add(param_2.5901, constant_11524), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/add" source_file="/pax/praxis/praxis/base_layer.py" source_line=695}
  select.5360 = s32[] select(compare.3477, add.6580, param_2.5901), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/select_n" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  dynamic-update-slice.325 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} dynamic-update-slice(bitcast.52601, bitcast.52600, select.5360, constant_7564, constant_7564, /*index=5*/constant_7564, constant_7564, constant_7564), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  bitcast.52599 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} bitcast(dynamic-update-slice.325), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  param_4.7770 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(4)
  bitcast.52617.clone.1 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_4.7770)
  param_3.8428 = bf16[2,48,128,2048]{3,2,1,0} parameter(3)
  bitcast.52616.clone.1 = bf16[1,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_3.8428)
  dynamic-update-slice.333.clone.1 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} dynamic-update-slice(bitcast.52617.clone.1, bitcast.52616.clone.1, select.5360, constant_7564, constant_7564, /*index=5*/constant_7564, constant_7564, constant_7564), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  ...
  ROOT tuple.356 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}) tuple(bitcast.52599, bitcast.52615.clone.1, bitcast.52611.clone.1, bitcast.52607.clone.1, bitcast.52603.clone.1)
}

It seems like there is a big buffer of size [15,1,2,2048,48,128] holding the activations for all microbatches. Within each microbatch, we are trying to update one row of this buffer (of shape [2,2048,48,128]). But XLA loads the entire buffer into memory, performs the update, and then copies the buffer back. We see this problem in our profiles. The amount of time spent on D2D copies (i.e., copy.575 to copy.583) is much larger than expected for the amount of data that should be copied. Right now, the time spent on activation checkpointing is 5% to 8% of the overall run time for a GPT-3 style model.

Our current understanding: The reason for the copy is because when bitcast is treated as computing a new value (e.g., like a convert or sqrt), then a new tensor must be used in each loop iteration, therefore a copy of each DUS result must be made. This should be able to be fixed by treating bitcast as an aliasing operation instead of computing a new value --- in the dataflow analysis. I think there is an option in dataflow analysis that configures how bitcast should be treated. In XLA TPU, the option is set to be true where bitcasts are treated as simply an aliasing operation.

Would someone be able to look into this?

I am attaching a link to the HLO: https://drive.google.com/drive/folders/1fYUsqfDgYRRpgOklE-k7qx_5ixkJzKPD?usp=sharing

@akuegel
Copy link
Member

akuegel commented Apr 6, 2023

The option is set in the same way for XLA GPU (both TPU and GPU use the default value for this flag). So it is not so easy to fix it like that, and there was some doubt whether it can be fixed apart from avoiding the reshape bitcast (which might be added on model side).

Quoting from the chat: "A reshape of a dus is not longer the same tensor. A copy might be needed. Not always though"

Cjkkkk pushed a commit to Cjkkkk/paxml that referenced this issue May 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants