简单用法如下:

1
2
3
4
5
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
iris = load_iris()
print(iris.data.shape)
print(iris.DESCR)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
(150, 4)
.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

:Number of Instances: 150 (50 in each of three classes)
:Number of Attributes: 4 numeric, predictive attributes and the class
:Attribute Information:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
- class:
- Iris-Setosa
- Iris-Versicolour
- Iris-Virginica

:Summary Statistics:

============== ==== ==== ======= ===== ====================
Min Max Mean SD Class Correlation
============== ==== ==== ======= ===== ====================
sepal length: 4.3 7.9 5.84 0.83 0.7826
sepal width: 2.0 4.4 3.05 0.43 -0.4194
petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
============== ==== ==== ======= ===== ====================

:Missing Attribute Values: None
:Class Distribution: 33.3% for each of 3 classes.
:Creator: R.A. Fisher
:Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
:Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.

This is perhaps the best known database to be found in the
pattern recognition literature. Fisher's paper is a classic in the field and
is referenced frequently to this day. (See Duda & Hart, for example.) The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant. One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

.. topic:: References

- Fisher, R.A. "The use of multiple measurements in taxonomic problems"
Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
Mathematical Statistics" (John Wiley, NY, 1950).
- Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
(Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
- Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
Structure and Classification Rule for Recognition in Partially Exposed
Environments". IEEE Transactions on Pattern Analysis and Machine
Intelligence, Vol. PAMI-2, No. 1, 67-71.
- Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions
on Information Theory, May 1972, 431-433.
- See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II
conceptual clustering system finds 3 classes in the data.
- Many, many more ...
1
X, y = load_iris(return_X_y=True)
1
2
3
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=33, stratify=y)
print(X_train.shape)
print(X_test.shape)
1
2
(112, 4)
(38, 4)
  • train_target:所要划分的样本结果

  • test_size:样本占比,如果是整数的话就是样本的数量

  • random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填 1,其他参数一样的情况下你得到的随机数组是一样的。但填 0 或不填,每次都会不一样。

  • stratify 是为了保持 split 前类的分布。比如有 100 个数据,80 个属于 A 类,20 个属于 B 类。如果 train_test_split(… test_size=0.25, stratify = y_all), 那么 split 之后数据如下:

training: 75 个数据,其中 60 个属于 A 类,15 个属于 B 类。
testing: 25 个数据,其中 20 个属于 A 类,5 个属于 B 类。

1
2
3
4
5
用了stratify参数,training集和testing集的类的比例是 A:B= 4:1,等同于split前的比例(80:20)。通常在这种类分布不平衡的情况下会用到stratify。

将stratify=X就是按照X中的比例分配

将stratify=y就是按照y中的比例分配

整体总结起来各个参数的设置及其类型如下:

主要参数说明:

*arrays:可以是列表、numpy 数组、scipy 稀疏矩阵或 pandas 的数据框

test_size:可以为浮点、整数或 None,默认为 None

① 若为浮点时,表示测试集占总样本的百分比

② 若为整数时,表示测试样本样本数

③ 若为 None 时,test size 自动设置成 0.25

train_size:可以为浮点、整数或 None,默认为 None

① 若为浮点时,表示训练集占总样本的百分比

② 若为整数时,表示训练样本的样本数

③ 若为 None 时,train_size 自动被设置成 0.75

random_state:可以为整数、RandomState 实例或 None,默认为 None

① 若为 None 时,每次生成的数据都是随机,可能不一样

② 若为整数时,每次生成的数据都相同

stratify:可以为类似数组或 None

① 若为 None 时,划分出来的测试集或训练集中,其类标签的比例也是随机的

② 若不为 None 时,划分出来的测试集或训练集中,其类标签的比例同输入的数组中类标签的比例相同,可以用于处理不均衡的数据集