-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to count FLOPs during the CLIP inference #143
Comments
We used fvcore's flop_count module with the following modifications to add the operations that it doesn't support out-of-the-box: import typing
from collections import Counter
from fvcore.nn import flop_count
from fvcore.nn.jit_handles import batchnorm_flop_jit, matmul_flop_jit, generic_activation_jit, get_shape
def generic_pooling_jit(name, multiplier=1):
def pool_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
# Inputs[0] contains the shape of the input.
input_shape = get_shape(inputs[0])
output_shape = get_shape(outputs[0])
assert 2 <= len(input_shape) <= 5, input_shape
flop = prod(input_shape) + prod(output_shape) # summing all elements + denominating in each for output
flop_counter = Counter({name: flop * multiplier})
return flop_counter
return lambda inputs, outputs: pool_jit(inputs, outputs)
def softmax_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
input_shape = get_shape(inputs[0])
output_shape = get_shape(outputs[0])
flop = prod(input_shape) * 2 + prod(output_shape) # exponentiating & summing inputs + denominating in each batch
flop_counter = Counter({"softmax": flop})
return flop_counter
def bmm_flop_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
input1_shape = get_shape(inputs[0])
input2_shape = get_shape(inputs[1])
assert len(input1_shape) == len(input2_shape) == 3
assert input1_shape[0] == input2_shape[0] and input1_shape[2] == input2_shape[1], [input1_shape, input2_shape]
flop = prod(input1_shape) * input2_shape[-1] # matmul of bnk * bkm -> bnm; flop = bnkm
flop_counter = Counter({"bmm": flop})
return flop_counter
flops, skips = flop_count(
ForwardWrapper(model),
inputs=(example_input,),
supported_ops={
"aten::batch_norm": batchnorm_flop_jit,
"aten::group_norm": batchnorm_flop_jit,
"aten::layer_norm": batchnorm_flop_jit,
"aten::add": generic_activation_jit("add"),
"aten::sub": generic_activation_jit("sub"),
"aten::mul": generic_activation_jit("mul"),
"aten::div": generic_activation_jit("div"),
"aten::sqrt": generic_activation_jit("sqrt"),
"aten::sigmoid": generic_activation_jit("sigmoid"),
"aten::sigmoid_": generic_activation_jit("sigmoid_"),
"aten::relu": generic_activation_jit("relu"),
"aten::relu_": generic_activation_jit("relu_"),
"aten::gelu": generic_activation_jit("gelu"),
"aten::add_": generic_activation_jit("add_"),
"aten::sub_": generic_activation_jit("sub_"),
"aten::mul_": generic_activation_jit("mul_"),
"aten::div_": generic_activation_jit("div_"),
"aten::sqrt_": generic_activation_jit("sqrt_"),
"aten::adaptive_avg_pool2d": generic_pooling_jit("adaptive_avg_pool2d"),
"aten::adaptive_max_pool2d": generic_pooling_jit("adaptive_max_pool2d"),
"aten::avg_pool2d": generic_pooling_jit("avg_pool2d"),
"aten::max_pool2d": generic_pooling_jit("max_pool2d"),
"aten::bmm": bmm_flop_jit,
"aten::mean": generic_pooling_jit("mean"),
"aten::var": generic_pooling_jit("var", multiplier=3), # subtracting mean, exponentiate, summing
"aten::var_mean": generic_pooling_jit("mean_var", multiplier=4),
"aten::softmax": softmax_jit,
"aten::dropout": generic_activation_jit("dropout"),
"aten::frobenius_norm": generic_pooling_jit("frobenius_norm"),
}
) |
Hi @Akashcodes732 , Did your issue get solved? I am stuck at the same problem and am in need of urgent help. Kindly help me solving this problem. |
Hi @jongwook , Kindly help - I am in urgent need of this solution. |
Hi @sandipan211, have you figured out what is the best practice to count FLOPs of the CLIP model? I have tried several tools on CLIP with ViT-B/16 (e.g. torchsummaryX, thop, and torchinfo), but got different results. Among them, I think the closest result to the FLOPs plotted in the CLIP paper |
I tried already existing FLOP counters for the CLIP model, but they dont seem to work.
I need help regarding how to count the FLOPs for an inference on CLIP model.
The text was updated successfully, but these errors were encountered: