71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
import logging
|
|
|
|
import requests
|
|
from celery import shared_task
|
|
from django.conf import settings
|
|
from django.utils import timezone
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@shared_task(bind=True, max_retries=3, default_retry_delay=60, queue="splat_jobs")
|
|
def dispatch_splat_job(self, splat_job_id: str):
|
|
"""
|
|
Submit a splatting job to the RunPod serverless endpoint.
|
|
Stores the RunPod job ID on the SplatJob record so incoming
|
|
webhooks can be matched back to it.
|
|
"""
|
|
from apps.jobs.models import SplatJob
|
|
from apps.splats.models import Splat
|
|
from apps.utils.storage import generate_presigned_get_url
|
|
|
|
try:
|
|
job = SplatJob.objects.select_related("splat").get(id=splat_job_id)
|
|
except SplatJob.DoesNotExist:
|
|
logger.error("SplatJob %s not found", splat_job_id)
|
|
return
|
|
|
|
splat = job.splat
|
|
|
|
# Generate a time-limited download URL for RunPod to fetch the video
|
|
video_url = generate_presigned_get_url(splat.video_key, expires_in=7200)
|
|
if video_url is None:
|
|
# Development — no real storage, bail out gracefully
|
|
logger.info("Skipping RunPod dispatch in dev (no S3 storage): job %s", splat_job_id)
|
|
return
|
|
|
|
webhook_url = f"{settings.API_BASE_URL}/api/v1/jobs/webhook/"
|
|
|
|
payload = {
|
|
"input": {
|
|
"video_url": video_url,
|
|
"splat_id": str(splat.id),
|
|
"job_id": str(job.id),
|
|
"webhook_url": webhook_url,
|
|
"webhook_secret": settings.WEBHOOK_SECRET,
|
|
}
|
|
}
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"https://api.runpod.io/v2/{settings.RUNPOD_ENDPOINT_ID}/run",
|
|
json=payload,
|
|
headers={"Authorization": f"Bearer {settings.RUNPOD_API_KEY}"},
|
|
timeout=15,
|
|
)
|
|
response.raise_for_status()
|
|
except requests.RequestException as exc:
|
|
logger.exception("RunPod dispatch failed for job %s", splat_job_id)
|
|
raise self.retry(exc=exc)
|
|
|
|
runpod_job_id = response.json()["id"]
|
|
|
|
SplatJob.objects.filter(pk=job.pk).update(
|
|
runpod_job_id=runpod_job_id,
|
|
status=SplatJob.Status.RUNNING,
|
|
started_at=timezone.now(),
|
|
)
|
|
Splat.objects.filter(pk=splat.pk).update(status=Splat.Status.PROCESSING)
|
|
|
|
logger.info("Dispatched RunPod job %s for splat %s", runpod_job_id, splat.id)
|