Skip to content

Commit

Permalink
Pin numpy version to <=1.21.6 (#247)
Browse files Browse the repository at this point in the history
search

search, download draft

draft, successful search and download

draft

Update: `ModelAnalysis.from_onnx(...)` to additionally work with loaded `ModelProto` (#253)

refactor

search, download
  • Loading branch information
rahul-tuli authored and horheynm committed Feb 1, 2023
1 parent 20b4ba4 commit c4426e9
Show file tree
Hide file tree
Showing 14 changed files with 439 additions and 379 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
_PACKAGE_NAME = "sparsezoo" if is_release else "sparsezoo-nightly"

_deps = [
"numpy>=1.0.0",
"numpy>=1.0.0,<=1.21.6",
"onnx>=1.5.0,<=1.12.0",
"pyyaml>=5.1.0",
"requests>=2.0.0",
Expand Down
13 changes: 9 additions & 4 deletions src/sparsezoo/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy
import onnx
import yaml
from onnx import NodeProto
from onnx import ModelProto, NodeProto
from pydantic import BaseModel, Field

from sparsezoo.analysis.utils.models import (
Expand Down Expand Up @@ -424,15 +424,20 @@ class ModelAnalysis(BaseModel):
)

@classmethod
def from_onnx(cls, onnx_file_path: str):
def from_onnx(cls, onnx_file_path: Union[str, ModelProto]):
"""
Model Analysis
:param cls: class being constructed
:param onnx_file_path: path to onnx file being analyzed
:param onnx_file_path: path to onnx file, or a loaded onnx ModelProto to
analyze
:return: instance of cls
"""
model_onnx = onnx.load(onnx_file_path)
model_onnx = (
onnx_file_path
if isinstance(onnx_file_path, ModelProto)
else onnx.load(onnx_file_path)
)
model_graph = ONNXGraph(model_onnx)

node_analyses = cls.analyze_nodes(model_graph)
Expand Down
88 changes: 88 additions & 0 deletions src/sparsezoo/api/graphql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Any, Dict, List, Optional

import requests

from sparsezoo.utils import BASE_API_URL

from .query_parser import QueryParser
from .utils import map_keys, to_snake_case


QUERY_BODY = """
{{
{operation_body} {arguments}
{{
{fields}
}}
}}
"""


class GraphQLAPI:

@staticmethod
def get_file_download_url(
model_id: str,
file_name: str,
base_url: str = BASE_API_URL,
):
"""Url to download a file"""
return f"{base_url}/v2/models/{model_id}/files/{file_name}"

@staticmethod
def fetch(
operation_body: str,
arguments: Dict[str, str],
fields: Optional[List[str]] = None,
url: Optional[str] = None,
) -> Dict[str, Any]:
"""
Fetch data for models via api. Uses graohql convention of post, not get for requests.
Input args are parsed to make a query body for the api request.
For more details on the appropriate values, please refer to the url endpoint on the browser
:param operation_body: The data object of interest
:param arguments: Used to filter data object in the backend
:param field: the object's field of interest
"""

query = QueryParser(
operation_body=operation_body, arguments=arguments, fields=fields
)
query.parse()

response = requests.post(
url=url or f"{BASE_API_URL}/v2/graphql",
json={
"query": QUERY_BODY.format(
operation_body=query.operation_body,
arguments=query.arguments,
fields=query.fields,
)
},
)

response.raise_for_status()
response_json = response.json()

respose_objects = response_json["data"][query.operation_body]

return [
map_keys(dictionary=respose_object, mapper=to_snake_case)
for respose_object in respose_objects
]
128 changes: 128 additions & 0 deletions src/sparsezoo/api/query_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Dict, List, Optional

from .utils import to_camel_case


DEFAULT_MODELS_FIELDS = ["modelId", "stub"]

DEFAULT_FILES_FIELDS = [
"displayName",
"fileSize",
"modelId",
]

DEFAULT_TRAINING_RESULTS_FIELDS = [
"datasetName",
"datasetType",
"recordedUnits",
"recordedValue",
]

DEFAULT_BENCHMARK_RESULTS_FIELDS = [
"batchSize",
"deviceInfo",
"numCores",
"recordedUnits",
"recordedValue",
]

DEPRECATED_STUB_ARGS_MAPPER = {"sub_domain": "task", "dataset": "source_dataset"}
DEFAULT_FIELDS = {
"models": DEFAULT_MODELS_FIELDS,
"files": DEFAULT_FILES_FIELDS,
"trainingResults": DEFAULT_TRAINING_RESULTS_FIELDS,
"benchmarkResults": DEFAULT_BENCHMARK_RESULTS_FIELDS,
}


class QueryParser:
"""Parse the class input arg fields to be used for graphql post requests"""

def __init__(
self,
operation_body: str,
arguments: Dict[str, str],
fields: Optional[List[str]] = None,
):
self._operation_body = operation_body
self._arguments = arguments
self._fields = fields

def parse(self):
"""Parse to a string compatible with graphql requst body"""

self._parse_operation_body()
self._parse_arguments()
self._parse_fields()

def _parse_operation_body(self) -> None:
self._operation_body = to_camel_case(self._operation_body)

def _parse_arguments(self) -> None:
"""Transform deprecated stub args and convert to camel case"""
parsed_arguments = ""
for key, value in self.arguments.items():
if value is not None:
contemporary_key = DEPRECATED_STUB_ARGS_MAPPER.get(key, key)
camel_case_key = to_camel_case(contemporary_key)

# single, double quotes matters
parsed_arguments += f'{camel_case_key}: "{value}",'

if bool(parsed_arguments):
parsed_arguments = "(" + parsed_arguments + ")"

self._arguments = parsed_arguments

def _parse_fields(self) -> None:
fields = self.fields or DEFAULT_FIELDS.get(self.operation_body)

parsed_fields = ""
for field in fields:
camel_case_field = to_camel_case(field)
parsed_fields += rf"{camel_case_field} "
self.fields = parsed_fields

@property
def operation_body(self) -> str:
"""Return the query operation body"""
return self._operation_body

@operation_body.setter
def operation_body(self, operation_body: str) -> None:
self._operation_body = operation_body

@property
def arguments(self) -> str:
"""Return the query arguments"""
return self._arguments

@arguments.setter
def arguments(self, arguments: str) -> None:
self._operation_body = arguments

@property
def fields(self) -> str:
"""Return the query fields"""
return self._fields

@fields.setter
def fields(self, fields: str) -> None:
self._fields = fields


42 changes: 42 additions & 0 deletions src/sparsezoo/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, List


def to_camel_case(string: str):
"Concert string to snake case"
components = string.split("_")
return components[0] + "".join(word.title() for word in components[1:])


def to_snake_case(string: str):
"Convert string to snake case"
return "".join(
[
"_" + character.lower() if character.isupper() else character
for character in string
]
).lstrip("_")


def map_keys(
dictionary: Dict[str, str], mapper: Callable[[str], str]
) -> Dict[str, str]:
"""Given a dictionary, update its key to a given mapper callable. (ex. to_snake_case)"""
mapped_dict = {}
for key, value in dictionary.items():
mapped_dict[mapper(key)] = value

return mapped_dict
4 changes: 3 additions & 1 deletion src/sparsezoo/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from sparsezoo.inference import ENGINES, InferenceRunner
from sparsezoo.model.result_utils import ModelResult
from sparsezoo.model.utils import (
from sparsezoo.model.utils import (
SAVE_DIR,
ZOO_STUB_PREFIX,
load_files_from_directory,
Expand Down Expand Up @@ -351,7 +351,9 @@ def initialize_model_from_stub(
path = os.path.join(SAVE_DIR, model_id)
if not files:
raise ValueError(f"No files found for given stub {stub}")

url = os.path.dirname(files[0].get("url"))

return files, path, url, validation_results, size

@staticmethod
Expand Down
Loading

0 comments on commit c4426e9

Please sign in to comment.