gpt4 book ai didi

python - 在 Amazon SageMaker 中进行预测之前预处理输入数据

转载 作者:行者123 更新时间:2023-11-30 08:53:10 24 4
gpt4 key购买 nike

我有一个我们自己训练的 Keras/tensorflow 模型,它可以进行图像相关的预测。我已经关注了这个trained keras model tutorial在 Sagemaker 中部署模型并可以调用端点进行预测。

现在,在我的客户端代码上,在通过调用 Sagemaker 端点进行预测之前,我需要下载图像并进行一些预处理。我不想在客户端执行此操作,而是想在 SageMaker 中执行整个过程。我该怎么做?

看来我需要更新入口点python代码train.py正如这里提到的:

sagemaker_model = TensorFlowModel(model_data = 's3://' + sagemaker_session.default_bucket() + '/model/model.tar.gz',
role = role,
entry_point = 'train.py')

其他文章表明我需要覆盖 input_fn函数来捕获预处理。但这些articles请参阅使用 MXNet 框架时使用的步骤。但我的模型是基于Keras/tensorflow框架的。

所以我不知道如何覆盖 input_fn功能。有人可以建议一下吗?

最佳答案

我也遇到了同样的问题,最后找到了解决方法。

准备好 model_data 后,您可以使用以下几行代码来部署它。

from sagemaker.tensorflow.model import TensorFlowModel
sagemaker_model = TensorFlowModel(
model_data = 's3://path/to/model/model.tar.gz',
role = role,
framework_version = '1.12',
entry_point = 'train.py',
source_dir='my_src',
env={'SAGEMAKER_REQUIREMENTS': 'requirements.txt'}
)

predictor = sagemaker_model.deploy(
initial_instance_count=1,
instance_type='ml.m4.xlarge',
endpoint_name='resnet-tensorflow-classifier'
)

您的笔记本应该有一个 my_src 目录,其中包含文件 train.pyrequirements.txt 文件。 train.py 文件应该定义一个函数 input_fn。对我来说,该函数处理图像/jpeg 内容:

import io
import numpy as np
from PIL import Image
from keras.applications.resnet50 import preprocess_input
from keras.preprocessing import image

JPEG_CONTENT_TYPE = 'image/jpeg'

# Deserialize the Invoke request body into an object we can perform prediction on
def input_fn(request_body, content_type=JPEG_CONTENT_TYPE):
# process an image uploaded to the endpoint
if content_type == JPEG_CONTENT_TYPE:
img = Image.open(io.BytesIO(request_body)).resize((300, 300))
img_array = np.array(img)
expanded_img_array = np.expand_dims(img_array, axis=0)
x = preprocess_input(expanded_img_array)
return x


else:
raise errors.UnsupportedFormatError(content_type)

您的处理代码将取决于您使用的模型架构。我正在 resnet50 上进行迁移学习,因此我使用了 keras.applications.resnet50 中的 preprocess_input

请注意,由于我的 train.py 代码导入了一些模块,因此我必须提供定义这些模块的 requirements.txt (这是我在文档)。

希望这对将来的人有帮助。

我的requirements.txt:

absl-py==0.7.1
astor==0.8.0
backports.weakref==1.0.post1
enum34==1.1.6
funcsigs==1.0.2
futures==3.2.0
gast==0.2.2
grpcio==1.20.1
h5py==2.9.0
Keras==2.2.4
Keras-Applications==1.0.7
Keras-Preprocessing==1.0.9
Markdown==3.1.1
mock==3.0.5
numpy==1.16.3
Pillow==6.0.0
protobuf==3.7.1
PyYAML==5.1
scipy==1.2.1
six==1.12.0
tensorboard==1.13.1
tensorflow==1.13.1
tensorflow-estimator==1.13.0
termcolor==1.1.0
virtualenv==16.5.0
Werkzeug==0.15.4

关于python - 在 Amazon SageMaker 中进行预测之前预处理输入数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54880241/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com