この記事はMLOps Advent Calendar 2022にリンクしてます。
SageMaker Endpointのレイテンシーを、負荷試験ツールLocustで計測してみました。
本文中のコード: https://github.com/nsakki55/code-for-blogpost/tree/main/sagemaker_endpoint_latency_check
記事の目的
プロダクションでML予測値を利用する場合、ビジネス要件を満たす推論速度に応じてインフラ・アルゴリズム選定を行う必要があります。
本記事ではSageMaker Endpointのレイテンシーを、負荷試験ツールLocustを用いて計測します。
モデルアルゴリズムによって推論速度は変わるため、推論処理を除いたレイテンシーを計測します。
アルゴリズムの推論処理高速化などの、レイテンシー高速化の調整は本記事では扱いません。
エンドポイント作成
モデルアルゴリズム選択による推論時間の差を除外するため、クライアントからデータを受け取り、定数0を返すだけの推論処理を実行します。
AWSリソース設定
使用するAWSリソースの設定を読み込んでおきます。
import yaml import sagemaker SETTING_FILE_PATH = "../config/settings.yaml" # AWSリソース設定読み込み with open(SETTING_FILE_PATH) as file: aws_info = yaml.safe_load(file) sess = sagemaker.Session() role = aws_info["aws"]["sagemaker"]["role"] bucket = aws_info["aws"]["sagemaker"]["s3bucket"] region = aws_info["aws"]["sagemaker"]["region"]
ダミーモデル作成
SageMaker Endpointを作成する際に、モデルファイルを指定する必要があるため、テキストファイルをtar.gz形式に圧縮したダミーのモデルを用意します。
! echo "this is dummy model" > dummy_model.txt |tar -czf dummy_model.tar.gz dummy_model.txt
ダミーのモデルをS3に保存します。
prefix = "latency_check" model_file = "dummy_model.tar.gz" s3_resource_session = boto3.Session().resource("s3").Bucket(bucket) s3_resource_session.Object(os.path.join(prefix, "model", model_file)).upload_file( model_file )
独自の推論スクリプトを実行するため、入力データを受け取り、定数0を返す推論処理を記述します。
SageMaker Endpointでは独自処理を実装する際のメソッドが用意されていて
- model_fn: ダミーモデルを読み込み、txtファイル中の文字列を出力
- predict_fn: 定数0を返す
処理を今回は記述します。
SageMaker Endpointが内部で何を行なっているかは、こちらの資料が参考になります。https://d1.awsstatic.com/webinars/jp/pdf/services/202208_AWS_Black_Belt_AWS_AIML_Dark_04_inference_part2.pdf
import os def model_fn(model_dir): with open(os.path.join(model_dir, "dummy_model.txt")) as f: model = f.read()[:-1] print(model) return model def predict_fn(input_data, model): return 0
用意したダミーモデルと、推論スクリプトを使用してModelを作成します。
SageMaker Endopoint で使用する Model は sagemaker toolkit をインストールした docker image を指定する必要があります。
Amazon SageMaker におけるカスタムコンテナ実装パターン詳説 〜推論編〜 | Amazon Web Services ブログ
今回は独自のimageを作らず、必要な環境が用意されている SKLearnModel を利用します。
model_data = f"s3://{bucket}/{prefix}/model/{model_file}" model = SKLearnModel( model_data=model_data, role=role, framework_version="0.23-1", py_version="py3", source_dir="model", entry_point="inference.py", sagemaker_session=sess, )
エンドポイント作成
ダミーモデルと独自推論スクリプトを設定したModelを使用して、SageMaker Endopointを作成します。
インスタンスはml.t2.medium (2vCPU, 4 GiB Memory)を指定しました。
timestamp = strftime("%Y%m%d-%H-%M-%S", gmtime()) model_name = "{}-{}".format("latency-check-model", timestamp) endpoint_name = "{}-{}".format("latency-check-endpoint", timestamp) sess.create_model( model_name, role, model.prepare_container_def(instance_type="ml.t2.medium") ) predictor = model.deploy( initial_instance_count=1, instance_type="ml.t2.medium", endpoint_name=endpoint_name, )
SageMaker Endopointからの処理を確認するために、リクエストを送ってみます。
期待通り定数0が返ってきているのを確認できます。
runtime = boto3.Session().client("sagemaker-runtime") response = runtime.invoke_endpoint( EndpointName=endpoint_name, ContentType="application/json", Accept="application/json", Body="0", ) predictions = json.loads(response["Body"].read().decode("utf-8")) print(predictions) #0
SageMaker Endopointの設定が完了したので、続いて負荷試験ツールlocustの準備を行います。
locustによる負荷試験
システム構成
Locust はPythonで書かれたオープンソースの負荷試験ツールです。
PythonでSageMaker Endpointへの負荷試験の内容を書けるため、今回はLocustを使用します。
下図のような構成で、SageMaker Endopoint と同一のリージョンにEC2を立てて、負荷試験を行います。
locustによる負荷試験を実施するEC2インスタンスはt2.medium(2 vCPU, 4 GiB Memory)を使用します。
SageMaker Endopointの負荷試験を実施するためのlocustのスクリプトを用意しました。
実装はこちらのレポジトリを参考にしました: https://github.com/C24IO/SageMaker-AutoLoad
レイテンシーはリクエストを投げ、取得するまでの時間として、Pythonのtimeモジュールで計算しています。
import time import boto3 import inspect from locust import task from botocore.config import Config from locust import TaskSet, task, events from locust.contrib.fasthttp import FastHttpUser from locust import task, events, constant def stopwatch(func): def wrapper(*args, **kwargs): previous_frame = inspect.currentframe().f_back _, _, task_name, task_func_name, _ = inspect.getframeinfo(previous_frame) task_func_name = task_func_name[0].split(".")[-1].split("(")[0] start = time.time() result = None try: result = func(*args, **kwargs) total = int((time.time() - start) * 1000) except Exception as e: events.request_failure.fire( request_type=task_name, name=task_func_name, response_time=total, response_length=len(result), exception=e, ) else: events.request_success.fire( request_type=task_name, name=task_func_name, response_time=total, response_length=len(result), ) return result return wrapper class ProtocolClient: def __init__(self, host): self.endpoint_name = host.split("/")[-1] self.region = "ap-northeast-1" self.content_type = "application/json" self.payload = "0" boto3config = Config(retries={"max_attempts": 100, "mode": "standard"}) self.sagemaker_client = boto3.client( "sagemaker-runtime", config=boto3config, region_name=self.region ) @stopwatch def sagemaker_client_invoke_endpoint(self): response = self.sagemaker_client.invoke_endpoint( EndpointName=self.endpoint_name, Body=self.payload, ContentType=self.content_type, ) response_body = response["Body"].read() return response_body class ProtocolLocust(FastHttpUser): abstract = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.client = ProtocolClient(self.host) self.client._locust_environment = self.environment class ProtocolTasks(TaskSet): @task def custom_protocol_boto3(self): self.client.sagemaker_client_invoke_endpoint() class ProtocolUser(ProtocolLocust): wait_time = constant(0) tasks = [ProtocolTasks]
以下のコマンドで負荷試験を実施します。
今回はレイテンシーに興味があるため、ユーザー数を1としてSageMaker Endopointの負荷をかけすぎないようにしています。
host名には http://
で始まる文字列を指定する必要があるため、便宜上 http://
をエンドポイント名の前に含めています。
locust -f locust_script.py -u 1 --headless --host='http://latency-check-endpoint-20221208-13-44-44' --stop-timeout 60 -L DEBUG -t 3m --logfile=logfile.log --csv=locust.csv --csv-full-history --reset-stats
結果
負荷試験を実施すると以下のように途中結果が標準出力されます。
100QPSでリクエストを送って、平均9msでレスポンスを受け取っています。
(data-science) [ec2-user@ip-10-0-1-232 locust]$ locust -f locust_script.py -u 1 --headless --host='http://latency-check-endpoint-20221208-13-44-44' --stop-timeout 60 -L[695/1612] 3m --logfile=logfile.log --csv=locust.csv --csv-full-history --reset-stats Name # reqs # fails | Avg Min Max Median | req/s failures/s -------------------------------------------------------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------------------------------------------------------- Aggregated 0 0(0.00%) | 0 0 0 0 | 0.00 0.00 Name # reqs # fails | Avg Min Max Median | req/s failures/s -------------------------------------------------------------------------------------------------------------------------------------------- custom_protocol_boto3 sagemaker_client_invoke_endpoint 180 0(0.00%) | 9 8 69 9 | 0.00 0.00 -------------------------------------------------------------------------------------------------------------------------------------------- Aggregated 180 0(0.00%) | 9 8 69 9 | 0.00 0.00 Name # reqs # fails | Avg Min Max Median | req/s failures/s -------------------------------------------------------------------------------------------------------------------------------------------- custom_protocol_boto3 sagemaker_client_invoke_endpoint 379 0(0.00%) | 9 8 69 9 | 82.00 0.00 -------------------------------------------------------------------------------------------------------------------------------------------- Aggregated 379 0(0.00%) | 9 8 69 9 | 82.00 0.00 Name # reqs # fails | Avg Min Max Median | req/s failures/s -------------------------------------------------------------------------------------------------------------------------------------------- custom_protocol_boto3 sagemaker_client_invoke_endpoint 579 0(0.00%) | 9 8 69 9 | 90.50 0.00 -------------------------------------------------------------------------------------------------------------------------------------------- Aggregated 579 0(0.00%) | 9 8 69 9 | 90.50 0.00
負荷試験が終わると、結果の統計量がcsvファイルで出力されます。
locust.csv_stats.csv
を出力すると以下の結果となりました。
percentaileごとの値を棒グラフで出力すると以下のようになりました。
以上、結果をまとめると
- 最小9ms
- 中央値9ms
- 99% percentile 14ms
となりました。
今回はモデルの推論処理を除いたため、実際にはさらに推論処理時間が加わります。
インスタンス・ネットワーク構成で異なる値をとると思いますが、SageMaker Endpointのレイテンシーのbaselineとして参考にしてもらえると幸いです。