Skip to content

Commit

Permalink
add workflow handler and endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-i committed Sep 9, 2024
1 parent abedadf commit 6c9de64
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 4 deletions.
93 changes: 90 additions & 3 deletions jupyter_scheduler/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@
import tarfile
import traceback
from abc import ABC, abstractmethod
from typing import Dict
from functools import lru_cache
from typing import Dict, List

import fsspec
import nbconvert
import nbformat
from nbconvert.preprocessors import CellExecutionError, ExecutePreprocessor
from prefect import flow, task
from prefect.futures import as_completed
from prefect_dask.task_runners import DaskTaskRunner

from jupyter_scheduler.models import DescribeJob, JobFeature, Status
from jupyter_scheduler.orm import Job, create_session
from jupyter_scheduler.orm import Job, Workflow, create_session
from jupyter_scheduler.parameterize import add_parameters
from jupyter_scheduler.utils import get_utc_timestamp
from jupyter_scheduler.workflows import DescribeWorkflow


class ExecutionManager(ABC):
Expand All @@ -29,14 +34,29 @@ class ExecutionManager(ABC):
_model = None
_db_session = None

def __init__(self, job_id: str, root_dir: str, db_url: str, staging_paths: Dict[str, str]):
def __init__(
self,
job_id: str,
workflow_id: str,
root_dir: str,
db_url: str,
staging_paths: Dict[str, str],
):
self.job_id = job_id
self.workflow_id = workflow_id
self.staging_paths = staging_paths
self.root_dir = root_dir
self.db_url = db_url

@property
def model(self):
if self.workflow_id:
with self.db_session() as session:
workflow = (
session.query(Workflow).filter(Workflow.workflow_id == self.workflow_id).first()
)
self._model = DescribeWorkflow.from_orm(workflow)
return self._model
if self._model is None:
with self.db_session() as session:
job = session.query(Job).filter(Job.job_id == self.job_id).first()
Expand Down Expand Up @@ -65,6 +85,18 @@ def process(self):
else:
self.on_complete()

def process_workflow(self):

self.before_start_workflow()
try:
self.execute_workflow()
except CellExecutionError as e:
self.on_failure_workflow(e)
except Exception as e:
self.on_failure_workflow(e)
else:
self.on_complete_workflow()

@abstractmethod
def execute(self):
"""Performs notebook execution,
Expand All @@ -74,6 +106,11 @@ def execute(self):
"""
pass

@abstractmethod
def execute_workflow(self):
"""Performs workflow execution"""
pass

@classmethod
@abstractmethod
def supported_features(cls) -> Dict[JobFeature, bool]:
Expand All @@ -98,6 +135,15 @@ def before_start(self):
)
session.commit()

def before_start_workflow(self):
"""Called before start of execute"""
workflow = self.model
with self.db_session() as session:
session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update(
{"status": Status.IN_PROGRESS}
)
session.commit()

def on_failure(self, e: Exception):
"""Called after failure of execute"""
job = self.model
Expand All @@ -109,6 +155,17 @@ def on_failure(self, e: Exception):

traceback.print_exc()

def on_failure_workflow(self, e: Exception):
"""Called after failure of execute"""
workflow = self.model
with self.db_session() as session:
session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update(
{"status": Status.FAILED, "status_message": str(e)}
)
session.commit()

traceback.print_exc()

def on_complete(self):
"""Called after job is completed"""
job = self.model
Expand All @@ -118,10 +175,40 @@ def on_complete(self):
)
session.commit()

def on_complete_workflow(self):
workflow = self.model
with self.db_session() as session:
session.query(Workflow).filter(Workflow.workflow_id == workflow.workflow_id).update(
{"status": Status.COMPLETED}
)
session.commit()


class DefaultExecutionManager(ExecutionManager):
"""Default execution manager that executes notebooks"""

@task(task_run_name="{task_id}")
def execute_task(task_id: str):
print(f"Task {task_id} executed")
return task_id

@flow(task_runner=DaskTaskRunner())
def execute_workflow(self):
workflow: DescribeWorkflow = self.model
tasks = {task["id"]: task for task in workflow.tasks}

# create Prefect tasks, use caching to ensure Prefect tasks are created before wait_for is called on them
@lru_cache(maxsize=None)
def make_task(task_id, execute_task):
deps = tasks[task_id]["dependsOn"]
return execute_task.submit(
task_id, wait_for=[make_task(dep_id, execute_task) for dep_id in deps]
)

final_tasks = [make_task(task_id, self.execute_task) for task_id in tasks]
for future in as_completed(final_tasks):
print(future.result())

def execute(self):
job = self.model

Expand Down
2 changes: 2 additions & 0 deletions jupyter_scheduler/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from traitlets import Bool, Type, Unicode, default

