Skip to content

Commit

Permalink
Support persisting the LoadMode.VIRTUALENV directory (#1079)
Browse files Browse the repository at this point in the history
## Description

Added `virtualenv_dir` as an option to `ExecutionConfig` which is then
propagated downstream to `DbtVirtualenvBaseOperator`.

The following now happens:
- If the flag is set, the operator will attempt to locate the `venv`'s
`python` binary under the provided `virtualenv_dir`.
- If so, it will conclude that the `venv` exists and continues without
creating a new one.
    - If not, it will create a new one at `virtualenv_dir`
- If the flag is not set, simply continue using the temporary directory
solution that was already in place.

## Impact
A very basic test using a local `docker compose` set-up as per the
contribution guide and the
[example_virtualenv](https:/astronomer/astronomer-cosmos/blob/main/dev/dags/example_virtualenv.py)
DAG saw the DAG's runtime go down from **2m31s** to just **32s**. I'd
this improvement to be even more noticeable with more complex graphs and
more python requirements.
 
## Related Issue(s)
Closes: #610 
Partially solves: #1042
Follow up ticket: #1157

## Breaking Change?
None, the flag is optional and is ignored (with a
[warning](https:/astronomer/astronomer-cosmos/compare/main...LennartKloppenburg:astronomer-cosmos:feature/cache-virtualenv?expand=1#diff-61b585fb903927b6868b9626c95e0ec47e3818eb477d795ebd13b0276d4fd76cR125))
when used outside of `VirtualEnv` execution mode.

## Important notice

Most of the changes in this PR were originally implemented in PR #611 by
@LennartKloppenburg. It became stale over the last few months due to
limited maintainer availability. Our sincere apologies to the original
author.

What was accomplished since:
1. Rebased
2. Fixed conflicts
3. Fixed failing tests
4. Introduced new tests

Co-authored-by: Lennart Kloppenburg <[email protected]>
  • Loading branch information
tatiana and LennartKloppenburg authored Aug 16, 2024
1 parent 41053ed commit 4273d99
Show file tree
Hide file tree
Showing 10 changed files with 523 additions and 152 deletions.
4 changes: 4 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ class ExecutionConfig:
:param test_indirect_selection: The mode to configure the test behavior when performing indirect selection.
:param dbt_executable_path: The path to the dbt executable for runtime execution. Defaults to dbt if available on the path.
:param dbt_project_path: Configures the DBT project location accessible at runtime for dag execution. This is the project path in a docker container for ExecutionMode.DOCKER or ExecutionMode.KUBERNETES. Mutually Exclusive with ProjectConfig.dbt_project_path
:param virtualenv_dir: Directory path to locate the (cached) virtual env that
should be used for execution when execution mode is set to `ExecutionMode.VIRTUALENV`
"""

execution_mode: ExecutionMode = ExecutionMode.LOCAL
Expand All @@ -367,6 +369,8 @@ class ExecutionConfig:
dbt_executable_path: str | Path = field(default_factory=get_system_dbt)

dbt_project_path: InitVar[str | Path | None] = None
virtualenv_dir: str | Path | None = None

project_path: Path | None = field(init=False)

def __post_init__(self, dbt_project_path: str | Path | None) -> None:
Expand Down
22 changes: 19 additions & 3 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,26 @@ def __init__(
validate_initial_user_config(execution_config, profile_config, project_config, render_config, operator_args)

if project_config.dbt_project_path:
execution_config, render_config = migrate_to_new_interface(execution_config, project_config, render_config)
# We copy the configuration so the change does not affect other DAGs or TaskGroups
# that may reuse the same original configuration
render_config = copy.deepcopy(render_config)
execution_config = copy.deepcopy(execution_config)
render_config.project_path = project_config.dbt_project_path
execution_config.project_path = project_config.dbt_project_path

validate_adapted_user_config(execution_config, project_config, render_config)

env_vars = project_config.env_vars or operator_args.get("env")
dbt_vars = project_config.dbt_vars or operator_args.get("vars")
env_vars = copy.deepcopy(project_config.env_vars or operator_args.get("env"))
dbt_vars = copy.deepcopy(project_config.dbt_vars or operator_args.get("vars"))

if execution_config.execution_mode != ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None:
logger.warning(
"`ExecutionConfig.virtualenv_dir` is only supported when \
ExecutionConfig.execution_mode is set to ExecutionMode.VIRTUALENV."
)

if not operator_args:
operator_args = {}

cache_dir = None
cache_identifier = None
Expand Down Expand Up @@ -275,6 +289,8 @@ def __init__(
task_args,
execution_mode=execution_config.execution_mode,
)
if execution_config.execution_mode == ExecutionMode.VIRTUALENV and execution_config.virtualenv_dir is not None:
task_args["virtualenv_dir"] = execution_config.virtualenv_dir

build_airflow_graph(
nodes=self.dbt_graph.filtered_nodes,
Expand Down
12 changes: 11 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ class DbtGraph:
Supports different ways of loading the `dbt` project into this representation.
Different loading methods can result in different `nodes` and `filtered_nodes`.
Example of how to use:
dbt_graph = DbtGraph(
project=ProjectConfig(dbt_project_path=DBT_PROJECT_PATH),
render_config=RenderConfig(exclude=["*orders*"], select=[]),
dbt_cmd="/usr/local/bin/dbt"
)
dbt_graph.load(method=LoadMode.DBT_LS, execution_mode=ExecutionMode.LOCAL)
"""

nodes: dict[str, DbtNode] = dict()
Expand All @@ -207,6 +216,7 @@ def __init__(
cache_identifier: str = "",
dbt_vars: dict[str, str] | None = None,
airflow_metadata: dict[str, str] | None = None,
operator_args: dict[str, Any] | None = None,
):
self.project = project
self.render_config = render_config
Expand All @@ -219,6 +229,7 @@ def __init__(
else:
self.dbt_ls_cache_key = ""
self.dbt_vars = dbt_vars or {}
self.operator_args = operator_args or {}

@cached_property
def env_vars(self) -> dict[str, str]:
Expand Down Expand Up @@ -568,7 +579,6 @@ def load_via_dbt_ls_without_cache(self) -> None:
self.run_dbt_deps(dbt_cmd, tmpdir_path, env)

nodes = self.run_dbt_ls(dbt_cmd, self.project_path, tmpdir_path, env)

self.nodes = nodes
self.filtered_nodes = nodes

Expand Down
180 changes: 135 additions & 45 deletions cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from functools import cached_property
import os
import shutil
import time
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Callable

import psutil
from airflow.utils.python_virtualenv import prepare_virtualenv

from cosmos import settings
from cosmos.exceptions import CosmosValueError
from cosmos.hooks.subprocess import FullOutputSubprocessResult
from cosmos.log import get_logger
from cosmos.operators.local import (
Expand All @@ -25,10 +30,19 @@
if TYPE_CHECKING:
from airflow.utils.context import Context


PY_INTERPRETER = "python3"
LOCK_FILENAME = "cosmos_virtualenv.lock"
logger = get_logger(__name__)


PY_INTERPRETER = "python3"
def depends_on_virtualenv_dir(method: Callable[[Any], Any]) -> Callable[[Any], Any]:
def wrapper(operator: DbtVirtualenvBaseOperator, *args: Any) -> Any:
if operator.virtualenv_dir is None:
raise CosmosValueError(f"Method relies on value of parameter `virtualenv_dir` which is None.")
return method(operator, *args)

return wrapper


class DbtVirtualenvBaseOperator(DbtLocalBaseOperator):
Expand All @@ -42,71 +56,147 @@ class DbtVirtualenvBaseOperator(DbtLocalBaseOperator):
:param py_system_site_packages: Whether or not all the Python packages from the Airflow instance will be accessible
within the virtual environment (if py_requirements argument is specified).
Avoid using unless the dbt job requires it.
:param virtualenv_dir: Directory path where Cosmos will create/update Python virtualenv. If defined, will persist the Python virtualenv in the Airflow worker node.
:param is_virtualenv_dir_temporary: Tells Cosmos if virtualenv should be persisted or not.
"""

template_fields = DbtLocalBaseOperator.template_fields + ("virtualenv_dir", "is_virtualenv_dir_temporary") # type: ignore[operator]

def __init__(
self,
py_requirements: list[str] | None = None,
pip_install_options: list[str] | None = None,
py_system_site_packages: bool = False,
virtualenv_dir: Path | None = None,
is_virtualenv_dir_temporary: bool = False,
**kwargs: Any,
) -> None:
self.py_requirements = py_requirements or []
self.pip_install_options = pip_install_options or []
self.py_system_site_packages = py_system_site_packages
self.virtualenv_dir = virtualenv_dir
self.is_virtualenv_dir_temporary = is_virtualenv_dir_temporary
self.max_retries_lock = settings.virtualenv_max_retries_lock
super().__init__(**kwargs)
self._venv_tmp_dir: None | TemporaryDirectory[str] = None

@cached_property
def venv_dbt_path(
self,
) -> str:
"""
Path to the dbt binary within a Python virtualenv.
The first time this property is called, it creates a virtualenv and installs the dependencies based on the
self.py_requirements, self.pip_install_options, and self.py_system_site_packages. This value is cached for future calls.
"""
# We are reusing the virtualenv directory for all subprocess calls within this task/operator.
# For this reason, we are not using contexts at this point.
# The deletion of this directory is done explicitly at the end of the `execute` method.
self._venv_tmp_dir = TemporaryDirectory(prefix="cosmos-venv")
py_interpreter = prepare_virtualenv(
venv_directory=self._venv_tmp_dir.name,
python_bin=PY_INTERPRETER,
system_site_packages=self.py_system_site_packages,
requirements=self.py_requirements,
pip_install_options=self.pip_install_options,
)
dbt_binary = Path(py_interpreter).parent / "dbt"
cmd_output = self.subprocess_hook.run_command(
[
py_interpreter,
"-c",
"from importlib.metadata import version; print(version('dbt-core'))",
]
)
dbt_version = cmd_output.output
self.log.info("Using dbt version %s available at %s", dbt_version, dbt_binary)
return str(dbt_binary)
if not self.py_requirements:
self.log.error("Cosmos virtualenv operators require the `py_requirements` parameter")

def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> FullOutputSubprocessResult:
if self.py_requirements:
command[0] = self.venv_dbt_path

subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(
# No virtualenv_dir set, so create a temporary virtualenv
if self.virtualenv_dir is None or self.is_virtualenv_dir_temporary:
self.log.info("Creating temporary virtualenv")
with TemporaryDirectory(prefix="cosmos-venv") as tempdir:
self.virtualenv_dir = Path(tempdir)
py_bin = self._prepare_virtualenv()
dbt_bin = str(Path(py_bin).parent / "dbt")
command[0] = dbt_bin # type: ignore
subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(
command=command,
env=env,
cwd=cwd,
output_encoding=self.output_encoding,
)
return subprocess_result

# Use a reusable virtualenv
self.log.info(f"Checking if the virtualenv lock {str(self._lock_file)} exists")
while not self._is_lock_available() and self.max_retries_lock:
logger.info("Waiting for virtualenv lock to be released")
time.sleep(1)
self.max_retries_lock -= 1

self.log.info(f"Acquiring the virtualenv lock")
self._acquire_venv_lock()
py_bin = self._prepare_virtualenv()
dbt_bin = str(Path(py_bin).parent / "dbt")
command[0] = dbt_bin # type: ignore
subprocess_result = self.subprocess_hook.run_command(
command=command,
env=env,
cwd=cwd,
output_encoding=self.output_encoding,
)
self.log.info("Releasing virtualenv lock")
self._release_venv_lock()
return subprocess_result

def clean_dir_if_temporary(self) -> None:
"""
Delete the virtualenv directory if it is temporary.
"""
if self.is_virtualenv_dir_temporary and self.virtualenv_dir and self.virtualenv_dir.exists():
self.log.info(f"Deleting the Python virtualenv {self.virtualenv_dir}")
shutil.rmtree(str(self.virtualenv_dir), ignore_errors=True)

def execute(self, context: Context) -> None:
output = super().execute(context)
if self._venv_tmp_dir:
self._venv_tmp_dir.cleanup()
self.log.info(output)
try:
output = super().execute(context)
self.log.info(output)
finally:
self.clean_dir_if_temporary()

def on_kill(self) -> None:
self.clean_dir_if_temporary()

def _prepare_virtualenv(self) -> str:
self.log.info(f"Creating or updating the virtualenv at `{self.virtualenv_dir}")
py_bin = prepare_virtualenv(
venv_directory=str(self.virtualenv_dir),
python_bin=PY_INTERPRETER,
system_site_packages=self.py_system_site_packages,
requirements=self.py_requirements,
pip_install_options=self.pip_install_options,
)
return py_bin

@property
def _lock_file(self) -> Path:
filepath = Path(f"{self.virtualenv_dir}/{LOCK_FILENAME}")
return filepath

@property
def _pid(self) -> int:
return os.getpid()

@depends_on_virtualenv_dir
def _is_lock_available(self) -> bool:
is_available = True
if self._lock_file.is_file():
with open(self._lock_file) as lf:
pid = int(lf.read())
self.log.info(f"Checking for running process with PID {pid}")
try:
_process_running = psutil.Process(pid).is_running()
self.log.info(f"Process {pid} running: {_process_running} and has the lock {self._lock_file}.")
except psutil.NoSuchProcess:
self.log.info(f"Process {pid} is not running. Lock {self._lock_file} was outdated.")
is_available = True
else:
is_available = not _process_running
return is_available

@depends_on_virtualenv_dir
def _acquire_venv_lock(self) -> None:
if not self.virtualenv_dir.is_dir(): # type: ignore
os.mkdir(str(self.virtualenv_dir))

with open(self._lock_file, "w") as lf:
logger.info(f"Acquiring lock at {self._lock_file} with pid {str(self._pid)}")
lf.write(str(self._pid))

@depends_on_virtualenv_dir
def _release_venv_lock(self) -> None:
if not self._lock_file.is_file():
logger.warning(f"Lockfile {self._lock_file} not found, perhaps deleted by other concurrent operator?")
return

with open(self._lock_file) as lf:
lock_file_pid = int(lf.read())

if lock_file_pid == self._pid:
return self._lock_file.unlink()

logger.warning(f"Lockfile owned by process of pid {lock_file_pid}, while operator has pid {self._pid}")


class DbtBuildVirtualenvOperator(DbtVirtualenvBaseOperator, DbtBuildLocalOperator): # type: ignore[misc]
Expand Down
1 change: 1 addition & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
dbt_docs_index_file_name = conf.get("cosmos", "dbt_docs_index_file_name", fallback="index.html")
enable_cache_profile = conf.getboolean("cosmos", "enable_cache_profile", fallback=True)
dbt_profile_cache_dir_name = conf.get("cosmos", "profile_cache_dir_name", fallback="profile")
virtualenv_max_retries_lock = conf.getint("cosmos", "virtualenv_max_retries_lock", fallback=120)

# Experimentally adding `remote_cache_dir` as a separate entity in the Cosmos 1.6 release to gather feedback.
# This will be merged with the `cache_dir` config parameter in upcoming releases.
Expand Down
Loading

0 comments on commit 4273d99

Please sign in to comment.