肉球でキーボード

MLエンジニアの技術ブログです

実行中のECS TaskのCloudWatchログを標準出力し、タスクの正常終了を判定する

本文中のコードです github.com

やりたいこと

  • boto3 で実行したECS Taskが成功・失敗したか実行側で判定したい
  • ECS Task のCloud Watchログをポーリングし続けて、実行ログを標準出力し続けたい

実装したコード

task_arn, ecs_cluster を引数に受け取り、ECS TaskのCloud Watchログをポーリングして標準出力し、ECS Taskが正常・異常終了したか判定する EcsTaskWatcher クラスを作成しました。
10秒ごとboto3のdescribe_tasksを実行し、前回の実行から新しいログが追加された場合、新しいログを標準主力します。
exit_code = 0 の場合は正常終了、それ以外は異常終了としてログのポーリングを終了します。

import time
from typing import Any, List, Tuple

import boto3


class ECSTaskExecutionError(Exception):
    """ECS Task is failed."""

    pass


class EcsTaskWatcher:
    def __init__(self, cluster: str, task_arn: str) -> None:
        self.ecs = boto3.client("ecs")
        self.cloudwatch = boto3.client("logs")
        self.cluster = cluster
        self.task_arn = task_arn

        self.log_group_name, self.log_stream_name = self._get_log_setting()

        self.previous_logs = []

    def _get_log_setting(self) -> Tuple[str, str]:
        task_description = self.ecs.describe_tasks(
            cluster=self.cluster,
            tasks=[
                self.task_arn,
            ],
        )
        task_definition_arn = task_description["tasks"][0]["taskDefinitionArn"]
        task_definition = self.ecs.describe_task_definition(taskDefinition=task_definition_arn)

        log_group_name = task_definition["taskDefinition"]["containerDefinitions"][0]["logConfiguration"]["options"][
            "awslogs-group"
        ]

        task_name = task_definition["taskDefinition"]["containerDefinitions"][0]["name"]
        log_stream_prefix = task_definition["taskDefinition"]["containerDefinitions"][0]["logConfiguration"]["options"][
            "awslogs-stream-prefix"
        ]
        task_id = self.task_arn.split("/")[-1]
        log_stream_name = f"{log_stream_prefix}/{task_name}/{task_id}"

        return log_group_name, log_stream_name

    def _stream_log(self) -> None:
        try:
            logs = self.cloudwatch.get_log_events(
                logGroupName=self.log_group_name, logStreamName=self.log_stream_name, startFromHead=True
            )["events"]

            new_logs = self._subtract_list(logs, self.previous_logs)
            if new_logs:
                for line in new_logs:
                    print(line["message"])

            self.previous_logs = logs

        except:
            pass

    def _subtract_list(self, list1, list2) -> List[Any]:
        list_diff = list1.copy()
        for l in list2:
            try:
                list_diff.remove(l)
            except ValueError:
                continue
        return list_diff

    def watch_task_condition(self) -> None:
        running_status = True
        while running_status:
            response = self.ecs.describe_tasks(
                cluster=self.cluster,
                tasks=[
                    self.task_arn,
                ],
            )
            last_status = response["tasks"][0]["lastStatus"]

            if last_status == "STOPPED":
                running_status = False
                self._stream_log()
                exit_code = response["tasks"][0]["containers"][0]["exitCode"]
                if exit_code == 0:
                    print("ECS Task Success")
                else:
                    print("ECS Task Failed")
                    raise ECSTaskExecutionError
            else:
                self._stream_log()
                time.sleep(10)

実行例

boto3のrun_taskで実行したECS Taskのtask_arnと、ecs_clusterをEcsTaskWatcherに渡し、watch_task_conditionメソッドを呼び出します。
ECS Taskが終了するまで次の処理に進みたくない場合に便利です。

import boto3
from ecs_task_watcher import EcsTaskWatcher

ECS_CLUSTER = "cluster-name"
TASK_DEFINITION_ARN = "task_definition_arn"


def main():
    ecs = boto3.client("ecs")
    ecs_task_reponse = ecs.run_task(
        cluster=ECS_CLUSTER,
        taskDefinition=TASK_DEFINITION_ARN,
    )

    task_arn = ecs_task_reponse["tasks"][0]["taskArn"]
    ecs_task_watcher = EcsTaskWatcher(ECS_CLUSTER, task_arn)
    ecs_task_watcher.watch_task_condition()


if __name__ == "__main__":
    main()

参考