Skip to content

Commit

Permalink
Stub v2 (#271)
Browse files Browse the repository at this point in the history
* Dummy graphql requests module

* graphql api request

* return Model instances

* Update NOTICE (#242)

license name change

* bump main to 1.4.0 (#246)

Co-authored-by: dhuang <[email protected]>

* Pin numpy version to <=1.21.6 (#247)

search

search, download draft

draft, successful search and download

draft

Update: `ModelAnalysis.from_onnx(...)` to additionally work with loaded `ModelProto` (#253)

refactor

search, download

* lint

* pass tests

* init files

* lint

* Add dummy test using test-specific subclass

* tests

* add incremeent_downloads=False

* allow empty arguments

* comments

* query parser, allow dict as input, add tests for extra functionality

* restore models.utils

* restore models.utils

* v2 stub

* comments

* change stubs to ones on prod

* lint

* Update src/sparsezoo/model/utils.py

Co-authored-by: Danny Guinther <[email protected]>

* Update src/sparsezoo/model/utils.py

Co-authored-by: Danny Guinther <[email protected]>

* Update src/sparsezoo/api/utils.py

Co-authored-by: Danny Guinther <[email protected]>

---------

Co-authored-by: Danny Guinther <[email protected]>
Co-authored-by: Jeannie Finks <[email protected]>
Co-authored-by: dhuangnm <[email protected]>
Co-authored-by: dhuang <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>
Co-authored-by: Danny Guinther <[email protected]>
  • Loading branch information
7 people authored Apr 12, 2023
1 parent a9718cc commit 00ccfac
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 53 deletions.
23 changes: 19 additions & 4 deletions src/sparsezoo/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict
from typing import Any, Callable, Dict, List


def to_camel_case(string: str):
Expand All @@ -32,7 +32,22 @@ def to_snake_case(string: str):


def map_keys(
dictionary: Dict[str, str], mapper: Callable[[str], str]
dictionary: Dict[str, Any], mapper: Callable[[str], str]
) -> Dict[str, str]:
"""Given a dictionary, update its key to a given mapper callable"""
return {mapper(key): value for key, value in dictionary.items()}
"""
Given a dictionary, update its keys to a given mapper callable.
If the value of the dict is a List of Dict or Dict of Dict, recursively map
its keys
"""
mapped_dict = {}
for key, value in dictionary.items():
if isinstance(value, List) or isinstance(value, Dict):
value_type = type(value)
mapped_dict[mapper(key)] = value_type(
map_keys(dictionary=sub_dict, mapper=mapper) for sub_dict in value
)
else:
mapped_dict[mapper(key)] = value

return mapped_dict
3 changes: 2 additions & 1 deletion src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sparsezoo.model.utils import (
SAVE_DIR,
ZOO_STUB_PREFIX,
is_stub,
load_files_from_directory,
load_files_from_stub,
save_outputs_to_tar,
Expand Down Expand Up @@ -78,7 +79,7 @@ def __init__(self, source: str, download_path: Optional[str] = None):
self.source = source
self._stub_params = {}

if self.source.startswith(ZOO_STUB_PREFIX):
if is_stub(self.source):
# initializing the files and params from the stub
_setup_args = self.initialize_model_from_stub(stub=self.source)
files, path, url, validation_results, compressed_size = _setup_args
Expand Down
124 changes: 79 additions & 45 deletions src/sparsezoo/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@
SAVE_DIR = os.getenv("SPARSEZOO_MODELS_PATH", CACHE_DIR)
COMPRESSED_FILE_NAME = "model.onnx.tar.gz"

STUB_V1_REGEX_EXPR = (
r"^(zoo:)?"
r"(?P<domain>[\.A-z0-9_]+)"
r"/(?P<sub_domain>[\.A-z0-9_]+)"
r"/(?P<architecture>[\.A-z0-9_]+)(-(?P<sub_architecture>[\.A-z0-9_]+))?"
r"/(?P<framework>[\.A-z0-9_]+)"
r"/(?P<repo>[\.A-z0-9_]+)"
r"/(?P<dataset>[\.A-z0-9_]+)(-(?P<training_scheme>[\.A-z0-9_]+))?"
r"/(?P<sparse_tag>[\.A-z0-9_-]+)"
)

STUB_V2_REGEX_EXPR = (
r"^(zoo:)?"
r"(?P<architecture>[\.A-z0-9_]+)"
r"(-(?P<sub_architecture>[\.A-z0-9_]+))?"
r"-(?P<source_dataset>[\.A-z0-9_]+)"
r"(-(?P<training_dataset>[\.A-z0-9_]+))?"
r"-(?P<sparse_tag>[\.A-z0-9_]+)"
)


def load_files_from_directory(directory_path: str) -> List[Dict[str, Any]]:
"""
Expand Down Expand Up @@ -118,33 +138,44 @@ def load_files_from_stub(
models = api.fetch(
operation_body="models",
arguments=arguments,
fields=["modelId", "modelOnnxSizeCompressedBytes"],
fields=[
"model_id",
"model_onnx_size_compressed_bytes",
"files",
"benchmark_results",
"training_results",
],
)

if len(models):
model_id = models[0]["model_id"]

files = api.fetch(
operation_body="files",
arguments={"model_id": model_id},
matching_models = len(models)
if matching_models == 0:
raise ValueError(
f"No matching models found with stub: {stub}." "Please try another stub"
)
if matching_models > 1:
logging.warning(
f"{len(models)} found from the stub: {stub}"
"Using the first model to obtain metadata."
"Proceed with caution"
)

if matching_models:
model = models[0]

model_id = model["model_id"]

files = model.get("files")
include_file_download_url(files)
files = restructure_request_json(request_json=files)

if params is not None:
files = filter_files(files=files, params=params)

training_results = api.fetch(
operation_body="training_results",
arguments={"model_id": model_id},
)
training_results = model.get("training_results")

benchmark_results = api.fetch(
operation_body="benchmark_results",
arguments={"model_id": model_id},
)
benchmark_results = model.get("benchmark_results")

model_onnx_size_compressed_bytes = models[0]["model_onnx_size_compressed_bytes"]
model_onnx_size_compressed_bytes = model.get("model_onnx_size_compressed_bytes")

throughput_results = [
ThroughputResults(**benchmark_result)
Expand Down Expand Up @@ -553,6 +584,38 @@ def include_file_download_url(files: List[Dict]):
)


def get_model_metadata_from_stub(stub: str) -> Dict[str, str]:
"""Return a dictionary of the model metadata from stub"""

matches = re.match(STUB_V1_REGEX_EXPR, stub) or re.match(STUB_V2_REGEX_EXPR, stub)
if not matches:
return {}

if "source_dataset" in matches.groupdict():
return {"repo_name": stub}

if "dataset" in matches.groupdict():
return {
"domain": matches.group("domain"),
"sub_domain": matches.group("sub_domain"),
"architecture": matches.group("architecture"),
"sub_architecture": matches.group("sub_architecture"),
"framework": matches.group("framework"),
"repo": matches.group("repo"),
"dataset": matches.group("dataset"),
"sparse_tag": matches.group("sparse_tag"),
}

return {}


def is_stub(candidate: str) -> bool:
return bool(
re.match(STUB_V1_REGEX_EXPR, candidate)
or re.match(STUB_V2_REGEX_EXPR, candidate)
)


def get_file_download_url(
model_id: str,
file_name: str,
Expand All @@ -566,32 +629,3 @@ def get_file_download_url(
download_url += "?increment_download=False"

return download_url


def get_model_metadata_from_stub(stub: str) -> Dict[str, str]:
"""
Return a dictionary of the model metadata from stub
"""

stub_regex_expr = (
r"^(zoo:)?"
r"(?P<domain>[\.A-z0-9_]+)"
r"/(?P<sub_domain>[\.A-z0-9_]+)"
r"/(?P<architecture>[\.A-z0-9_]+)(-(?P<sub_architecture>[\.A-z0-9_]+))?"
r"/(?P<framework>[\.A-z0-9_]+)"
r"/(?P<repo>[\.A-z0-9_]+)"
r"/(?P<dataset>[\.A-z0-9_]+)"
r"/(?P<sparse_tag>[\.A-z0-9_-]+)"
)
matches = re.match(stub_regex_expr, stub)

return {
"domain": matches.group("domain"),
"sub_domain": matches.group("sub_domain"),
"architecture": matches.group("architecture"),
"sub_architecture": matches.group("sub_architecture"),
"framework": matches.group("framework"),
"repo": matches.group("repo"),
"dataset": matches.group("dataset"),
"sparse_tag": matches.group("sparse_tag"),
}
43 changes: 40 additions & 3 deletions tests/sparsezoo/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@
("checkpoint", "postqat"),
True,
),
(
"biobert-base_cased-jnlpba_pubmed-pruned80.4block_quantized",
("deployment", "default"),
True,
),
(
"resnet_v1-50-imagenet-pruned95",
("checkpoint", "preqat"),
True,
),
],
scope="function",
)
Expand Down Expand Up @@ -127,20 +137,47 @@ def _assert_validation_results_exist(model):
"stub, clone_sample_outputs, expected_files",
[
(
"zoo:cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned-moderate", # noqa E501
(
"zoo:"
"cv/classification/mobilenet_v1-1.0/"
"pytorch/sparseml/imagenet/pruned-moderate"
),
True,
files_ic,
),
(
"zoo:nlp/question_answering/distilbert-none/pytorch/huggingface/squad/pruned80_quant-none-vnni", # noqa E501
(
"zoo:"
"nlp/question_answering/distilbert-none/"
"pytorch/huggingface/squad/pruned80_quant-none-vnni"
),
False,
files_nlp,
),
(
"zoo:cv/detection/yolov5-s/pytorch/ultralytics/coco/pruned_quant-aggressive_94", # noqa E501
(
"zoo:"
"cv/detection/yolov5-s/"
"pytorch/ultralytics/coco/pruned_quant-aggressive_94"
),
True,
files_yolo,
),
(
"yolov5-x-coco-pruned70.4block_quantized",
False,
files_yolo,
),
(
"yolov5-n6-voc_coco-pruned55",
False,
files_yolo,
),
(
"resnet_v1-50-imagenet-channel30_pruned90_quantized",
False,
files_yolo,
),
],
scope="function",
)
Expand Down

0 comments on commit 00ccfac

Please sign in to comment.