1

我正在寻找一种优雅的解决方案来使用在一个数据集(例如您的训练集)中创建的决策规则来根据这些规则拆分另一个数据集(例如测试数据)的数据。

看这个例子:

# Load PimaIndiansDiabetes dataset from mlbench package
library("mlbench")
data("PimaIndiansDiabetes")
## Split in training and test (2/3 - 1/3)
idtrain <- c(sample(1:768,512))
PimaTrain <-PimaIndiansDiabetes[idtrain,]
Pimatest <-PimaIndiansDiabetes[-idtrain,]

m1 <- RWeka::J48(as.factor(as.character(PimaTrain$diabetes)) ~ .,
                 data = PimaTrain[,-c(9)],
                 control = RWeka::Weka_control(M = 10, C= 0.25))

这给出了以下输出:

> m1
J48 pruned tree
------------------

glucose <= 154
|   age <= 28
|   |   glucose <= 118: neg (157.0/11.0)
|   |   glucose > 118
|   |   |   pressure <= 52: pos (10.0/3.0)
|   |   |   pressure > 52: neg (54.0/12.0)
|   age > 28
|   |   glucose <= 103: neg (54.0/10.0)
|   |   glucose > 103
|   |   |   mass <= 41.3: neg (129.0/55.0)
|   |   |   mass > 41.3: pos (12.0/1.0)
glucose > 154: pos (96.0/19.0)

Number of Leaves  :     7

Size of the tree :  13

根据这些规则,您将拥有 7 个组(或叶)。我正在寻找的是对测试数据Pimatest应用这些规则(因此不重新训练决策树),以便实际上每个数据点都可以指定给用新变量group指示的 7 个组之一。

输出如下所示:

head(Pimatest)
   pregnant glucose pressure triceps insulin mass pedigree age diabetes group
3         8     183       64       0       0 23.3    0.672  32      pos     7
4         1      89       66      23      94 28.1    0.167  21      neg     1
6         5     116       74       0       0 25.6    0.201  30      neg     5
7         3      78       50      32      88 31.0    0.248  26      pos     1
8        10     115        0       0       0 35.3    0.134  29      neg     5
11        4     110       92       0       0 37.6    0.191  30      neg     5

我目前有一个工作解决方案,它的编码非常糟糕,所以这就是为什么我正在为这个问题寻找一个优雅的解决方案。

4

1 回答 1

2

据我了解,您希望能够将每个点与对该点进行分类的一组规则联系起来。您可以通过将J48树转换为party树并使用partykit包中的工具来实现。

因为您没有为随机数生成器设置种子,所以我们无法获得与您获得的完全相同的测试/训练拆分。我将设置种子以使我的示例可重现,但即使我使用您的代码,我的树也会与您的略有不同。

可重现的示例(主要是您的代码)

library(RWeka)
library("mlbench")
data("PimaIndiansDiabetes")

## Split in training and test (2/3 - 1/3)
set.seed(1234)
idtrain <- c(sample(1:768,512))
PimaTrain <-PimaIndiansDiabetes[idtrain,]
Pimatest <-PimaIndiansDiabetes[-idtrain,]

m1 <- RWeka::J48(as.factor(as.character(PimaTrain$diabetes)) ~ .,
                 data = PimaTrain[,-c(9)],
                 control = RWeka::Weka_control(M = 10, C= 0.25))
m1
J48 pruned tree
------------------
glucose <= 122
|   mass <= 26.8: neg (85.0/1.0)
|   mass > 26.8
|   |   pregnant <= 4: neg (137.0/19.0)
|   |   pregnant > 4
|   |   |   glucose <= 106: neg (44.0/10.0)
|   |   |   glucose > 106: pos (24.0/6.0)
glucose > 122
|   glucose <= 157
|   |   age <= 31
|   |   |   age <= 24: neg (30.0/5.0)
|   |   |   age > 24
|   |   |   |   pressure <= 72: pos (16.0/5.0)
|   |   |   |   pressure > 72: neg (22.0/5.0)
|   |   age > 31: pos (78.0/27.0)
|   glucose > 157: pos (76.0/13.0)

