Skip to content

Commit

Permalink
Add Index simplification & API update (apache#26)
Browse files Browse the repository at this point in the history
* Add vectorized cooperative_fetching test

* Update math simplify for vectorized CF

* File rename

* Update tune_network

* API update
  • Loading branch information
jcf94 authored and merrymercy committed Jun 20, 2020
1 parent 2f241ed commit 18d44b8
Show file tree
Hide file tree
Showing 18 changed files with 344 additions and 111 deletions.
3 changes: 0 additions & 3 deletions python/tvm/ansor/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ def auto_schedule(workload, target=None,
Returns
-------
state : State
sch : tvm.Schedule
tensors : List[Tensor]
Expand All @@ -270,4 +268,3 @@ def auto_schedule(workload, target=None,
else:
raise ValueError("Invalid workload: " + workload +
". Expect a string or SearchTask")

24 changes: 19 additions & 5 deletions python/tvm/ansor/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import tvm._ffi
from tvm.runtime import Object
from tvm import te
from .loop_state import State
from .loop_state import State, StateObject
from . import _ffi_api


Expand Down Expand Up @@ -63,8 +63,12 @@ def apply_steps_from_state(self, state, layout_rewrite_level=None):
sch : Schedule
args : List[Tensor]
"""
sch, args = _ffi_api.ComputeDAGApplyStepsFromState(self, state)
return sch, args
if isinstance(state, State):
return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object)
elif isinstance(state, StateObject):
return _ffi_api.ComputeDAGApplyStepsFromState(self, state)
else:
raise ValueError("The input must be a State or StateObject")

def print_python_code_from_state(self, state):
"""
Expand All @@ -76,7 +80,12 @@ def print_python_code_from_state(self, state):
-------
str : Str
"""
return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state)
if isinstance(state, State):
return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state.state_object)
elif isinstance(state, StateObject):
return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state)
else:
raise ValueError("The input must be a State or StateObject")

def infer_bound_from_state(self, state):
"""
Expand All @@ -88,7 +97,12 @@ def infer_bound_from_state(self, state):
-------
state : StateObject
"""
return _ffi_api.ComputeDAGInferBoundFromState(self, state)
if isinstance(state, State):
return State(_ffi_api.ComputeDAGInferBoundFromState(self, state.state_object))
elif isinstance(state, StateObject):
return State(_ffi_api.ComputeDAGInferBoundFromState(self, state))
else:
raise ValueError("The input must be a State or StateObject")

def gen_schedule(state, bufs):
if not state or not state.complete:
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/ansor/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import struct
import numpy as np

from .loop_state import StateObject
from .loop_state import State, StateObject
from .measure import MeasureInput, MeasureResult
from . import _ffi_api

Expand Down Expand Up @@ -131,12 +131,16 @@ def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput],
return unpack_feature(byte_arr)


def get_per_stmt_features_from_states(states: List[StateObject],
def get_per_stmt_features_from_states(states,
task: "SearchTask",
max_n_bufs: int = None) -> List[np.ndarray]:
"""Get per_stmt features from states"""
if isinstance(states[0], State):
state_objects = [s.state_object for s in states]
elif isinstance(states[0], StateObject):
state_objects = states
byte_arr = _ffi_api.GetPerStmtFeaturesFromStates(
states, task, max_n_bufs or DEFAULT_MAX_N_BUFS)
state_objects, task, max_n_bufs or DEFAULT_MAX_N_BUFS)
return unpack_feature(byte_arr)[0]


Expand Down
14 changes: 1 addition & 13 deletions python/tvm/ansor/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,6 @@ def iters(self):
setattr(self, "iterators_cache", _ffi_api.StageGetIterators(self))
return getattr(self, "iterators_cache")

def iter(self, index):
"""
Parameters
----------
index : Int
Returns
-------
iter : Iterator
"""
return _ffi_api.StageGetIterator(self, index)


@tvm._ffi.register_object("ansor.State")
class StateObject(Object):
Expand Down Expand Up @@ -302,7 +290,7 @@ def bind_thread(self, stage_id, it, thread_name):
}
thread_id = trans_table[thread_name]

self.state_object, res = _ffi_api.StateUnroll(self.state_object, stage_id, it, thread_id)
self.state_object, res = _ffi_api.StateBindThread(self.state_object, stage_id, it, thread_id)
self.clear_cache()
return res

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ansor/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class MeasureInput(Object):
"""

def __init__(self, task, state):
self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state)
self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)


@tvm._ffi.register_object("ansor.BuildResult")
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/ansor/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tvm._ffi
from tvm.runtime import Object
from .measure import MeasureCallback, MeasureErrorNo
from .loop_state import State
from . import _ffi_api


Expand Down Expand Up @@ -74,7 +75,8 @@ def write_measure_records_to_file(filename, inputs, results):

def get_states_from_measure_inputs(inputs, task):
"""Get states from measure inputs"""
return _ffi_api.GetStatesFromMeasureInputs(inputs, task)
state_objects = _ffi_api.GetStatesFromMeasureInputs(inputs, task)
return [State(s) for s in state_objects]


def best_measure_pair_in_file(filename, workload_key=None, target=None):
Expand Down
136 changes: 71 additions & 65 deletions scripts/tune_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,22 +191,14 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof
return module, ctx


def tune_and_evaluate(network_name, model_path, batch_size, target, target_host,
local_measure, device_key, host, port, n_parallel, ndk_cc,
build_timeout, run_timeout, num_threads, tune, check_correctness,
debug_profile, tuning_parameters, record_file, layout_set):
task_scheduler, model_type, policy, log_file, load_log_file = (tuning_parameters['task_scheduler'],
tuning_parameters['model_type'], tuning_parameters['policy'],
tuning_parameters['log_file'], tuning_parameters['load_log_file'])

if layout_set:
layout = layout_set

def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune,
debug_profile, check_correctness, network_parameters,
task_scheduler_parameters, tune_parameters, module_parameters):
# Extract workloads from relay program
print("=============== Extract workloads ===============")
mod, params, input_name, data_shape, out_shape = get_network(network_name, model_path, batch_size, layout)
mod, params, input_name, data_shape, out_shape = get_network(**network_parameters)

