Skip to content

Commit

Permalink
Fix local thread storage usage and make it typecheck
Browse files Browse the repository at this point in the history
The way we used local thread storage before did not typecheck, since we
assigned to `Thread`. Thread local storage can be a global variable, the
state of this object will be different per thread.
  • Loading branch information
danieldk committed Jan 9, 2024
1 parent d34f536 commit 5c46b82
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
# notebook might not have preserved contextvars across cells.
_GLOBAL_STATE = {"ops": None}

# Thread-local state.
_LOCAL_STATE = threading.local()


def set_gpu_allocator(allocator: str) -> None: # pragma: no cover
"""Route GPU memory allocation via PyTorch or tensorflow.
Expand Down Expand Up @@ -152,22 +155,14 @@ def contextvars_eq_thread_ops() -> bool:
return False


def _get_thread_state():
def _get_thread_state() -> threading.local:
"""Get a thread-specific state variable that inherits from a global
state when it's created."""
thread: threading.Thread = threading.current_thread()
if not hasattr(thread, "__local"):
thread.__local = _create_thread_local(_GLOBAL_STATE)
return thread.__local


def _create_thread_local(
attrs: Dict[str, Any], local_class: Type[threading.local] = threading.local
):
obj = local_class()
for name, value in attrs.items():
setattr(obj, name, value)
return obj
if not hasattr(_LOCAL_STATE, "initialized") or not _LOCAL_STATE.initialized:
for name, value in _GLOBAL_STATE.items():
setattr(_LOCAL_STATE, name, value)
_LOCAL_STATE.initialized = True
return _LOCAL_STATE


__all__ = [
Expand Down

0 comments on commit 5c46b82

Please sign in to comment.