Number of Leaves  :     9
Size of the tree :      17

我的树有 9 个叶子而不是你的 7 个。这是由于为训练集选择了不同的实例。现在我们准备好获取规则了。

library(partykit)
Pm1 = as.party(m1)
Pm1
Fitted party:
[1] root
|   [2] glucose <= 122
|   |   [3] mass <= 26.8: neg (n = 85, err = 1.2%)
|   |   [4] mass > 26.8
|   |   |   [5] pregnant <= 4: neg (n = 137, err = 13.9%)
|   |   |   [6] pregnant > 4
|   |   |   |   [7] glucose <= 106: neg (n = 44, err = 22.7%)
|   |   |   |   [8] glucose > 106: pos (n = 24, err = 25.0%)
|   [9] glucose > 122
|   |   [10] glucose <= 157
|   |   |   [11] age <= 31
|   |   |   |   [12] age <= 24: neg (n = 30, err = 16.7%)
|   |   |   |   [13] age > 24
|   |   |   |   |   [14] pressure <= 72: pos (n = 16, err = 31.2%)
|   |   |   |   |   [15] pressure > 72: neg (n = 22, err = 22.7%)
|   |   |   [16] age > 31: pos (n = 78, err = 34.6%)
|   |   [17] glucose > 157: pos (n = 76, err = 17.1%)

Number of inner nodes:    8
Number of terminal nodes: 9

这与之前的树相同,但具有节点被标记的优点。我们还可以为每个叶子写出规则。

Pm1_rules = partykit:::.list.rules.party(Pm1)
Pm1_rules
                                                                       3 
                                         "glucose <= 122 & mass <= 26.8" 
                                                                       5 
                          "glucose <= 122 & mass > 26.8 & pregnant <= 4" 
                                                                       7 
          "glucose <= 122 & mass > 26.8 & pregnant > 4 & glucose <= 106" 
                                                                       8 
           "glucose <= 122 & mass > 26.8 & pregnant > 4 & glucose > 106" 
                                                                      12 
                "glucose > 122 & glucose <= 157 & age <= 31 & age <= 24" 
                                                                      14 
"glucose > 122 & glucose <= 157 & age <= 31 & age > 24 & pressure <= 72" 
                                                                      15 
 "glucose > 122 & glucose <= 157 & age <= 31 & age > 24 & pressure > 72" 
                                                                      16 
                             "glucose > 122 & glucose <= 157 & age > 31" 
                                                                      17 
                                         "glucose > 122 & glucose > 157" 

The decisions are written out as rules. The names of the rulesets are the numbers of the leaf nodes. To get the rules used for a test point, you just need to know which leaf node it ends up at. But the predict method for party object will give you that.

TestPred = predict(Pm1, newdata=Pimatest, type="node")
TestPred
  3   4   5   6   9  12  17  20  22  27  28  29  31  32  33  35  36  38  41  43 
 17   5  16   3  17  17   5   5   7  16   3  16   8  17   3   8   3   7  17   3 
 46  48  50  56  57  60  62  64  65  66  68  70  72  75  76  79  84  95  96  97 
 17   5   3   3  17   5  16  12   8   7   5  15  14   5   3  14   3  12  16   5 
...

I truncated the output because it was too long. Now, for example,
we see that the first test point went to node 17. We just need to use that to index into the rule sets. But a little care is needed. The 17 returned by predict is a number. The name of the ruleset is a string, so we need to use as.character to convert it.

Pm1_rules[as.character(TestPred[1])]
                             17 
"glucose > 122 & glucose > 157" 

We confirm:

Pimatest[1,]
  pregnant glucose pressure triceps insulin mass pedigree age diabetes
3        8     183       64       0       0 23.3    0.672  32      pos

So yes, glucose > 122 AND glucose > 157

You can get the rules for the other test points in the same way.

于 2018-08-30T20:46:37.100 回答