Source code for otter.task.task_reporter

"""TaskReporter class and report decorator for logging and updating tasks in the manifest."""

from __future__ import annotations

import asyncio
import sys
from collections.abc import Awaitable
from datetime import UTC, datetime
from functools import wraps
from typing import TYPE_CHECKING

from loguru import logger

from otter.manifest.model import Artifact, Result, TaskManifest
from otter.util.errors import TaskAbortedError

if TYPE_CHECKING:
    from collections.abc import Callable
    from typing import Any

    from otter.task.model import Task


[docs] class TaskReporter: """Class for logging and updating tasks in the manifest.""" def __init__(self, name: str) -> None: self.name = name self.manifest: TaskManifest = TaskManifest(name=name) @property def artifacts(self) -> list[Artifact] | None: """Return the `Artifacts` associated with the `Task`.""" return self.manifest.artifacts @artifacts.setter def artifacts(self, artifacts: list[Artifact]) -> None: """Set the `Artifact` associated with the `Task`.""" self.manifest.artifacts = artifacts
[docs] def start_run(self) -> None: """Update a task that has started running.""" self.manifest.started_run_at = datetime.now(UTC) logger.info(f'task {self.name} started running')
[docs] def finish_run(self, done: bool = False) -> None: """Update a task that has finished running.""" self.manifest.finished_run_at = datetime.now(UTC) if done: self.manifest.result = Result.SUCCESS logger.success(f'task {self.name} finished running: took {self.manifest.run_elapsed:.3f}s')
[docs] def start_validation(self) -> None: """Update a task that has started validation.""" self.manifest.started_validation_at = datetime.now(UTC) logger.info(f'task {self.name} started validation')
[docs] def finish_validation(self) -> None: """Update a task that has finished validation.""" self.manifest.finished_validation_at = datetime.now(UTC) self.manifest.result = Result.SUCCESS logger.success(f'task {self.name} finished validation: took {self.manifest.validation_elapsed:.3f}s') logger.success(f'task {self.name} completed: took {self.manifest.elapsed:.3f}s')
[docs] def abort(self) -> None: """Update a task that has been aborted.""" self.manifest.result = Result.ABORTED logger.warning(f'task {self.name} aborted')
[docs] def fail(self, error: Exception, where: str) -> None: """Update a task that has failed running or validation.""" self.manifest.result = Result.FAILURE logger.opt(exception=sys.exc_info()).error(f'task {where} failed: {error}')
[docs] def report(func: Callable[..., Task] | Callable[..., Awaitable[Task]]) -> Callable[..., Awaitable[Task]]: """Decorator for logging and updating tasks in the manifest.""" @wraps(func) async def wrapper(self: Task, *args: Any, **kwargs: Any) -> Task: name = getattr(func, '__name__', None) if name is None: raise ValueError('wrapped function must have a __name__ attribute') try: # perform these before the wrapped method runs if name == 'run': self.start_run() elif name == 'validate': self.start_validation() # call the wrapped method (handle both async and sync) result: Task if asyncio.iscoroutinefunction(func): result = await func(self, *args, **kwargs) else: result = func(self, *args, **kwargs) # type: ignore[assignment] # perform these after the wrapped method runs if name == 'run': self.finish_run(done=not self.has_validation()) elif name == 'validate': self.finish_validation() return result # handle exceptions except Exception as e: self.context.abort.set() if isinstance(e, TaskAbortedError): self.abort() else: self.fail(e, name) return self return wrapper