Skip to content

Commit

Permalink
Add type error suppressions for upcoming upgrade (#3109)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3109

X-link: facebookresearch/FBGEMM#196

Reviewed By: connernilsen

Differential Revision: D62447645
  • Loading branch information
Maggie Moss authored and facebook-github-bot committed Sep 10, 2024
1 parent 20cb987 commit 17b7bd0
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 2 deletions.
28 changes: 27 additions & 1 deletion fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,21 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:

# Benchmark forward
time_ref, output_ref = benchmark_torch_function(
torch.index_select, (input, 0, offset_indices), **bench_kwargs
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
torch.index_select,
(input, 0, offset_indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

input_group = input.split(batch_size, 0)
time, output_group = benchmark_torch_function(
torch.ops.fbgemm.group_index_select_dim0,
(input_group, indices_group),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)
logging.info(
Expand All @@ -306,13 +314,19 @@ def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
time_ref, _ = benchmark_torch_function(
functools.partial(output_ref.backward, retain_graph=True),
(grad,),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

# pyre-fixme[6]: For 1st argument expected `Union[List[Tensor],
# typing.Tuple[Tensor, ...]]` but got `Tensor`.
cat_output = torch.cat(output_group)
time, _ = benchmark_torch_function(
functools.partial(cat_output.backward, retain_graph=True),
(grad,),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)
logging.info(
Expand Down Expand Up @@ -714,6 +728,8 @@ def batch_group_index_select_bwd(
time_pyt, out_pyt = benchmark_torch_function(
index_select_fwd_ref,
(inputs, indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -726,12 +742,16 @@ def batch_group_index_select_bwd(
input_rows,
input_columns,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

time_gis, out_gis = benchmark_torch_function(
group_index_select_fwd,
(gis_inputs, indices),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -746,6 +766,8 @@ def batch_group_index_select_bwd(
time_bwd_pyt, _ = benchmark_torch_function(
index_select_bwd_ref,
(out_pyt, grads),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -756,6 +778,8 @@ def batch_group_index_select_bwd(
concat_grads,
optim_batch,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand All @@ -766,6 +790,8 @@ def batch_group_index_select_bwd(
concat_grads,
optim_group,
),
# pyre-fixme[6]: For 3rd argument expected `bool` but got `int`.
# pyre-fixme[6]: For 3rd argument expected `str` but got `int`.
**bench_kwargs,
)

Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,9 @@ def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> N

time_per_iter = benchmark_requests(
requests_uvm,
# pyre-fixme[6]: For 2nd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> Tensor` but got `(indices: Tensor, offsets: Tensor,
# per_sample_weights: Tensor) -> None`.
run_bench,
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
Expand Down Expand Up @@ -1922,6 +1925,9 @@ def nbit_uvm(
indices,
offsets,
),
# pyre-fixme[6]: For 3rd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> None` but got `(indices: Any, offsets: Any,
# indices_weights: Any) -> Tensor`.
lambda indices, offsets, indices_weights: emb_mixed.forward(
indices,
offsets,
Expand Down Expand Up @@ -2409,6 +2415,9 @@ def nbit_cache( # noqa C901
indices,
offsets,
),
# pyre-fixme[6]: For 3rd argument expected `(Tensor, Tensor,
# Optional[Tensor]) -> None` but got `(indices: Any, offsets: Any,
# indices_weights: Any) -> Tensor`.
lambda indices, offsets, indices_weights: emb.forward(
indices,
offsets,
Expand Down Expand Up @@ -3049,6 +3058,7 @@ def device_with_spec( # noqa C901
reuse=reuse,
alpha=alpha,
weighted=weighted,
# pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None,
zipf_oversample_ratio=3 if Ls[t] > 5 else 5,
)
Expand Down
6 changes: 5 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@


def none_throws(
optional: Optional[TypeVar("_T")], message: str = "Unexpected `None`"
# pyre-fixme[31]: Expression `typing.Optional[typing.TypeVar("_T")]` is not a
# valid type.
optional: Optional[TypeVar("_T")],
message: str = "Unexpected `None`",
# pyre-fixme[31]: Expression `typing.TypeVar("_T")` is not a valid type.
) -> TypeVar("_T"):
if optional is None:
raise AssertionError(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,15 @@ class IntNBitTableBatchedEmbeddingBagsCodegen(nn.Module):

embedding_specs: List[Tuple[str, int, int, SparseType, EmbeddingLocation]]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `cache_miss_counter` is never initialized.
cache_miss_counter: torch.Tensor
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
local_uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `weights_offsets` is never initialized.
weights_offsets: torch.Tensor
# pyre-fixme[13]: Attribute `weights_placements` is never initialized.
weights_placements: torch.Tensor

def __init__( # noqa C901
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,12 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
lxu_cache_locations_empty: Tensor
timesteps_prefetched: List[int]
record_cache_metrics: RecordCacheMetrics
# pyre-fixme[13]: Attribute `uvm_cache_stats` is never initialized.
uvm_cache_stats: torch.Tensor
# pyre-fixme[13]: Attribute `local_uvm_cache_stats` is never initialized.
local_uvm_cache_stats: torch.Tensor
uuid: str
# pyre-fixme[13]: Attribute `last_uvm_cache_print_state` is never initialized.
last_uvm_cache_print_state: torch.Tensor
_vbe_B_offsets: Optional[torch.Tensor]
_vbe_max_B: int
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ def forward(
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
feature_requires_grad: Optional[Tensor] = None,
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
) -> Tensor:
indices, offsets, per_sample_weights = self.prepare_inputs(
indices, offsets, per_sample_weights
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def dequantize_embs(
weight_ty: SparseType,
use_cpu: bool,
fp8_config: Optional[FP8QuantizationConfig] = None,
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
) -> torch.Tensor:
print(f"weight_ty: {weight_ty}")
assert (
Expand Down

0 comments on commit 17b7bd0

Please sign in to comment.