肉球でキーボード

MLエンジニアの技術ブログです

SageMaker Endpointのレイテンシーを負荷試験ツールlocustで計測する

この記事はMLOps Advent Calendar 2022にリンクしてます。

SageMaker Endpointのレイテンシーを、負荷試験ツールLocustで計測してみました。

本文中のコード: https://github.com/nsakki55/code-for-blogpost/tree/main/sagemaker_endpoint_latency_check

記事の目的

プロダクションでML予測値を利用する場合、ビジネス要件を満たす推論速度に応じてインフラ・アルゴリズム選定を行う必要があります。
本記事ではSageMaker Endpointのレイテンシーを、負荷試験ツールLocustを用いて計測します
モデルアルゴリズムによって推論速度は変わるため、推論処理を除いたレイテンシーを計測します
アルゴリズムの推論処理高速化などの、レイテンシー高速化の調整は本記事では扱いません。

エンドポイント作成

モデルアルゴリズム選択による推論時間の差を除外するため、クライアントからデータを受け取り、定数0を返すだけの推論処理を実行します。

AWSリソース設定

使用するAWSリソースの設定を読み込んでおきます。

import yaml
import sagemaker

SETTING_FILE_PATH = "../config/settings.yaml"

# AWSリソース設定読み込み
with open(SETTING_FILE_PATH) as file:
    aws_info = yaml.safe_load(file)

sess = sagemaker.Session()

role = aws_info["aws"]["sagemaker"]["role"]
bucket = aws_info["aws"]["sagemaker"]["s3bucket"]
region = aws_info["aws"]["sagemaker"]["region"]

ダミーモデル作成

SageMaker Endpointを作成する際に、モデルファイルを指定する必要があるため、テキストファイルをtar.gz形式に圧縮したダミーのモデルを用意します。

! echo "this is dummy model" > dummy_model.txt |tar -czf dummy_model.tar.gz dummy_model.txt

ダミーのモデルをS3に保存します。

prefix = "latency_check"
model_file = "dummy_model.tar.gz"

s3_resource_session = boto3.Session().resource("s3").Bucket(bucket)
s3_resource_session.Object(os.path.join(prefix, "model", model_file)).upload_file(
    model_file
)

独自の推論スクリプトを実行するため、入力データを受け取り、定数0を返す推論処理を記述します。

SageMaker Endpointでは独自処理を実装する際のメソッドが用意されていて

  • model_fn: ダミーモデルを読み込み、txtファイル中の文字列を出力
  • predict_fn: 定数0を返す

処理を今回は記述します。

SageMaker Endpointが内部で何を行なっているかは、こちらの資料が参考になります。https://d1.awsstatic.com/webinars/jp/pdf/services/202208_AWS_Black_Belt_AWS_AIML_Dark_04_inference_part2.pdf

import os


def model_fn(model_dir):
    with open(os.path.join(model_dir, "dummy_model.txt")) as f:
        model = f.read()[:-1]
        print(model)
    return model


def predict_fn(input_data, model):
    return 0

用意したダミーモデルと、推論スクリプトを使用してModelを作成します。

SageMaker Endopoint で使用する Model は sagemaker toolkit をインストールした docker image を指定する必要があります。
Amazon SageMaker におけるカスタムコンテナ実装パターン詳説 〜推論編〜 | Amazon Web Services ブログ

今回は独自のimageを作らず、必要な環境が用意されている SKLearnModel を利用します。

model_data = f"s3://{bucket}/{prefix}/model/{model_file}"
model = SKLearnModel(
    model_data=model_data,
    role=role,
    framework_version="0.23-1",
    py_version="py3",
    source_dir="model",
    entry_point="inference.py",
    sagemaker_session=sess,
)

エンドポイント作成

ダミーモデルと独自推論スクリプトを設定したModelを使用して、SageMaker Endopointを作成します。

インスタンスはml.t2.medium (2vCPU, 4 GiB Memory)を指定しました。

timestamp = strftime("%Y%m%d-%H-%M-%S", gmtime())
model_name = "{}-{}".format("latency-check-model", timestamp)
endpoint_name = "{}-{}".format("latency-check-endpoint", timestamp)


sess.create_model(
    model_name, role, model.prepare_container_def(instance_type="ml.t2.medium")
)
predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.t2.medium",
    endpoint_name=endpoint_name,
)

SageMaker Endopointからの処理を確認するために、リクエストを送ってみます。
期待通り定数0が返ってきているのを確認できます。

runtime = boto3.Session().client("sagemaker-runtime")
response = runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/json",
    Accept="application/json",
    Body="0",
)
predictions = json.loads(response["Body"].read().decode("utf-8"))
print(predictions) #0

SageMaker Endopointの設定が完了したので、続いて負荷試験ツールlocustの準備を行います。

locustによる負荷試験

システム構成

Locust はPythonで書かれたオープンソース負荷試験ツールです。
PythonでSageMaker Endpointへの負荷試験の内容を書けるため、今回はLocustを使用します。
下図のような構成で、SageMaker Endopoint と同一のリージョンにEC2を立てて、負荷試験を行います。
locustによる負荷試験を実施するEC2インスタンスはt2.medium(2 vCPU, 4 GiB Memory)を使用します。

負荷試験構成

SageMaker Endopointの負荷試験を実施するためのlocustのスクリプトを用意しました。
実装はこちらのレポジトリを参考にしました: https://github.com/C24IO/SageMaker-AutoLoad