from jupyter_scheduler.orm import create_tables
from jupyter_scheduler.workflows import WorkflowHandler

from .handlers import (
BatchJobHandler,
Expand Down Expand Up @@ -35,6 +36,7 @@ class SchedulerApp(ExtensionApp):
(r"scheduler/job_definitions/%s/jobs" % JOB_DEFINITION_ID_REGEX, JobFromDefinitionHandler),
(r"scheduler/runtime_environments", RuntimeEnvironmentsHandler),
(r"scheduler/config", ConfigHandler),
(r"scheduler/worklows", WorkflowHandler),
]

drop_tables = Bool(False, config=True, help="Drop the database tables before starting.")
Expand Down
10 changes: 10 additions & 0 deletions jupyter_scheduler/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ class Job(CommonColumns, Base):
# Any default values specified for new columns will be ignored during the migration process.


class Workflow(Base):
__tablename__ = "workflows"
__table_args__ = {"extend_existing": True}
workflow_id = Column(String(36), primary_key=True, default=generate_uuid)
tasks = Column(JsonType(1024))
status = Column(String(64), default=Status.STOPPED)
# All new columns added to this table must be nullable to ensure compatibility during database migrations.
# Any default values specified for new columns will be ignored during the migration process.


class JobDefinition(CommonColumns, Base):
__tablename__ = "job_definitions"
__table_args__ = {"extend_existing": True}
Expand Down
28 changes: 27 additions & 1 deletion jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@
UpdateJob,
UpdateJobDefinition,
)
from jupyter_scheduler.orm import Job, JobDefinition, create_session
from jupyter_scheduler.orm import Job, JobDefinition, Workflow, create_session
from jupyter_scheduler.utils import (
copy_directory,
create_output_directory,
create_output_filename,
)
from jupyter_scheduler.workflows import CreateWorkflow


class BaseScheduler(LoggingConfigurable):
Expand Down Expand Up @@ -111,6 +112,10 @@ def create_job(self, model: CreateJob) -> str:
"""
raise NotImplementedError("must be implemented by subclass")

def create_workflow(self, model: CreateWorkflow) -> str:
"""Creates a new workflow record, may trigger execution of the workflow."""
raise NotImplementedError("must be implemented by subclass")

def update_job(self, job_id: str, model: UpdateJob):
"""Updates job metadata in the persistence store,
for example name, status etc. In case of status
Expand Down Expand Up @@ -526,6 +531,27 @@ def create_job(self, model: CreateJob) -> str:

return job_id

def create_workflow(self, model: CreateWorkflow) -> str:

with self.db_session() as session:

workflow = Workflow(**model.dict(exclude_none=True))

session.add(workflow)
session.commit()

execution_manager = self.execution_manager_class(
workflow_id=workflow.workflow_id,
root_dir=self.root_dir,
db_url=self.db_url,
)
execution_manager.process_workflow()
session.commit()

workflow_id = workflow.workflow_id

return workflow_id

def update_job(self, job_id: str, model: UpdateJob):
with self.db_session() as session:
session.query(Job).filter(Job.job_id == job_id).update(model.dict(exclude_none=True))
Expand Down
57 changes: 57 additions & 0 deletions jupyter_scheduler/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import json
from typing import List

from jupyter_server.utils import ensure_async
from tornado.web import HTTPError, authenticated

from jupyter_scheduler.exceptions import (
IdempotencyTokenError,
InputUriError,
SchedulerError,
)
from jupyter_scheduler.handlers import (
APIHandler,
ExtensionHandlerMixin,
JobHandlersMixin,
)
from jupyter_scheduler.models import Status
from jupyter_scheduler.pydantic_v1 import BaseModel, ValidationError


class WorkflowHandler(ExtensionHandlerMixin, JobHandlersMixin, APIHandler):
@authenticated
async def post(self):
payload = self.get_json_body()
try:
workflow_id = await ensure_async(
self.scheduler.create_workflow(CreateWorkflow(**payload))
)
self.log.info(payload)
print(payload)
except ValidationError as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e
except InputUriError as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e
except IdempotencyTokenError as e:
self.log.exception(e)
raise HTTPError(409, str(e)) from e
except SchedulerError as e:
self.log.exception(e)
raise HTTPError(500, str(e)) from e
except Exception as e:
self.log.exception(e)
raise HTTPError(500, "Unexpected error occurred during creation of a workflow.") from e
else:
self.finish(json.dumps(dict(workflow_id=workflow_id)))


class CreateWorkflow(BaseModel):
tasks: List[str]


class DescribeWorkflow(BaseModel):
workflow_id: str
tasks: List[str] = None
status: Status = Status.CREATED

0 comments on commit 6c9de64

Please sign in to comment.