Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to count FLOPs during the CLIP inference #143

Closed
Akashcodes732 opened this issue Aug 27, 2021 · 4 comments
Closed

How to count FLOPs during the CLIP inference #143

Akashcodes732 opened this issue Aug 27, 2021 · 4 comments

Comments

@Akashcodes732
Copy link

I tried already existing FLOP counters for the CLIP model, but they dont seem to work.
I need help regarding how to count the FLOPs for an inference on CLIP model.

@jongwook
Copy link
Collaborator

We used fvcore's flop_count module with the following modifications to add the operations that it doesn't support out-of-the-box:

import typing
from collections import Counter

from fvcore.nn import flop_count
from fvcore.nn.jit_handles import batchnorm_flop_jit, matmul_flop_jit, generic_activation_jit, get_shape


def generic_pooling_jit(name, multiplier=1):
    def pool_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
        # Inputs[0] contains the shape of the input.
        input_shape = get_shape(inputs[0])
        output_shape = get_shape(outputs[0])
        assert 2 <= len(input_shape) <= 5, input_shape
        flop = prod(input_shape) + prod(output_shape)  # summing all elements + denominating in each for output
        flop_counter = Counter({name: flop * multiplier})
        return flop_counter

    return lambda inputs, outputs: pool_jit(inputs, outputs)

def softmax_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input_shape = get_shape(inputs[0])
    output_shape = get_shape(outputs[0])
    flop = prod(input_shape) * 2 + prod(output_shape) # exponentiating & summing inputs + denominating in each batch
    flop_counter = Counter({"softmax": flop})
    return flop_counter

def bmm_flop_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input1_shape = get_shape(inputs[0])
    input2_shape = get_shape(inputs[1])
    assert len(input1_shape) == len(input2_shape) == 3
    assert input1_shape[0] == input2_shape[0] and input1_shape[2] == input2_shape[1], [input1_shape, input2_shape]
    flop = prod(input1_shape) * input2_shape[-1]  # matmul of bnk * bkm -> bnm; flop = bnkm
    flop_counter = Counter({"bmm": flop})
    return flop_counter


flops, skips = flop_count(
    ForwardWrapper(model),
    inputs=(example_input,),
    supported_ops={
        "aten::batch_norm": batchnorm_flop_jit,
        "aten::group_norm": batchnorm_flop_jit,
        "aten::layer_norm": batchnorm_flop_jit,
        "aten::add": generic_activation_jit("add"),
        "aten::sub": generic_activation_jit("sub"),
        "aten::mul": generic_activation_jit("mul"),
        "aten::div": generic_activation_jit("div"),
        "aten::sqrt": generic_activation_jit("sqrt"),
        "aten::sigmoid": generic_activation_jit("sigmoid"),
        "aten::sigmoid_": generic_activation_jit("sigmoid_"),
        "aten::relu": generic_activation_jit("relu"),
        "aten::relu_": generic_activation_jit("relu_"),
        "aten::gelu": generic_activation_jit("gelu"),
        "aten::add_": generic_activation_jit("add_"),
        "aten::sub_": generic_activation_jit("sub_"),
        "aten::mul_": generic_activation_jit("mul_"),
        "aten::div_": generic_activation_jit("div_"),
        "aten::sqrt_": generic_activation_jit("sqrt_"),
        "aten::adaptive_avg_pool2d": generic_pooling_jit("adaptive_avg_pool2d"),
        "aten::adaptive_max_pool2d": generic_pooling_jit("adaptive_max_pool2d"),
        "aten::avg_pool2d": generic_pooling_jit("avg_pool2d"),
        "aten::max_pool2d": generic_pooling_jit("max_pool2d"),
        "aten::bmm": bmm_flop_jit,
        "aten::mean": generic_pooling_jit("mean"),
        "aten::var": generic_pooling_jit("var", multiplier=3),  # subtracting mean, exponentiate, summing
        "aten::var_mean": generic_pooling_jit("mean_var", multiplier=4),
        "aten::softmax": softmax_jit,
        "aten::dropout": generic_activation_jit("dropout"),
        "aten::frobenius_norm": generic_pooling_jit("frobenius_norm"),
    }
)

@sandipan211
Copy link

sandipan211 commented Jun 13, 2024

Hi @Akashcodes732 ,

Did your issue get solved? I am stuck at the same problem and am in need of urgent help. Kindly help me solving this problem.

@sandipan211
Copy link

We used fvcore's flop_count module with the following modifications to add the operations that it doesn't support out-of-the-box:

import typing
from collections import Counter

