Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when using torchcompile option for CLIP training #726

Open
kkjh0723 opened this issue Nov 1, 2023 · 6 comments
Open

Error when using torchcompile option for CLIP training #726

kkjh0723 opened this issue Nov 1, 2023 · 6 comments

Comments

@kkjh0723
Copy link

kkjh0723 commented Nov 1, 2023

Hello,

While I attempt to apply torchcompile option for training CLIP ViT-B-32 model, I got some error.
Below is the script to run training.

torchrun --nproc_per_node 16 -m training.main --save-frequency 1 --zeroshot-frequency 1 --report-to tensorboard --train-data={data_dir}  --csv-img-key filepath --csv-caption-key title --imagenet-val={imagenet val dir} --workers=8 --model ViT-B-32 --precision amp_bf16 --workers 4 --csv-separator "," --local-loss --gather-with-grad --aug-cfg scale='(0.5, 1.0)' --name test--accum-freq 4 --grad-checkpointing --torchcompile

And I got the below error message.
How can I fix this issue?
Note that my pytorch version is 2.1.0 and no error occurs when I runs above script without --torchcompile option.

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/open_clip/src/training/main.py", line 508, in <module>
    main(sys.argv[1:])
  File "/workspace/open_clip/src/training/main.py", line 436, in main
    train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer)
  File "/workspace/open_clip/src/training/train.py", line 117, in train_one_epoch
    model_out = model(images, texts)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 487, in catch_errors
    return hijacked_callback(frame, cache_entry, hooks, frame_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 586, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1167, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 331, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1155, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 331, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1155, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2179, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2286, in inline_call_
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1115, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 562, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1123, in call_function
    p_args, _, example_value = self.create_wrapped_node(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1025, in create_wrapped_node
    ) = speculate_subgraph(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 203, in speculate_subgraph
    f"to trace function `{f.get_name()}` into a single graph. This means "
torch._dynamo.exc.InternalTorchDynamoError: 'NNModuleVariable' object has no attribute 'get_name'

from user code:
   File "/workspace/open_clip/src/open_clip/model.py", line 293, in forward
    image_features = self.encode_image(image, normalize=True) if image is not None else None
  File "/workspace/open_clip/src/open_clip/model.py", line 266, in encode_image
    features = self.visual(image)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/open_clip/src/open_clip/transformer.py", line 516, in forward
    x = self.transformer(x)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/open_clip/src/open_clip/transformer.py", line 322, in forward
    x = checkpoint(r, x, None, None, attn_mask)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
@rwightman
Copy link
Collaborator

@kkjh0723 I think it might break with gradient checkpointing? not sure there is a workaround, possibly maybe using non reentrant mode?

@EIFY
Copy link
Contributor

EIFY commented Nov 2, 2023

I got the same error trying to run both --grad-checkpointing and --torchcompile, but since pytorch 2.1.0 --torchcompile now works with --accum-freq > 1 as the next best option.

@rwightman
Copy link
Collaborator

@EIFY did you try forcing the non reentrant checkpointing? could look to change the default if that works...

@EIFY
Copy link
Contributor

EIFY commented Nov 2, 2023

@rwightman No I haven't tried that.

In that regard, the good news is that

if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https:/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)

pytorch/pytorch#79887 is now fixed and we should be able to do e.g.

if self.grad_checkpointing and not torch.jit.is_scripting(): 
    x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)

The bad news is that other than that grad_checkpointing is either delegated to the vision/text trunks w/o argument support

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable

or not supported at all:
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass

So fairly involved changes would be necessary. I will try doing the easy part and see if it at least gets past that when I get a chance.

@EIFY
Copy link
Contributor

EIFY commented Nov 8, 2023

@rwightman OK so it turned out that use_reentrant=False doesn't help. It still breaks at the same point:

[2023-11-08 12:56:29,383] [0/0] torch._utils_internal: [INFO] CompilationMetrics(frame_key='1', co_name='forward', co_filename='/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py', co_firstlineno=256, cache_size=0, guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, fail_reason="'NNModuleVariable' object has no attribute 'get_name'")
Traceback (most recent call last):
(...)
torch._dynamo.exc.InternalTorchDynamoError: 'NNModuleVariable' object has no attribute 'get_name'

from user code:
   File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py", line 274, in forward
    image_features = dim_scale_img * self.encode_image(image, normalize=self.normalize) if image is not None else None
  File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/model.py", line 239, in encode_image
    features = self.visual(image)
  File "/home/jason-chou/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/transformer.py", line 486, in forward
    x = self.transformer(x)
  File "/home/jason-chou/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jason-chou/.local/lib/python3.10/site-packages/open_clip/transformer.py", line 319, in forward
    x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)

@lavoiems
Copy link

lavoiems commented Jan 3, 2024

Is there any update on this? I am facing the same issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants