Skip to content

Commit

Permalink
[serving] add default dtype when running in deepspeed (deepjavalibrar…
Browse files Browse the repository at this point in the history
  • Loading branch information
tosterberg authored Apr 11, 2023
1 parent 294ff4d commit 0681015
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def get_torch_dtype_from_str(dtype: str):
raise ValueError(f"Invalid data type: {dtype}")


def default_dtype():
if torch.cuda.is_available():
if torch.cuda.is_bf16_supported():
return "bf16"
return "fp16"
return "fp32"


class DeepSpeedService(object):

def __init__(self):
Expand Down Expand Up @@ -122,7 +130,7 @@ def _parse_properties(self, properties):
self.model_id_or_path = properties.get("model_id") or properties.get(
"model_dir")
self.task = properties.get("task")
self.data_type = get_torch_dtype_from_str(properties.get("dtype"))
self.data_type = get_torch_dtype_from_str(properties.get("dtype", default_dtype()))
self.max_tokens = int(properties.get("max_tokens", 1024))
self.device = int(os.getenv("LOCAL_RANK", 0))
self.tensor_parallel_degree = int(
Expand Down

0 comments on commit 0681015

Please sign in to comment.