from fvcore.nn import flop_count
from fvcore.nn.jit_handles import batchnorm_flop_jit, matmul_flop_jit, generic_activation_jit, get_shape


def generic_pooling_jit(name, multiplier=1):
    def pool_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
        # Inputs[0] contains the shape of the input.
        input_shape = get_shape(inputs[0])
        output_shape = get_shape(outputs[0])
        assert 2 <= len(input_shape) <= 5, input_shape
        flop = prod(input_shape) + prod(output_shape)  # summing all elements + denominating in each for output
        flop_counter = Counter({name: flop * multiplier})
        return flop_counter

    return lambda inputs, outputs: pool_jit(inputs, outputs)

def softmax_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input_shape = get_shape(inputs[0])
    output_shape = get_shape(outputs[0])
    flop = prod(input_shape) * 2 + prod(output_shape) # exponentiating & summing inputs + denominating in each batch
    flop_counter = Counter({"softmax": flop})
    return flop_counter

def bmm_flop_jit(inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]:
    input1_shape = get_shape(inputs[0])
    input2_shape = get_shape(inputs[1])
    assert len(input1_shape) == len(input2_shape) == 3
    assert input1_shape[0] == input2_shape[0] and input1_shape[2] == input2_shape[1], [input1_shape, input2_shape]
    flop = prod(input1_shape) * input2_shape[-1]  # matmul of bnk * bkm -> bnm; flop = bnkm
    flop_counter = Counter({"bmm": flop})
    return flop_counter


flops, skips = flop_count(
    ForwardWrapper(model),
    inputs=(example_input,),
    supported_ops={
        "aten::batch_norm": batchnorm_flop_jit,
        "aten::group_norm": batchnorm_flop_jit,
        "aten::layer_norm": batchnorm_flop_jit,
        "aten::add": generic_activation_jit("add"),
        "aten::sub": generic_activation_jit("sub"),
        "aten::mul": generic_activation_jit("mul"),
        "aten::div": generic_activation_jit("div"),
        "aten::sqrt": generic_activation_jit("sqrt"),
        "aten::sigmoid": generic_activation_jit("sigmoid"),
        "aten::sigmoid_": generic_activation_jit("sigmoid_"),
        "aten::relu": generic_activation_jit("relu"),
        "aten::relu_": generic_activation_jit("relu_"),
        "aten::gelu": generic_activation_jit("gelu"),
        "aten::add_": generic_activation_jit("add_"),
        "aten::sub_": generic_activation_jit("sub_"),
        "aten::mul_": generic_activation_jit("mul_"),
        "aten::div_": generic_activation_jit("div_"),
        "aten::sqrt_": generic_activation_jit("sqrt_"),
        "aten::adaptive_avg_pool2d": generic_pooling_jit("adaptive_avg_pool2d"),
        "aten::adaptive_max_pool2d": generic_pooling_jit("adaptive_max_pool2d"),
        "aten::avg_pool2d": generic_pooling_jit("avg_pool2d"),
        "aten::max_pool2d": generic_pooling_jit("max_pool2d"),
        "aten::bmm": bmm_flop_jit,
        "aten::mean": generic_pooling_jit("mean"),
        "aten::var": generic_pooling_jit("var", multiplier=3),  # subtracting mean, exponentiate, summing
        "aten::var_mean": generic_pooling_jit("mean_var", multiplier=4),
        "aten::softmax": softmax_jit,
        "aten::dropout": generic_activation_jit("dropout"),
        "aten::frobenius_norm": generic_pooling_jit("frobenius_norm"),
    }
)

Hi @jongwook ,
I am trying to use your solution but am unable to do so. Where is this ForwardWrapper() defined? Moreover, my model forward() takes 2 inputs - image and its clip preprocessed version. What should be inputs to my flop_count() call?

Kindly help - I am in urgent need of this solution.

@X-funbean
Copy link

Hi @sandipan211, have you figured out what is the best practice to count FLOPs of the CLIP model? I have tried several tools on CLIP with ViT-B/16 (e.g. torchsummaryX, thop, and torchinfo), but got different results. Among them, I think the closest result to the FLOPs plotted in the CLIP paper Learning Transferable Visual Models From Natural Language Supervision (figure below) is achieved by torchinfo, which is 14.04GFLOPs (multi-adds). I also tried the codes provided by @jongwook. However, it gave a result of over 161GFLOPs. In addition, according to the model profile log provided by open_clip (https:/mlfoundations/open_clip/blob/main/docs/model_profile.csv), the computation complexity of CLIP with ViT-B/16 should be 41.09 GFLOPs. Any idea on this?

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants