Skip to content

Commit

Permalink
Made update apis consistent with REST APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Oct 12, 2022
1 parent 9b08703 commit 2e28d0a
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 38 deletions.
19 changes: 10 additions & 9 deletions jupyter_scheduler/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ async def post(self):
@tornado.web.authenticated
async def patch(self, job_definition_id):
payload = self.get_json_body()
payload["job_definition_id"] = job_definition_id
await ensure_async(self.scheduler.update_job_definition(UpdateJobDefinition(**payload)))
await ensure_async(
self.scheduler.update_job_definition(job_definition_id, UpdateJobDefinition(**payload))
)
self.set_status(204)
self.finish()

Expand Down Expand Up @@ -163,16 +164,16 @@ async def post(self):
async def patch(self, job_id):
payload = self.get_json_body()

if "status" not in payload:
raise tornado.web.HTTPError(500, "Field 'status' missing in request body")
status = payload.get("status", None)
status = Status(status) if status else None

if status and status != Status.STOPPED:
raise tornado.web.HTTPError(500, "Value of 'STOPPED' only allowed for field 'status'")

status = Status(payload.get("status"))
if status == Status.STOPPED:
if status:
await ensure_async(self.scheduler.stop_job(job_id))
else:
await ensure_async(
self.scheduler.update_job(UpdateJob(job_id=job_id, status=str(status)))
)
await ensure_async(self.scheduler.update_job(job_id, UpdateJob(**payload)))

self.set_status(204)
self.finish()
Expand Down
4 changes: 0 additions & 4 deletions jupyter_scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,7 @@ class CountJobsQuery(BaseModel):


class UpdateJob(BaseModel):
job_id: str
end_time: Optional[int] = None
status: Optional[Status] = None
status_message: Optional[str] = None
name: Optional[str] = None
compute_type: Optional[str] = None

Expand Down Expand Up @@ -182,7 +179,6 @@ class Config:


class UpdateJobDefinition(BaseModel):
job_definition_id: str
input_uri: Optional[str]
output_prefix: Optional[str]
runtime_environment_name: Optional[str]
Expand Down
20 changes: 9 additions & 11 deletions jupyter_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_job(self, model: CreateJob) -> str:
pass

@abstractmethod
def update_job(self, model: UpdateJob):
def update_job(self, job_id: str, model: UpdateJob):
"""Updates job metadata in the persistence store,
for example name, status etc. In case of status
change to STOPPED, should call stop_job
Expand Down Expand Up @@ -96,7 +96,7 @@ def create_job_definition(self, model: CreateJobDefinition) -> str:
pass

@abstractmethod
def update_job_definition(self, model: UpdateJobDefinition):
def update_job_definition(self, job_definition_id: str, model: UpdateJobDefinition):
"""Updates job definition metadata in the persistence store,
should only impact all future jobs.
"""
Expand Down Expand Up @@ -180,11 +180,9 @@ def create_job(self, model: CreateJob) -> str:

return job_id

def update_job(self, model: UpdateJob):
def update_job(self, job_id: str, model: UpdateJob):
with self.db_session() as session:
session.query(Job).filter(Job.job_id == model.job_id).update(
model.dict(exclude_none=True)
)
session.query(Job).filter(Job.job_id == job_id).update(model.dict(exclude_none=True))
session.commit()

def list_jobs(self, query: ListJobsQuery) -> ListJobsResponse:
Expand Down Expand Up @@ -280,21 +278,21 @@ def create_job_definition(self, model: CreateJobDefinition) -> str:

return job_definition_id

def update_job_definition(self, model: UpdateJobDefinition):
def update_job_definition(self, job_definition_id: str, model: UpdateJobDefinition):
with self.db_session() as session:
session.query(JobDefinition).filter(
JobDefinition.job_definition_id == model.job_definition_id
).update(model.dict(exclude_none=True, exclude={"job_definition_id"}))
JobDefinition.job_definition_id == job_definition_id
).update(model.dict(exclude_none=True))
session.commit()

schedule = (
session.query(JobDefinition.schedule)
.filter(JobDefinition.job_definition_id == model.job_definition_id)
.filter(JobDefinition.job_definition_id == job_definition_id)
.scalar()
)

if self.task_runner and schedule:
self.task_runner.update_job_definition(model)
self.task_runner.update_job_definition(job_definition_id, model)

def delete_job_definition(self, job_definition_id: str):
with self.db_session() as session:
Expand Down
17 changes: 7 additions & 10 deletions jupyter_scheduler/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def add_job_definition(self, job_definition_id: str):
pass

@abstractmethod
def update_job_definition(self, model: UpdateJobDefinition):
def update_job_definition(self, job_definition_id: str, model: UpdateJobDefinition):
"""This handles update to job definitions"""
pass

Expand All @@ -162,6 +162,8 @@ def resume_jobs(self, job_definition_id: str):


class TaskRunner(BaseTaskRunner):
"""Default task runner"""

def __init__(self, scheduler, run_interval: int) -> None:
self.run_interval = run_interval
self.scheduler = scheduler
Expand Down Expand Up @@ -221,25 +223,23 @@ def add_job_definition(self, job_definition_id: str):
)
)

def update_job_definition(self, model: UpdateJobDefinition):
cache = self.cache.get(model.job_definition_id)
def update_job_definition(self, job_definition_id: str, model: UpdateJobDefinition):
cache = self.cache.get(job_definition_id)
schedule = model.schedule or cache.schedule
timezone = model.timezone or cache.timezone
active = model.active if model.active is not None else cache.active
next_run_time = self.compute_next_run_time(schedule, timezone)

self.cache.update(
model.job_definition_id,
job_definition_id,
UpdateJobDefinitionCache(
timezone=timezone, next_run_time=next_run_time, active=active, schedule=schedule
),
)

if cache.next_run_time != next_run_time and active:
self.queue.push(
JobDefinitionTask(
job_definition_id=model.job_definition_id, next_run_time=next_run_time
)
JobDefinitionTask(job_definition_id=job_definition_id, next_run_time=next_run_time)
)

def delete_job_definition(self, job_definition_id: str):
Expand Down Expand Up @@ -283,7 +283,6 @@ def process_queue(self):
continue

time_diff = self.compute_time_diff(queue_run_time, cache.timezone)
print(f"time difference is : {(time_diff)/1000}")

# if run time is in future
if time_diff < 0:
Expand All @@ -292,10 +291,8 @@ def process_queue(self):
elif time_diff >= (self.run_interval * 1000):
break
else:
print(f"Going to create job with job definition id: {task.job_definition_id}")
self.create_job(task.job_definition_id)
self.queue.pop()
print(f"heap after popping...{self.queue}")
run_time = self.compute_next_run_time(cache.schedule, cache.timezone)
self.cache.update(
task.job_definition_id, UpdateJobDefinitionCache(next_run_time=run_time)
Expand Down
30 changes: 27 additions & 3 deletions jupyter_scheduler/tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SortDirection,
SortField,
Status,
UpdateJob,
)
from jupyter_scheduler.tests.utils import expected_http_error

Expand Down Expand Up @@ -162,11 +163,34 @@ async def test_get_jobs(jp_fetch, params, list_query, jobs_list):
assert actual_job["url"] == expected_job["url"]


async def test_patch_jobs_for_missing_status(jp_fetch):
async def test_patch_jobs_for_status(jp_fetch):
with patch("jupyter_scheduler.scheduler.Scheduler.stop_job") as mock_stop_job:
job_id = "542e0fac-1274-4a78-8340-a850bdb559c8"
body = {"status": "STOPPED"}
response = await jp_fetch(
"scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body)
)
assert response.code == 204
mock_stop_job.assert_called_once_with(job_id)


async def test_patch_jobs_for_invalid_status(jp_fetch):
with pytest.raises(HTTPClientError) as e:
job_id = "542e0fac-1274-4a78-8340-a850bdb559c8"
response = await jp_fetch("scheduler", "jobs", job_id, method="PATCH", body="{}")
assert expected_http_error(e, 500, "Field 'status' missing in request body")
body = {"status": "IN_PROGRESS"}
await jp_fetch("scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body))
assert expected_http_error(e, 500, "Value of 'STOPPED' only allowed for field 'status'")


async def test_patch_jobs(jp_fetch):
with patch("jupyter_scheduler.scheduler.Scheduler.update_job") as mock_update_job:
job_id = "542e0fac-1274-4a78-8340-a850bdb559c8"
body = {"name": "hello world", "compute_type": "compute_type_a"}
response = await jp_fetch(
"scheduler", "jobs", job_id, method="PATCH", body=json.dumps(body)
)
assert response.code == 204
mock_update_job.assert_called_once_with(job_id, UpdateJob(**body))


async def test_patch_jobs_for_stop_job(jp_fetch):
Expand Down
2 changes: 1 addition & 1 deletion jupyter_scheduler/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_update_job_definition(jp_scheduler, load_job_definitions, jp_scheduler_
update = UpdateJobDefinition(
job_definition_id=job_definition_id, schedule=schedule, timezone=timezone
)
jp_scheduler.update_job_definition(update)
jp_scheduler.update_job_definition(job_definition_id, update)

with jp_scheduler_db() as session:
definition = session.get(JobDefinition, job_definition_id)
Expand Down

0 comments on commit 2e28d0a

Please sign in to comment.