-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pin numpy version to <=1.21.6 (#247)
search search, download draft draft, successful search and download draft Update: `ModelAnalysis.from_onnx(...)` to additionally work with loaded `ModelProto` (#253) refactor search, download
- Loading branch information
1 parent
20b4ba4
commit c4426e9
Showing
14 changed files
with
439 additions
and
379 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Any, Dict, List, Optional | ||
|
||
import requests | ||
|
||
from sparsezoo.utils import BASE_API_URL | ||
|
||
from .query_parser import QueryParser | ||
from .utils import map_keys, to_snake_case | ||
|
||
|
||
QUERY_BODY = """ | ||
{{ | ||
{operation_body} {arguments} | ||
{{ | ||
{fields} | ||
}} | ||
}} | ||
""" | ||
|
||
|
||
class GraphQLAPI: | ||
|
||
@staticmethod | ||
def get_file_download_url( | ||
model_id: str, | ||
file_name: str, | ||
base_url: str = BASE_API_URL, | ||
): | ||
"""Url to download a file""" | ||
return f"{base_url}/v2/models/{model_id}/files/{file_name}" | ||
|
||
@staticmethod | ||
def fetch( | ||
operation_body: str, | ||
arguments: Dict[str, str], | ||
fields: Optional[List[str]] = None, | ||
url: Optional[str] = None, | ||
) -> Dict[str, Any]: | ||
""" | ||
Fetch data for models via api. Uses graohql convention of post, not get for requests. | ||
Input args are parsed to make a query body for the api request. | ||
For more details on the appropriate values, please refer to the url endpoint on the browser | ||
:param operation_body: The data object of interest | ||
:param arguments: Used to filter data object in the backend | ||
:param field: the object's field of interest | ||
""" | ||
|
||
query = QueryParser( | ||
operation_body=operation_body, arguments=arguments, fields=fields | ||
) | ||
query.parse() | ||
|
||
response = requests.post( | ||
url=url or f"{BASE_API_URL}/v2/graphql", | ||
json={ | ||
"query": QUERY_BODY.format( | ||
operation_body=query.operation_body, | ||
arguments=query.arguments, | ||
fields=query.fields, | ||
) | ||
}, | ||
) | ||
|
||
response.raise_for_status() | ||
response_json = response.json() | ||
|
||
respose_objects = response_json["data"][query.operation_body] | ||
|
||
return [ | ||
map_keys(dictionary=respose_object, mapper=to_snake_case) | ||
for respose_object in respose_objects | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import Dict, List, Optional | ||
|
||
from .utils import to_camel_case | ||
|
||
|
||
DEFAULT_MODELS_FIELDS = ["modelId", "stub"] | ||
|
||
DEFAULT_FILES_FIELDS = [ | ||
"displayName", | ||
"fileSize", | ||
"modelId", | ||
] | ||
|
||
DEFAULT_TRAINING_RESULTS_FIELDS = [ | ||
"datasetName", | ||
"datasetType", | ||
"recordedUnits", | ||
"recordedValue", | ||
] | ||
|
||
DEFAULT_BENCHMARK_RESULTS_FIELDS = [ | ||
"batchSize", | ||
"deviceInfo", | ||
"numCores", | ||
"recordedUnits", | ||
"recordedValue", | ||
] | ||
|
||
DEPRECATED_STUB_ARGS_MAPPER = {"sub_domain": "task", "dataset": "source_dataset"} | ||
DEFAULT_FIELDS = { | ||
"models": DEFAULT_MODELS_FIELDS, | ||
"files": DEFAULT_FILES_FIELDS, | ||
"trainingResults": DEFAULT_TRAINING_RESULTS_FIELDS, | ||
"benchmarkResults": DEFAULT_BENCHMARK_RESULTS_FIELDS, | ||
} | ||
|
||
|
||
class QueryParser: | ||
"""Parse the class input arg fields to be used for graphql post requests""" | ||
|
||
def __init__( | ||
self, | ||
operation_body: str, | ||
arguments: Dict[str, str], | ||
fields: Optional[List[str]] = None, | ||
): | ||
self._operation_body = operation_body | ||
self._arguments = arguments | ||
self._fields = fields | ||
|
||
def parse(self): | ||
"""Parse to a string compatible with graphql requst body""" | ||
|
||
self._parse_operation_body() | ||
self._parse_arguments() | ||
self._parse_fields() | ||
|
||
def _parse_operation_body(self) -> None: | ||
self._operation_body = to_camel_case(self._operation_body) | ||
|
||
def _parse_arguments(self) -> None: | ||
"""Transform deprecated stub args and convert to camel case""" | ||
parsed_arguments = "" | ||
for key, value in self.arguments.items(): | ||
if value is not None: | ||
contemporary_key = DEPRECATED_STUB_ARGS_MAPPER.get(key, key) | ||
camel_case_key = to_camel_case(contemporary_key) | ||
|
||
# single, double quotes matters | ||
parsed_arguments += f'{camel_case_key}: "{value}",' | ||
|
||
if bool(parsed_arguments): | ||
parsed_arguments = "(" + parsed_arguments + ")" | ||
|
||
self._arguments = parsed_arguments | ||
|
||
def _parse_fields(self) -> None: | ||
fields = self.fields or DEFAULT_FIELDS.get(self.operation_body) | ||
|
||
parsed_fields = "" | ||
for field in fields: | ||
camel_case_field = to_camel_case(field) | ||
parsed_fields += rf"{camel_case_field} " | ||
self.fields = parsed_fields | ||
|
||
@property | ||
def operation_body(self) -> str: | ||
"""Return the query operation body""" | ||
return self._operation_body | ||
|
||
@operation_body.setter | ||
def operation_body(self, operation_body: str) -> None: | ||
self._operation_body = operation_body | ||
|
||
@property | ||
def arguments(self) -> str: | ||
"""Return the query arguments""" | ||
return self._arguments | ||
|
||
@arguments.setter | ||
def arguments(self, arguments: str) -> None: | ||
self._operation_body = arguments | ||
|
||
@property | ||
def fields(self) -> str: | ||
"""Return the query fields""" | ||
return self._fields | ||
|
||
@fields.setter | ||
def fields(self, fields: str) -> None: | ||
self._fields = fields | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Callable, Dict, List | ||
|
||
|
||
def to_camel_case(string: str): | ||
"Concert string to snake case" | ||
components = string.split("_") | ||
return components[0] + "".join(word.title() for word in components[1:]) | ||
|
||
|
||
def to_snake_case(string: str): | ||
"Convert string to snake case" | ||
return "".join( | ||
[ | ||
"_" + character.lower() if character.isupper() else character | ||
for character in string | ||
] | ||
).lstrip("_") | ||
|
||
|
||
def map_keys( | ||
dictionary: Dict[str, str], mapper: Callable[[str], str] | ||
) -> Dict[str, str]: | ||
"""Given a dictionary, update its key to a given mapper callable. (ex. to_snake_case)""" | ||
mapped_dict = {} | ||
for key, value in dictionary.items(): | ||
mapped_dict[mapper(key)] = value | ||
|
||
return mapped_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.