218 lines
7.4 KiB
Python
218 lines
7.4 KiB
Python
import logging
|
|
|
|
from django.contrib.gis.geos import Point, GEOSGeometry
|
|
from django.utils import timezone
|
|
from rest_framework import status
|
|
from rest_framework.response import Response
|
|
from rest_framework.views import APIView
|
|
|
|
from apps.common.permissions import WebhookPermission
|
|
from .models import SplatJob
|
|
from .serializers import SplatJobSerializer, WebhookInputSerializer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class JobDetailView(APIView):
|
|
def get(self, request, pk):
|
|
try:
|
|
job = SplatJob.objects.select_related("splat__owner").get(pk=pk)
|
|
except SplatJob.DoesNotExist:
|
|
return Response(status=status.HTTP_404_NOT_FOUND)
|
|
|
|
if job.splat.owner != request.user:
|
|
return Response(status=status.HTTP_404_NOT_FOUND)
|
|
|
|
return Response(SplatJobSerializer(job).data)
|
|
|
|
|
|
class JobWebhookView(APIView):
|
|
authentication_classes = []
|
|
permission_classes = [WebhookPermission]
|
|
|
|
def post(self, request):
|
|
serializer = WebhookInputSerializer(data=request.data)
|
|
serializer.is_valid(raise_exception=True)
|
|
payload = serializer.validated_data
|
|
|
|
try:
|
|
job = SplatJob.objects.select_related("splat__owner", "splat__challenge__creator").get(
|
|
runpod_job_id=payload["job_id"]
|
|
)
|
|
except SplatJob.DoesNotExist:
|
|
logger.warning("Webhook received for unknown RunPod job ID: %s", payload["job_id"])
|
|
return Response(status=status.HTTP_404_NOT_FOUND)
|
|
|
|
webhook_status = payload["status"]
|
|
|
|
if webhook_status == "step_complete":
|
|
_handle_step_complete(job, payload)
|
|
|
|
elif webhook_status == "succeeded":
|
|
_handle_succeeded(job, payload)
|
|
|
|
elif webhook_status == "failed":
|
|
_handle_failed(job, payload)
|
|
|
|
return Response(status=status.HTTP_200_OK)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal webhook handlers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _handle_step_complete(job, payload):
|
|
step = payload.get("step", "")
|
|
progress = payload.get("progress", job.progress)
|
|
|
|
logs = dict(job.pipeline_logs)
|
|
logs[step] = payload.get("output", {})
|
|
|
|
SplatJob.objects.filter(pk=job.pk).update(
|
|
current_step=step,
|
|
progress=progress,
|
|
pipeline_logs=logs,
|
|
)
|
|
|
|
|
|
def _handle_succeeded(job, payload):
|
|
from apps.splats.models import Splat
|
|
from apps.utils.fcm import send_notification
|
|
from django.conf import settings
|
|
|
|
output = payload.get("output", {})
|
|
splat = job.splat
|
|
|
|
# Update spatial fields
|
|
splat_updates = {
|
|
"splat_key": output.get("splat_key", ""),
|
|
"preview_key": output.get("preview_key", ""),
|
|
"splat_file_size": output.get("splat_file_size"),
|
|
"quality_score": output.get("quality_score"),
|
|
"frame_count": output.get("frame_count"),
|
|
}
|
|
|
|
location_coords = output.get("location")
|
|
if location_coords:
|
|
splat_updates["location"] = Point(location_coords[0], location_coords[1], srid=4326)
|
|
|
|
splat_updates["altitude"] = output.get("altitude")
|
|
splat_updates["heading"] = output.get("heading")
|
|
|
|
coverage_geojson = output.get("coverage")
|
|
if coverage_geojson:
|
|
try:
|
|
import json
|
|
splat_updates["coverage"] = GEOSGeometry(json.dumps(coverage_geojson), srid=4326)
|
|
except Exception:
|
|
logger.exception("Failed to parse coverage GeoJSON for splat %s", splat.id)
|
|
|
|
# Quality gate
|
|
thresholds = settings.SPLAT_QUALITY_THRESHOLDS
|
|
colmap_points = output.get("colmap_points") or 0
|
|
quality_score = output.get("quality_score") or 0.0
|
|
frame_count = output.get("frame_count") or 0
|
|
|
|
passed = (
|
|
colmap_points >= thresholds["min_colmap_points"]
|
|
and quality_score >= thresholds["min_quality_score"]
|
|
and frame_count >= thresholds["min_frame_count"]
|
|
)
|
|
|
|
splat_updates["status"] = Splat.Status.READY
|
|
splat_updates["is_published"] = passed
|
|
|
|
Splat.objects.filter(pk=splat.pk).update(**splat_updates)
|
|
|
|
SplatJob.objects.filter(pk=job.pk).update(
|
|
status=SplatJob.Status.SUCCEEDED,
|
|
progress=100,
|
|
current_step=SplatJob.Step.QUALITY_CHECK,
|
|
colmap_points=colmap_points,
|
|
pipeline_logs={**job.pipeline_logs, "quality_gate": {"passed": passed}},
|
|
finished_at=timezone.now(),
|
|
)
|
|
|
|
# Notify splat owner
|
|
if passed:
|
|
send_notification(
|
|
splat.owner.fcm_token,
|
|
title="Your splat is ready!",
|
|
body="Your recording has been processed and is now visible on the map.",
|
|
data={"splat_id": str(splat.id), "type": "splat_ready"},
|
|
)
|
|
else:
|
|
send_notification(
|
|
splat.owner.fcm_token,
|
|
title="Splat processing complete",
|
|
body="Your recording was processed but did not meet quality thresholds.",
|
|
data={"splat_id": str(splat.id), "type": "splat_quality_failed"},
|
|
)
|
|
|
|
# If this splat is linked to a challenge, update submission count and notify
|
|
if passed and splat.challenge_id:
|
|
_handle_challenge_submission(splat)
|
|
|
|
|
|
def _handle_failed(job, payload):
|
|
from apps.splats.models import Splat
|
|
from apps.utils.fcm import send_notification
|
|
|
|
error_message = payload.get("error", "Unknown pipeline error")
|
|
|
|
SplatJob.objects.filter(pk=job.pk).update(
|
|
status=SplatJob.Status.FAILED,
|
|
error_message=error_message,
|
|
finished_at=timezone.now(),
|
|
)
|
|
Splat.objects.filter(pk=job.splat_id).update(status=Splat.Status.FAILED)
|
|
|
|
send_notification(
|
|
job.splat.owner.fcm_token,
|
|
title="Recording failed",
|
|
body="There was a problem processing your recording. Please try again.",
|
|
data={"splat_id": str(job.splat_id), "type": "splat_failed"},
|
|
)
|
|
|
|
logger.error("Splat job %s failed: %s", job.id, error_message)
|
|
|
|
|
|
def _handle_challenge_submission(splat):
|
|
"""
|
|
Increment challenge submission count and dispatch FCM notifications
|
|
to the challenge creator and all other participants.
|
|
"""
|
|
from apps.challenges.models import Challenge, ChallengeParticipant
|
|
from apps.utils.fcm import send_notification
|
|
|
|
Challenge.objects.filter(pk=splat.challenge_id).update(
|
|
submission_count=Challenge.objects.filter(pk=splat.challenge_id).values("submission_count")[0]["submission_count"] + 1
|
|
)
|
|
|
|
# Re-fetch to check max_submissions
|
|
challenge = Challenge.objects.get(pk=splat.challenge_id)
|
|
if challenge.max_submissions and challenge.submission_count >= challenge.max_submissions:
|
|
Challenge.objects.filter(pk=challenge.pk).update(status=Challenge.Status.CLOSED)
|
|
|
|
# Notify creator
|
|
send_notification(
|
|
challenge.creator.fcm_token,
|
|
title="New submission to your challenge!",
|
|
body=f'Someone submitted a splat for "{challenge.title}".',
|
|
data={"challenge_id": str(challenge.id), "splat_id": str(splat.id), "type": "challenge_submission"},
|
|
)
|
|
|
|
# Notify other participants (excluding the submitter)
|
|
participants = (
|
|
ChallengeParticipant.objects.filter(challenge=challenge)
|
|
.exclude(user=splat.owner)
|
|
.select_related("user")
|
|
)
|
|
for participant in participants:
|
|
send_notification(
|
|
participant.user.fcm_token,
|
|
title="New splat on a challenge you joined",
|
|
body=f'A new recording was submitted for "{challenge.title}".',
|
|
data={"challenge_id": str(challenge.id), "type": "challenge_new_splat"},
|
|
)
|