diff --git a/lit_nlp/notebook.py b/lit_nlp/notebook.py index e3996529..aae2eb9a 100644 --- a/lit_nlp/notebook.py +++ b/lit_nlp/notebook.py @@ -95,13 +95,16 @@ def _encode(v): class LitWidget(object): """Class for using LIT inside notebooks.""" - def __init__(self, - *args, - height=1000, - render=False, - proxy_url=None, - layouts: Optional[layout.LitComponentLayouts] = None, - **kw): + def __init__( + self, + *args, + height=1000, + render=False, + proxy_url=None, + layouts: Optional[layout.LitComponentLayouts] = None, + warm_start: bool = False, + **kw, + ): """Start LIT server and optionally render the UI immediately. Args: @@ -111,15 +114,16 @@ def __init__(self, to False. proxy_url: Optional proxy URL, if using in a notebook with a server proxy. Defaults to None. - layouts: Optional custom UI layouts. TODO(lit-dev): support simple module - lists here as well. + layouts: Optional custom UI layouts. + warm_start: If true, run predictions for every model on every compatible + dataset before returning a renderable widget. **kw: Keyword arguments for the LitApp. """ app_flags = dict(server_config.get_flags()) app_flags['server_type'] = 'notebook' app_flags['host'] = 'localhost' app_flags['port'] = None - app_flags['warm_start'] = 1 + app_flags['warm_start'] = 1 if warm_start else 0 app_flags['warm_start_progress_indicator'] = progress_indicator app_flags['sync_state'] = True