This article discusses the implementation details of StartTriggerArgs
, which may not interest most Airflow users. However, if you're a contributor or simply interested in how Airflow works, let's dive in.
I'll focus on PR Enhance start_trigger_args serialization #40993, which fixed an issue in Prevent start trigger initialization in scheduler #39585.
What's the issue?
Following the same example as the previous article
from airflow.sensors.date_time import DateTimeSensorAsync
from airflow.triggers.base import StartTriggerArgs
class ExecuteFromTriggerDateTimeSensor(DateTimeSensor):
start_trigger_args: StartTriggerArgs = StartTriggerArgs(
trigger_cls="airflow.triggers.temporal.DateTimeTrigger",
trigger_kwargs={"moment": "", "end_from_trigger": False},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger: bool = True
def __init__(self, *, trigger_kwargs, **kwargs) -> None:
super().__init__(**kwargs)
self.start_trigger_args.trigger_kwargs = dict(
moment=timezone.parse(self.target_time),
end_from_trigger=self.end_from_trigger,
)
def execute_complete(self, context, event) -> None:
return
This works fine now, but if you check out commit fa92727~ (one commit before this fix), you'll encounter error message like the following
TypeError: Object of type datetime is not JSON serializable
Surely Airflow can do better.
Why is this happening?
In Prevent start trigger initialization in scheduler #39585, I implemented the StartTriggerArgs.serialize
method this way
@dataclass
class StartTriggerArgs:
"""Arguments required for start task execution from triggerer."""
trigger_cls: str
next_method: str
trigger_kwargs: dict[str, Any] | None = None
timeout: timedelta | None = None
def serialize(self):
return {
"trigger_cls": self.trigger_cls,
"trigger_kwargs": self.trigger_kwargs,
"next_method": self.next_method,
"timeout": self.timeout,
}
and wish Airflow would magically handle everything for me.
But, ...
Why do we need serialization here?
Before we discuss how to make the serialization work correctly, we should probably also consider why we need serialization.
Let's bring back the diagram in the first article.
In step 3, we raise a TaskDeferred exception. It will be caught in airflow/models/taskinstance.py::_run_raw_task, which calls _defer_task after defer_task
and writes a trigger row into the trigger table in the database here. That's why it needs to be serialized.
session.add(trigger_row)
session.flush()
How do we get it resolved?
Returning to the PR Enhance start_trigger_args serialization #40993.
We are aware of the locations for serializing StartTriggerArgs and deserializing StartTriggerArgs. The current implementation is incorrect, so our objective is to rewrite the serialize/deserialize method.
Within the arguments of StartTriggerArgs, trigger_cls
and next_method
are either strings or None, so there isn't much to do with them. However, we'll need to make some additional adjustments for trigger_kwargs
and next_kwargs
, which are dictionaries, and for timeout
, which is a timedelta.
Since logic for different data types is already implemented in airflow/serialization/serialized_objects.py, we should leverage them. In BaseSerialization.serialize, there's already logic that handles dictionaries so what we need to do is extracting trigger_kwargs
and next_kwargs
and pass them into this method. As for timeout
, we already know it's timedetla and can be handled quite easily, there's no need to call the BaseSerialization.serialize
method again.
def encode_start_trigger_args(var: StartTriggerArgs) -> dict[str, Any]:
"""
Encode a StartTriggerArgs.
:meta private:
"""
serialize_kwargs = lambda key: (
BaseSerialization.serialize(getattr(var, key)) if getattr(var, key) is not None else None
)
return {
"__type": "START_TRIGGER_ARGS",
"trigger_cls": var.trigger_cls,
"trigger_kwargs": serialize_kwargs("trigger_kwargs"),
"next_method": var.next_method,
"next_kwargs": serialize_kwargs("next_kwargs"),
"timeout": var.timeout.total_seconds() if var.timeout else None,
}
For deserialization, we're following a similar approach, but in reverse.
def decode_start_trigger_args(var: dict[str, Any]) -> StartTriggerArgs:
"""
Decode a StartTriggerArgs.
:meta private:
"""
deserialize_kwargs = lambda key: BaseSerialization.deserialize(var[key]) if var[key] is not None else None
return StartTriggerArgs(
trigger_cls=var["trigger_cls"],
trigger_kwargs=deserialize_kwargs("trigger_kwargs"),
next_method=var["next_method"],
next_kwargs=deserialize_kwargs("next_kwargs"),
timeout=datetime.timedelta(seconds=var["timeout"]) if var["timeout"] else None,
)