レイテンシーはリクエストを投げ、取得するまでの時間として、Pythonのtimeモジュールで計算しています。

import time
import boto3
import inspect
from locust import task
from botocore.config import Config
from locust import TaskSet, task, events
from locust.contrib.fasthttp import FastHttpUser
from locust import task, events, constant


def stopwatch(func):
    def wrapper(*args, **kwargs):
        previous_frame = inspect.currentframe().f_back
        _, _, task_name, task_func_name, _ = inspect.getframeinfo(previous_frame)
        task_func_name = task_func_name[0].split(".")[-1].split("(")[0]

        start = time.time()
        result = None

        try:
            result = func(*args, **kwargs)
            total = int((time.time() - start) * 1000)

        except Exception as e:
            events.request_failure.fire(
                request_type=task_name,
                name=task_func_name,
                response_time=total,
                response_length=len(result),
                exception=e,
            )
        else:
            events.request_success.fire(
                request_type=task_name,
                name=task_func_name,
                response_time=total,
                response_length=len(result),
            )
        return result

    return wrapper


class ProtocolClient:
    def __init__(self, host):
        self.endpoint_name = host.split("/")[-1]
        self.region = "ap-northeast-1"
        self.content_type = "application/json"
        self.payload = "0"

        boto3config = Config(retries={"max_attempts": 100, "mode": "standard"})
        self.sagemaker_client = boto3.client(
            "sagemaker-runtime", config=boto3config, region_name=self.region
        )

    @stopwatch
    def sagemaker_client_invoke_endpoint(self):
        response = self.sagemaker_client.invoke_endpoint(
            EndpointName=self.endpoint_name,
            Body=self.payload,
            ContentType=self.content_type,
        )
        response_body = response["Body"].read()
        return response_body


class ProtocolLocust(FastHttpUser):
    abstract = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.client = ProtocolClient(self.host)
        self.client._locust_environment = self.environment


class ProtocolTasks(TaskSet):
    @task
    def custom_protocol_boto3(self):
        self.client.sagemaker_client_invoke_endpoint()


class ProtocolUser(ProtocolLocust):
    wait_time = constant(0)
    tasks = [ProtocolTasks]

以下のコマンドで負荷試験を実施します。
今回はレイテンシーに興味があるため、ユーザー数を1としてSageMaker Endopointの負荷をかけすぎないようにしています。
host名には http:// で始まる文字列を指定する必要があるため、便宜上 http:// をエンドポイント名の前に含めています。

locust -f locust_script.py -u 1 --headless --host='http://latency-check-endpoint-20221208-13-44-44' --stop-timeout 60 -L DEBUG -t 3m --logfile=logfile.log --csv=locust.csv --csv-full-history --reset-stats

結果

負荷試験を実施すると以下のように途中結果が標準出力されます。
100QPSでリクエストを送って、平均9msでレスポンスを受け取っています。

(data-science) [ec2-user@ip-10-0-1-232 locust]$ locust -f locust_script.py -u 1 --headless --host='http://latency-check-endpoint-20221208-13-44-44' --stop-timeout 60 -L[695/1612]
3m --logfile=logfile.log --csv=locust.csv --csv-full-history --reset-stats
 Name                                                          # reqs      # fails  |     Avg     Min     Max  Median  |   req/s failures/s
--------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------
 Aggregated                                                         0     0(0.00%)  |       0       0       0       0  |    0.00    0.00

 Name                                                          # reqs      # fails  |     Avg     Min     Max  Median  |   req/s failures/s
--------------------------------------------------------------------------------------------------------------------------------------------
 custom_protocol_boto3 sagemaker_client_invoke_endpoint           180     0(0.00%)  |       9       8      69       9  |    0.00    0.00
--------------------------------------------------------------------------------------------------------------------------------------------
 Aggregated                                                       180     0(0.00%)  |       9       8      69       9  |    0.00    0.00

 Name                                                          # reqs      # fails  |     Avg     Min     Max  Median  |   req/s failures/s
--------------------------------------------------------------------------------------------------------------------------------------------
 custom_protocol_boto3 sagemaker_client_invoke_endpoint           379     0(0.00%)  |       9       8      69       9  |   82.00    0.00
--------------------------------------------------------------------------------------------------------------------------------------------
 Aggregated                                                       379     0(0.00%)  |       9       8      69       9  |   82.00    0.00

 Name                                                          # reqs      # fails  |     Avg     Min     Max  Median  |   req/s failures/s
--------------------------------------------------------------------------------------------------------------------------------------------
 custom_protocol_boto3 sagemaker_client_invoke_endpoint           579     0(0.00%)  |       9       8      69       9  |   90.50    0.00
--------------------------------------------------------------------------------------------------------------------------------------------
 Aggregated                                                       579     0(0.00%)  |       9       8      69       9  |   90.50    0.00

負荷試験が終わると、結果の統計量がcsvファイルで出力されます。
locust.csv_stats.csv を出力すると以下の結果となりました。

locust.csv_stats.csv

percentaileごとの値を棒グラフで出力すると以下のようになりました。

latency percentile

以上、結果をまとめると

  • 最小9ms
  • 中央値9ms
  • 99% percentile 14ms

となりました。

今回はモデルの推論処理を除いたため、実際にはさらに推論処理時間が加わります。
インスタンス・ネットワーク構成で異なる値をとると思いますが、SageMaker Endpointのレイテンシーのbaselineとして参考にしてもらえると幸いです。

参考