Skip to content

Commit

Permalink
bugfixes addressing issues with imports
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Jul 11, 2024
1 parent a70960b commit bc94021
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions makani/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
parser.add_argument("--epsilon_factor", default=0, type=float)
parser.add_argument("--split_data_channels", action="store_true")
parser.add_argument("--mode", default="score", type=str, choices=["score", "ensemble"], help="Select inference mode")
parser.add_argument("--enable_odirect", action="store_true")

# checkpoint format
parser.add_argument("--checkpoint_format", default="legacy", choices=["legacy", "flexible"], type=str, help="Format in which to load checkpoints.")
Expand Down Expand Up @@ -124,6 +125,7 @@
params["amp_mode"] = args.amp_mode
params["jit_mode"] = args.jit_mode
params["cuda_graph_mode"] = args.cuda_graph_mode
params["enable_odirect"] = args.enable_odirect
params["enable_benchy"] = args.enable_benchy
params["disable_ddp"] = args.disable_ddp
params["enable_nhwc"] = args.enable_nhwc
Expand Down
2 changes: 1 addition & 1 deletion makani/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def is_distributed(name: str):
return False


# initialization routine
# initialization routine
def init(model_parallel_sizes=[1, 1, 1, 1],
model_parallel_names=["h", "w", "fin", "fout"],
verbose=False):
Expand Down
6 changes: 3 additions & 3 deletions makani/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_dataloader(params, files_pattern, device, train=True, final_eval=False):
from makani.utils.dataloaders.data_loader_multifiles import MultifilesDataset as MultifilesDataset2D
from torch.utils.data.distributed import DistributedSampler

# multifiles dataset
# multifiles
dataset = MultifilesDataset2D(params, files_pattern, train)

sampler = DistributedSampler(dataset, shuffle=train, num_replicas=params.data_num_shards, rank=params.data_shard_id) if (params.data_num_shards > 1) else None
Expand All @@ -81,8 +81,8 @@ def get_dataloader(params, files_pattern, device, train=True, final_eval=False):
dataset,
batch_size=int(params.batch_size),
num_workers=params.num_data_workers,
shuffle=False, # (sampler is None),
sampler=sampler if train else None,
shuffle=(sampler is None) and train,
sampler=sampler,
drop_last=True,
pin_memory=torch.cuda.is_available(),
)
Expand Down
1 change: 1 addition & 0 deletions makani/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import io
import numpy as np
import concurrent.futures as cf
Expand Down

0 comments on commit bc94021

Please sign in to comment.