在本节,我们将使用 Redshift ML 训练机器学习模型并进行推理
数据集信息: https://archive.ics.uci.edu/ml/datasets/iris
这是鸢尾花数据集,该数据集包含 3 个类,每个类 50 个实例,其中每个类指一种鸢尾植物。其中一类与其他两类可线性分离;后者彼此不可线性分离。
预测属性:鸢尾属植物的类别。
属性信息:
类别: 山鸢尾 变色鸢尾 维吉尼亚鸢尾
执行以下语句在 Redshift 中创建并加载训练表和推理表。训练数据用于创建模型,推理数据用于进行预测:
DROP TABLE IF EXISTS iris_data_train;
CREATE TABLE iris_data_train (
Id int,
SepalLengthCm float,
SepalWidthCm float,
PetalLengthCm float,
PetalWidthCm float,
Species varchar(15)
);
COPY iris_data_train from 's3://redshift-downloads/redshift-ml/workshop/iris-data/train/' REGION 'us-east-1' IAM_ROLE default CSV IGNOREHEADER 1 ;
DROP TABLE IF EXISTS iris_data_test;
CREATE TABLE iris_data_test (
Id int,
SepalLengthCm float,
SepalWidthCm float,
PetalLengthCm float,
PetalWidthCm float,
Species varchar(15)
);
COPY iris_data_test from 's3://redshift-downloads/redshift-ml/workshop/iris-data/test/' REGION 'us-east-1' IAM_ROLE default CSV IGNOREHEADER 1 ;
表的数据内容:
训练模型时需要提供一些信息,例如 PROBLEM_TYPE(分类或回归)
和 OBJECTIVE
作为创建模型过程的一部分:
ROBLEM_TYPE
:这里选multiclass
。适用于 SageMaker Autopilot
支持的所有 Problem_type参考: https://docs.aws.amazon.com/sagemaker/latest/dg/autopilot-automate-model-development-problem-types.html
OBJECTIVE
:这里选accuray
。适用于 xgboost 的所有目标:https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters
执行以下语句来创建模型。创建一个S3 存储桶,并在下面的查询中提供S3桶名称:
CREATE MODEL model_iris
FROM (
SELECT
Id,
SepalLengthCm,
SepalWidthCm,
PetalLengthCm,
PetalWidthCm,
Species
FROM iris_data_train
)
TARGET Species
FUNCTION func_model_iris IAM_ROLE default
PROBLEM_TYPE multiclass_classification
OBJECTIVE 'accuracy'
SETTINGS (S3_BUCKET '<< REPLACE S3 bucket >>',
MAX_RUNTIME 1800 );
;
检查模型训练的进度。一开始模型状态为TRAINNING
:
show model model_iris;
创建模型后,将需要约 30 分钟才能运行,状态会变成READY
:
检查模型的推理/准确性。运行以下查询 : 第一个检查模型的准确性,第二个将使用预先构建的模型创建的函数进行推理:
--Check Model Accuracy
WITH infer_data AS (
SELECT Species AS label,
func_model_iris(Id, SepalLengthCm, SepalWidthCm, PetalLengthCm, PetalWidthCm) AS predicted,
CASE WHEN label is NULL THEN NULL ELSE label END AS actual,
CASE WHEN actual = predicted THEN 1::INT
ELSE 0::INT END AS correct
FROM iris_data_test
),
aggr_data AS (
SELECT SUM(correct) as num_correct, COUNT(*) as total FROM infer_data
)
SELECT (num_correct::float/total::float) AS accuracy FROM aggr_data;
--Predict the class of iris flower
WITH class_data AS ( SELECT func_model_iris(
Id,
SepalLengthCm,
SepalWidthCm,
PetalLengthCm,
PetalWidthCm) AS class
FROM iris_data_test )
SELECT
CASE WHEN class = 'Iris-versicolor' THEN 'Class-Iris-versicolor'
WHEN class = 'Iris-setosa' THEN 'Class-Iris-setosa'
WHEN class = 'Iris-virginica' THEN 'Class-Iris-virginica'
ELSE 'Class-Other' END as class_distribution,
COUNT(1) AS count
from class_data GROUP BY 1;