if tune:
print("=============== Extracting workloads ===============")
workloads, wkl_weights = ansor.extract_from_program(mod, target=target,
params=params, ops=(relay.op.nn.dense, relay.op.nn.softmax,
relay.op.nn.conv2d, relay.op.nn.conv2d_transpose,
Expand All @@ -215,7 +207,7 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host,
relay.op.nn.conv3d, relay.op.nn.adaptive_avg_pool3d,
relay.op.nn.batch_matmul, relay.op.mean,
))
print("Total workload number: %d" % (len(workloads)))
print("Totally %d workload extracted." % (len(workloads)))

# Tune workloads with auto scheduler
print("=============== Tuning ===============")
Expand All @@ -225,23 +217,13 @@ def tune_and_evaluate(network_name, model_path, batch_size, target, target_host,
print("[========= Task %d =========]\n" % i, dag)
tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host))

def objective_func(costs):
return sum(c * w for c, w in zip(costs, wkl_weights))

tuner = ansor.SimpleTaskScheduler(tasks, objective_func, strategy=task_scheduler,
load_log_file=load_log_file,
load_model_file=tuning_parameters['load_model'])
tuner = ansor.SimpleTaskScheduler(tasks,
lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)),
**task_scheduler_parameters)
tune_option, measure_ctx = create_tune_option(target, **tune_parameters)

tune_option, measure_ctx = create_tune_option(target, log_file,
tuning_parameters['n_trials'], tuning_parameters['num_measure_per_iter'],
tuning_parameters['verbose'], n_parallel, build_timeout,
local_measure, device_key, host, port, ndk_cc,
tuning_parameters['early_stopping'])
search_policy = "%s.%s" % (policy, model_type)

if local_measure and target.target_name != 'cuda':
if tune_parameters['local_measure'] and target.target_name != 'cuda':
os.environ['TVM_BIND_MASTER_CORE_0'] = "1"

tuner.tune(tune_option, search_policy)

if measure_ctx:
Expand All @@ -251,15 +233,13 @@ def objective_func(costs):

# Compile graph with best states found by auto-scheduler
print("=============== Compile ===============")
with ansor.apply_history_best(log_file, args.log_n_lines):
#if True:
#with ansor.BlockingEmptyContext():
with ansor.apply_history_best(tune_parameters['log_file'], log_n_lines):
os.environ['TVM_AUTO_CACHE_FLUSH'] = "0"
os.environ['TVM_BIND_MASTER_CORE_0'] = "1"
if kernel_layout_rewrite:
ansor.prepare_layout_rewrite(mod, target=target,
params=params,
ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d))
params=params,
ops=(relay.op.nn.dense, relay.op.nn.conv2d, relay.op.nn.conv3d))
else:
# disable layout rewrite
ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE
Expand All @@ -268,22 +248,12 @@ def objective_func(costs):
with relay.build_config(opt_level=3):
graph, lib, opt_params = relay.build_module.build(
mod, target=target, params=params)
'''
from tvm.relay.backend import graph_runtime_codegen
with relay.build_config(opt_level=3):
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])
with tvm.transform.PassContext(opt_level=3):
graph, lib, opt_params = relay.build_module.build(
mod, target=target, params=params)
'''

ansor.finish_layout_rewrite()
print("=============== Compile Finish ===============")

module, ctx = create_module(data_shape, graph, lib, target, input_name, opt_params,
debug_profile, local_measure, ndk_cc,
device_key, host, port, run_timeout, num_threads)
module, ctx = create_module(data_shape, graph, lib, target, input_name,
opt_params, debug_profile, **module_parameters)

# Evaluate
print("========== Evaluate ==========")
Expand Down Expand Up @@ -315,9 +285,8 @@ def objective_func(costs):
graph, lib, opt_params = relay.build_module.build(
mod, target=target, params=params)

