Redshift ML

在本节,我们将使用 Redshift ML 训练机器学习模型并进行推理

数据准备

数据集信息https://archive.ics.uci.edu/ml/datasets/iris

这是鸢尾花数据集,该数据集包含 3 个类,每个类 50 个实例,其中每个类指一种鸢尾植物。其中一类与其他两类可线性分离;后者彼此不可线性分离。

预测属性:鸢尾属植物的类别。

属性信息:

  1. 萼片长度(厘米)
  2. 萼片宽度(厘米)
  3. 花瓣长度(厘米)
  4. 花瓣宽度(厘米)

类别: 山鸢尾 变色鸢尾 维吉尼亚鸢尾

执行以下语句在 Redshift 中创建并加载训练表和推理表。训练数据用于创建模型,推理数据用于进行预测:

DROP TABLE IF EXISTS iris_data_train;

CREATE TABLE iris_data_train (
  Id int,
  SepalLengthCm float,
  SepalWidthCm float,
  PetalLengthCm float,
  PetalWidthCm float,
  Species varchar(15)
  );

COPY iris_data_train from 's3://redshift-downloads/redshift-ml/workshop/iris-data/train/' REGION 'us-east-1' IAM_ROLE default CSV IGNOREHEADER 1 ;



DROP TABLE IF EXISTS iris_data_test;

CREATE TABLE iris_data_test (
  Id int,
  SepalLengthCm float,
  SepalWidthCm float,
  PetalLengthCm float,
  PetalWidthCm float,
  Species varchar(15)
  );

COPY iris_data_test from 's3://redshift-downloads/redshift-ml/workshop/iris-data/test/' REGION 'us-east-1' IAM_ROLE default CSV IGNOREHEADER 1 ;

表的数据内容:

image-20231127103549532

训练模型

训练模型时需要提供一些信息,例如 PROBLEM_TYPE(分类或回归)OBJECTIVE 作为创建模型过程的一部分:

执行以下语句来创建模型。创建一个S3 存储桶,并在下面的查询中提供S3桶名称:

CREATE MODEL model_iris
FROM (
SELECT
   Id,
   SepalLengthCm,
   SepalWidthCm,
   PetalLengthCm,
   PetalWidthCm,
   Species
FROM iris_data_train
)
TARGET Species
FUNCTION func_model_iris IAM_ROLE default
PROBLEM_TYPE multiclass_classification
OBJECTIVE 'accuracy'
SETTINGS (S3_BUCKET '<< REPLACE S3 bucket >>',
MAX_RUNTIME 1800 );
;

检查模型训练的进度。一开始模型状态为TRAINNING

show model model_iris;

image-20231127103648531

创建模型后,将需要约 30 分钟才能运行,状态会变成READY

image-20231127165036144

检查模型的推理/准确性。运行以下查询 : 第一个检查模型的准确性,第二个将使用预先构建的模型创建的函数进行推理:

--Check Model Accuracy

WITH infer_data AS (
    SELECT Species AS label,
        func_model_iris(Id, SepalLengthCm, SepalWidthCm, PetalLengthCm, PetalWidthCm) AS predicted,
        CASE WHEN label is NULL THEN NULL ELSE label END AS actual,
        CASE WHEN actual = predicted THEN 1::INT
        ELSE 0::INT END AS correct
    FROM iris_data_test
),
aggr_data AS (
    SELECT SUM(correct) as num_correct, COUNT(*) as total FROM infer_data
)
SELECT (num_correct::float/total::float) AS accuracy FROM aggr_data;



--Predict the class of iris flower

WITH class_data AS ( SELECT func_model_iris(
   Id,
   SepalLengthCm,
   SepalWidthCm,
   PetalLengthCm,
   PetalWidthCm) AS class
FROM iris_data_test )
SELECT
CASE WHEN class = 'Iris-versicolor'  THEN 'Class-Iris-versicolor'
     WHEN class = 'Iris-setosa'  THEN 'Class-Iris-setosa'
     WHEN class = 'Iris-virginica'  THEN 'Class-Iris-virginica'
     ELSE 'Class-Other' END as class_distribution,
COUNT(1) AS count
from class_data GROUP BY 1;

image-20231127165243010