Skip to content

Commit

Permalink
Fix warning "Unknown parameter: listen_time_out"
Browse files Browse the repository at this point in the history
  • Loading branch information
mlemainque committed Jan 2, 2019
1 parent 3add0df commit 5ef8107
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
12 changes: 5 additions & 7 deletions dask_lightgbm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def parse_host_port(address):
return host, port


def build_network_params(worker_addresses, local_worker_ip, local_listen_port, listen_time_out):
def build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out):
addr_port_map = {addr: (local_listen_port + i) for i, addr in enumerate(worker_addresses)}
params = {
"machines": ",".join([parse_host_port(addr)[0] + ":" + str(port) for addr, port in addr_port_map.items()]),
"local_listen_port": addr_port_map[local_worker_ip],
"listen_time_out": listen_time_out,
"time_out": time_out,
"num_machines": len(addr_port_map)
}
return params
Expand All @@ -56,10 +56,8 @@ def concat(L):
raise TypeError("Data must be either numpy arrays or pandas dataframes. Got %s" % type(L[0]))


def _fit_local(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400, listen_time_out=120,
**kwargs):
network_params = build_network_params(worker_addresses, get_worker().address, local_listen_port,
listen_time_out)
def _fit_local(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400, time_out=120):
network_params = build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out)
params = {**params, **network_params}

# Prepare data
Expand Down Expand Up @@ -127,7 +125,7 @@ def train(client, X, y, params, model_factory, sample_weight=None, **kwargs):
list_of_parts=list_of_parts,
worker_addresses=list(worker_map.keys()),
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
time_out=params.get("time_out", 120),
return_model=worker==master_worker,
**kwargs)
for worker, list_of_parts in worker_map.items()]
Expand Down
2 changes: 1 addition & 1 deletion dask_lightgbm/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_build_network_params():
"machines": "192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402",
"local_listen_port": 12401,
"num_machines": len(workers_ips),
"listen_time_out": 120
"time_out": 120
}
assert exp_params == params

Expand Down

0 comments on commit 5ef8107

Please sign in to comment.