diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 5b51bb0d24df..c1373e5d6a12 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1632,11 +1632,12 @@ def _get_previous_ti( @internal_api_call @provide_session -def _update_rtif(ti, rendered_fields, session: Session | None = None): +def _update_rtif(ti, rendered_fields, session: Session = NEW_SESSION): from airflow.models.renderedtifields import RenderedTaskInstanceFields rtif = RenderedTaskInstanceFields(ti=ti, render_templates=False, rendered_fields=rendered_fields) RenderedTaskInstanceFields.write(rtif, session=session) + session.flush() RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session) diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index ea22d31871db..6ff87b28a89b 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -24,13 +24,16 @@ from datetime import date, timedelta from unittest import mock +import pendulum import pytest +from sqlalchemy import select from airflow import settings from airflow.configuration import conf from airflow.decorators import task as task_decorator -from airflow.models import Variable +from airflow.models import DagRun, Variable from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF +from airflow.operators.python import PythonOperator from airflow.providers.standard.operators.bash import BashOperator from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.timezone import datetime @@ -386,3 +389,48 @@ def test_redact(self, redact, dag_maker): "env": "val 2", "cwd": "val 3", } + + @pytest.mark.skip_if_database_isolation_mode + def test_rtif_deletion_stale_data_error(self, dag_maker, session): + """ + Here we verify bad behavior. When we rerun a task whose RTIF + will get removed, we get a stale data error. + """ + with dag_maker(dag_id="test_retry_handling"): + task = PythonOperator( + task_id="test_retry_handling_op", + python_callable=lambda a, b: print(f"{a}\n{b}\n"), + op_args=[ + "dag {{dag.dag_id}};", + "try_number {{ti.try_number}};yo", + ], + ) + + def run_task(date): + run_id = f"abc_{date.to_date_string()}" + dr = session.scalar(select(DagRun).where(DagRun.execution_date == date, DagRun.run_id == run_id)) + if not dr: + dr = dag_maker.create_dagrun(execution_date=date, run_id=run_id) + ti = dr.task_instances[0] + ti.state = None + ti.try_number += 1 + session.commit() + ti.task = task + ti.run() + return dr + + base_date = pendulum.datetime(2021, 1, 1) + exec_dates = [base_date.add(days=x) for x in range(40)] + for date_ in exec_dates: + run_task(date=date_) + + session.commit() + session.expunge_all() + + # find oldest date + date = session.scalar( + select(DagRun.execution_date).join(RTIF.dag_run).order_by(DagRun.execution_date).limit(1) + ) + date = pendulum.instance(date) + # rerun the old date. this will fail + run_task(date=date)