module, _ = create_module(data_shape, graph, lib, target, input_name, opt_params,
debug_profile, local_measure, ndk_cc,
device_key, host, port, run_timeout, num_threads)
module, _ = create_module(data_shape, graph, lib, target, input_name,
opt_params, debug_profile, **module_parameters)
module.run()

expected_output = module.get_output(0).asnumpy()
Expand All @@ -343,7 +312,7 @@ def objective_func(costs):
# Strategy related options
parser.add_argument("--seed", type=int, default=0, help='random seed')
parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'],
default='meta-rewrite')
default='meta-rewrite')
parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb')
parser.add_argument("--task-scheduler", type=str, default='gradient',
choices=['no', 'gradient', 'round-robin'],
Expand All @@ -359,6 +328,7 @@ def objective_func(costs):
# Detailed control options
parser.add_argument("--build-timeout", type=int, default=10)
parser.add_argument("--run-timeout", type=int, default=10)
parser.add_argument("--early-stopping", type=int, default=-1)
parser.add_argument("--verbose", type=int, default=1)
parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True)
parser.add_argument("--device-key", type=str, default=None)
Expand All @@ -375,23 +345,59 @@ def objective_func(costs):
logging.getLogger('ansor').setLevel(logging.DEBUG)

target = tvm.target.create(args.target)
log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size,
target.target_name)
load_log_file = args.load_log or log_file
search_policy = "%s.%s" % (args.policy, args.model_type)
if args.layout:
layout = args.layout
elif target.target_name == "cuda":
layout = "NCHW"
else:
layout = "NHWC"

network_parameters = {
'name': args.network,
'model_path': args.model_path,
'batch_size': args.batch_size,
'layout': layout
}

task_scheduler_parameters = {
'strategy': args.task_scheduler,
'load_log_file': load_log_file,
'load_model_file': args.load_model,
'verbose': args.verbose,
}

tuning_parameters = {
control_parameters = {
'local_measure': args.local_measure,
'device_key': args.device_key,
'host': args.host,
'port': args.port,
'ndk_cc': args.ndk_cc,
}

tune_parameters = {
'log_file': log_file,
'n_trials': args.n_trials,
'num_measure_per_iter': args.num_measure_per_iter,
'log_file': args.log_file or "%s-B%d.json" % (args.network, args.batch_size),
'load_model': args.load_model,
'model_type': args.model_type,
'task_scheduler': args.task_scheduler,
'policy': args.policy,
'early_stopping': -1,
'verbose': 1,
'verbose': args.verbose,
'n_parallel': args.n_parallel,
'build_timeout': args.build_timeout,
'run_timeout': args.run_timeout,
'early_stopping': args.early_stopping,
**control_parameters
}

module_parameters = {
'run_timeout': args.run_timeout,
'num_threads': args.num_threads,
**control_parameters
}
tuning_parameters['load_log_file'] = args.load_log or tuning_parameters['log_file']

os.environ["TOPHUB_LOCATION"] = "NONE"
tune_and_evaluate(args.network, args.model_path, args.batch_size, target, args.target_host,
args.local_measure, args.device_key, args.host,
args.port, args.n_parallel, args.ndk_cc, args.build_timeout,
args.run_timeout, args.num_threads, args.tune, args.check_correctness,
args.debug_profile, tuning_parameters, args.out_file, args.layout)
tune_and_evaluate(target, args.target_host, args.log_n_lines, search_policy,
args.tune, args.debug_profile, args.check_correctness,
network_parameters, task_scheduler_parameters, tune_parameters,
module_parameters)
4 changes: 2 additions & 2 deletions scripts/tune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose,
n_parallel, build_timeout, local_measure, device_key, host,
port, ndk_cc, early_stopping=-1):
port, ndk_cc, early_stopping=-1, run_timeout=10):
builder = runner = measure_ctx = None
if local_measure:
builder = ansor.LocalBuilder(timeout=build_timeout)
Expand All @@ -26,7 +26,7 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose
else:
os.environ['TVM_NDK_CC'] = ndk_cc
builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk')
runner = ansor.RPCRunner(key=device_key, host=host, port=port,
runner = ansor.RPCRunner(key=device_key, host=host, port=port, timeout=run_timeout,
n_parallel=n_parallel, repeat=1, min_repeat_ms=400)

tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping,
Expand Down
5 changes: 0 additions & 5 deletions src/ansor/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1019,11 +1019,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
PrintState(&p->stream, node, true);
});


TVM_REGISTER_GLOBAL("ansor.StageGetIterator").set_body_typed([](const Stage& stage, int index) {
return stage->iters[index];
});

TVM_REGISTER_GLOBAL("ansor.StageGetIterators").set_body_typed([](const Stage& stage) {
return Array<Iterator>(stage->iters);
});
Expand Down
Loading

0 comments on commit 18d44b8

Please sign in to comment.