ctree()-如何获取每个终端节点的拆分条件列表?

kgsdhlau  于 2023-09-27  发布在  其他
关注(0)|答案(5)|浏览(88)

我有一个来自ctree()party包)的输出,如下所示。如何获得每个终端节点的拆分条件列表,例如sns <= 0, dta <= 1; sns <= 0, dta > 1等?

1) sns <= 0; criterion = 1, statistic = 14655.021
  2) dta <= 1; criterion = 1, statistic = 3286.389
   3)*  weights = 153682 
  2) dta > 1
   4)*  weights = 289415 
1) sns > 0
  5) dta <= 2; criterion = 1, statistic = 1882.439
   6)*  weights = 245457 
  5) dta > 2
   7) dta <= 6; criterion = 1, statistic = 1170.813
     8)*  weights = 328582 
   7) dta > 6

谢谢

64jmpszr

64jmpszr1#

这个函数应该可以完成这项工作

CtreePathFunc <- function (ct, data) {

  ResulTable <- data.frame(Node = character(), Path = character())

  for(Node in unique(where(ct))){
  # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
  NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])

  # Getting the weigths for that node
  NodeWeights <- nodes(ct, Node)[[1]]$weights

  # Finding the path
  Path <- NULL
  for (i in NonTerminalNodes){
    if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
  }

  # Finding the splitting creteria for that path
  Path2 <- SB <- NULL

  for(i in 1:length(Path)){
    if(i == length(Path)) {
      n <- nodes(ct, Node)[[1]]
    } else {n <- nodes(ct, Path[i + 1])[[1]]}

    if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
      SB <- "<="
    } else {SB <- ">"}
    Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
                                 SB,
                                 as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
                   collapse = ", ")
  }

  # Output
  ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
  }
  return(ResulTable)
}

测试

library(party)
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
Result <- CtreePathFunc(ct, airq)
Result 

##   Node                               Path
## 1    5 Temp <= 82, Wind > 6.9, Temp <= 77
## 2    3            Temp <= 82, Wind <= 6.9
## 3    6  Temp <= 82, Wind > 6.9, Temp > 77
## 4    9             Temp > 82, Wind > 10.3
## 5    8            Temp > 82, Wind <= 10.3
h79rfbju

h79rfbju2#

如果您使用ctree()的新推荐partykit实现,而不是旧的party包,则可以使用函数.list.rules.party()。这还没有正式导出,但可以用来提取所需的信息。

library("partykit")
airq <- subset(airquality, !is.na(Ozone))
ct <- ctree(Ozone ~ ., data = airq)
partykit:::.list.rules.party(ct)
##                                      3                                      5 
##             "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77" 
##                                      6                                      8 
##  "Temp <= 82 & Wind > 6.9 & Temp > 77"             "Temp > 82 & Wind <= 10.3" 
##                                      9 
##              "Temp > 82 & Wind > 10.3"
62o28rlo

62o28rlo3#

由于我需要这个函数,但对于分类数据,我或多或少地回答了问题@JoãoDaniel(我只测试了分类预测变量),接下来的函数:

# returns string w/o leading or trailing whitespace
# http://stackoverflow.com/questions/2261079/how-to-trim-leading-and-trailing-whitespace-in-r
trim <- function (x) gsub("^\\s+|\\s+$", "", x)
getVariable <- function (x) sub("(.*?)[[:space:]].*", "\\1", x)
getSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\2", x)

getReglaFinal = function(elemento) {        
    x = as.data.frame(strsplit(as.character(elemento),";"))
    Regla = apply(x,1, trim)
    Regla = data.frame(Regla)
    indice = as.numeric(rownames(Regla))
    variable = apply(Regla,1, getVariable)
    simbolo = apply(Regla,1, getSimbolo)

    ReglaRaw = data.frame(Regla,indice,variable,simbolo)
    cols <- c( 'variable' , 'simbolo' )
    ReglaRaw$tipo_corte <- apply(  ReglaRaw[ , cols ] ,1 , paste , collapse = "" )
    #print(ReglaRaw)
    cortes = unique(ReglaRaw$tipo_corte)
    #print(cortes)
    ReglaFinal = ""
    for(i in 1:length(cortes)){
        #print("------------------------------------")
        #print(cortes[i])
        #print("ReglaRaw econtrada")
        #print(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
        maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
        #print(maximo)
        tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo])
        if(ReglaFinal==""){
            ReglaFinal = tmp
        }else{
            ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ")
        }
    }
    return(ReglaFinal)
}#getReglaFinal

CtreePathFuncAllCat <- function (ct) {

  ResulTable <- data.frame(Node = character(), Path = character())

  for(Node in unique(where(ct))){

    # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
    NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])

    # Getting the weigths for that node
    NodeWeights <- nodes(ct, Node)[[1]]$weights

    # Finding the path
    Path <- NULL
    for (i in NonTerminalNodes){
        if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
    }

    # Finding the splitting creteria for that path
    Path2 <- SB <- NULL

    variablesNombres <- array()
    variablesPuntos <- list()

    for(i in 1:length(Path)){
        n <- nodes(ct, Path[i])[[1]]

        if(i == length(Path)) {
            nextNodeID = Node
        } else {
            nextNodeID = Path[i+1]
        }       

        vec_puntos  = as.vector(n[[5]]$splitpoint)
        vec_nombre  = n[[5]]$variableName
        vec_niveles = attr(n[[5]]$splitpoint,"levels")

        index = 0

        if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) ){
            index = vec_puntos
            vec_puntos = vector(length=length(vec_niveles))
            vec_puntos[index] = TRUE
        }

        if(length(vec_niveles)==0){
            index = vec_puntos
            vec_puntos = n[[5]]$splitpoint
        }

        if(index==0){
            if(nextNodeID==n$right$nodeID){
                vec_puntos = !vec_puntos
            }else{
                vec_puntos = !!vec_puntos
            }
            if(i != 1) {
                for(j in 1:(length(Path)-1)){
                    if(length(variablesNombres)>=j){
                        if( variablesNombres[j]==vec_nombre){
                            vec_puntos = vec_puntos*variablesPuntos[[j]]
                        }
                    }
                }
                vec_puntos = vec_puntos==1
            }   
            SB = "="
        }else{
            if(nextNodeID==n$right$nodeID){
                SB = ">"
            }else{
                SB = "<="
            }

        }

        variablesPuntos[[i]] = vec_puntos       
        variablesNombres[i] = vec_nombre

        if(length(vec_niveles)==0){
            descripcion = vec_puntos
        }else{
            descripcion = paste(vec_niveles[vec_puntos],collapse=", ")
        }
        Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"{",descripcion, "}"),collapse=" ")
                        ),
                       collapse = "; ")
    }

    # Output
    ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
  }

    we = weights(ct)
    c0 = as.matrix(where(ct))
    c3 = sapply(we, function(w) sum(w))
    c3 = as.matrix(unique(cbind(c0,c3)))
    Counts = as.matrix(c3[,2])
    c2 = drop(Predict(ct))
    Means = as.matrix(unique(c2))

    ResulTable = data.frame(ResulTable,Means,Counts)
    ResulTable  = ResulTable[ order(ResulTable$Means) ,]

    ResulTable$TruePath =  apply(as.data.frame(ResulTable$Path),1, getReglaFinal)

    ResulTable2 = ResulTable

    ResulTable2$SQL <- paste("WHEN ",gsub("\\'([-+]?([0-9]*\\.[0-9]+|[0-9]+))\\'", "\\1",gsub("\\, ", "','", gsub(" \\}", "')", gsub("\\{ ", "('", gsub("\\;", " AND ", ResulTable2$TruePath)))))," THEN ")

    cols <- c( 'SQL' , 'Node' )
    ResulTable2$SQL <- apply(  ResulTable2[ , cols ] ,1 , paste , collapse = "'Nodo " )

    ResulTable2$SQL <- gsub("THEN'", "THEN '", gsub(" '", "'",  paste(ResulTable2$SQL,"'")))

    ResultadoFinal = list()

    ResultadoFinal$PreTable = ResulTable
    ResultadoFinal$Table = ResulTable
    ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath
    ResultadoFinal$Table$TruePath = NULL
    ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="")

    return(ResultadoFinal)
}#CtreePathFuncAllCat

下面是一个测试:

library(party)
#With ordered factors
TreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp)
Result2 <- CtreePathFuncAllCat(TreeModel1)
Result2
##$PreTable
##  Node                                                Path    Means Counts
##3    7    DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316    114
##2    6   DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000    175
##1    4  DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905    105
##4    3 DECT <= { Somewhat likely }; DECT <= { Not likely } 9.833333     18
##                                          TruePath
##3   DECT > { Somewhat likely }; SYMPT > { Disagree }
##2  DECT > { Somewhat likely }; SYMPT <= { Disagree }
##1 DECT <= { Somewhat likely }; DECT > { Not likely }
##4                             DECT <= { Not likely }
##
##$Table
##  Node                                               Path    Means Counts
##3    7   DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316    114
##2    6  DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000    175
##1    4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905    105
##4    3                             DECT <= { Not likely } 9.833333     18
##
##$SQL
##[1] " CASE  WHEN  DECT > ('Somewhat likely') AND  SYMPT > ('Disagree')  THEN 'Nodo 7' WHEN  DECT > ('Somewhat likely') AND  SYMPT <= ('Disagree')  THEN 'Nodo 6' WHEN  DECT <= ('Somewhat likely') AND  DECT > ('Not likely')  THEN 'Nodo 4' WHEN  DECT <= ('Not likely')  THEN 'Nodo 3'  END "

#With unordered factors
TreeModel2 = ctree(count~spray, data = InsectSprays)
plot(TreeModel2, type="simple")
Result2 <- CtreePathFuncAllCat(TreeModel2)
Result2
##$PreTable
##Node                                  Path     Means Counts            TruePath
##2    5 spray = { C, D, E }; spray = { C, E }  2.791667     24    spray = { C, E }
##3    4    spray = { C, D, E }; spray = { D }  4.916667     12       spray = { D }
##1    2                   spray = { A, B, F } 15.500000     36 spray = { A, B, F }
##
##$Table
##Node                Path     Means Counts
##2    5    spray = { C, E }  2.791667     24
##3    4       spray = { D }  4.916667     12
##1    2 spray = { A, B, F } 15.500000     36
##
##$SQL
##[1] " CASE  WHEN  spray = ('C','E')  THEN 'Nodo 5' WHEN  spray = ('D')  THEN 'Nodo 4' WHEN  spray = ('A','B','F')  THEN 'Nodo 2'  END "

#With continuous variables
airq <- subset(airquality, !is.na(Ozone))
TreeModel3 <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
Result2 <- CtreePathFuncAllCat(TreeModel3)
Result2
##$PreTable
##  Node                                           Path    Means Counts
##1    5 Temp <= { 82 }; Wind > { 6.9 }; Temp <= { 77 } 18.47917     48
##3    6  Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286     21
##4    9                 Temp > { 82 }; Wind > { 10.3 } 48.71429      7
##2    3                Temp <= { 82 }; Wind <= { 6.9 } 55.60000     10
##5    8                Temp > { 82 }; Wind <= { 10.3 } 81.63333     30
##                                     TruePath
##1                Temp <= { 77 }; Wind > { 6.9 }
##3 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 }
##4                Temp > { 82 }; Wind > { 10.3 }
##2               Temp <= { 82 }; Wind <= { 6.9 }
##5               Temp > { 82 }; Wind <= { 10.3 }
##
##$Table
##  Node                                          Path    Means Counts
##1    5                Temp <= { 77 }; Wind > { 6.9 } 18.47917     48
##3    6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286     21
##4    9                Temp > { 82 }; Wind > { 10.3 } 48.71429      7
##2    3               Temp <= { 82 }; Wind <= { 6.9 } 55.60000     10
##5    8               Temp > { 82 }; Wind <= { 10.3 } 81.63333     30
##
##$SQL
##[1] " CASE  WHEN  Temp <= (77) AND  Wind > (6.9)  THEN 'Nodo 5' WHEN  Temp <= (82) AND  Wind > (6.9) AND  Temp > (77)  THEN 'Nodo 6' WHEN  Temp > (82) AND  Wind > (10.3)  THEN 'Nodo 9' WHEN  Temp <= (82) AND  Wind <= (6.9)  THEN 'Nodo 3' WHEN  Temp > (82) AND  Wind <= (10.3)  THEN 'Nodo 8'  END "

更新!现在函数支持分类变量和数值变量的混合!

flseospp

flseospp4#

CtreePathFunc函数以Hadley诗句(我认为更容易理解)的方式重写。也处理分类变量。

library(magrittr)
readSplitter <- function(nodeSplit){
  splitPoint <- nodeSplit$splitpoint
  if("levels" %>% is_in(splitPoint %>% attributes %>% names)){
    splitPoint %>% attr("levels") %>% .[splitPoint]
  }else{
    splitPoint %>% as.numeric
  }
}

hasWeigths <- function(ct, path, terminalNode, pathNumber){
  ct %>%
    nodes(pathNumber %>% equals(path %>% length) %>% ifelse(terminalNode, path[pathNumber + 1]) ) %>%
    .[[1]] %>% use_series("weights") %>% as.logical %>% which
}

dataFilter <- function(ct, dts, path, terminalNode, pathNumber){
  whichWeights <- hasWeigths(ct, path, terminalNode, pathNumber)
  nodes(ct, path[pathNumber])[[1]][[5]] %>%
    buildDataFilter(dts, whichWeights)
}

buildDataFilter <- function(nodeSplit, ...) UseMethod("buildDataFilter")

buildDataFilter.nominalSplit <-
  function(nodeSplit, dts, whichWeights){
    varName <- nodeSplit$variableName
    includedLevels <- dts[ whichWeights
                          ,varName] %>% unique
    paste( varName, "=="
          ,includedLevels %>% paste(collapse = ", ") %>% paste0("{", ., "}"))
  }

buildDataFilter.orderedSplit <-
  function(nodeSplit, dts, whichWeights){
    varName <- nodeSplit$variableName
    splitter <- nodeSplit %>% readSplitter

    dts[ whichWeights
        ,varName] %>%
          is_weakly_less_than(splitter) %>%
          all %>%
          ifelse("<=" ,">") %>%
          paste(varName, ., splitter)
}

readTerminalNodePaths <- function (ct, dts) {

  nodeWeights <- function(Node) nodes(ct, Node)[[1]]$weights
  sgmnts <- ct %>% where %>% unique
  nodesFirstTreeWeightIsOne <- function(node) nodes(ct, node)[[1]][2][[1]] == 1

  # Take the inner nodes smaller than the selected terminal node
  innerNodes <-
    function(Node) setdiff( 1:(Node - 1)
                           ,sgmnts[sgmnts < Node])
  pathForTerminalNode <- function(terminalNode){
    innerNodes(terminalNode) %>%
      sapply(function(innerNode){
        if(any(nodeWeights(terminalNode) & nodesFirstTreeWeightIsOne(innerNode))) innerNode
       }) %>%
      unlist
  }

  # Find the splits criteria
  sgmnts %>% sapply(function(terminalNode){ #

    path <- terminalNode %>% pathForTerminalNode

    path %>% length %>% seq %>%
      sapply(function(nodeNumber){
        dataFilter(ct, dts, path, terminalNode, nodeNumber)
       }, simplify = FALSE) %>%
      unlist %>% paste(collapse = " & ") %>%
      data.frame(Node = terminalNode, Path = .)

  }, simplify = FALSE) %>%
    Reduce(f = rbind)
}

测试

shiftFirstPart <- function(vctr, divideBy, proportion = .5){
    vctr[vctr %>% length %>% multiply_by(proportion) %>% round %>% seq] %<>% divide_by(divideBy)
  vctr
}
set.seed(11)
n <- 13000
gdt <- 
  data.frame( is_buyer = runif(n) %>% shiftFirstPart(1.5) %>% round %>% factor(labels = c("no", "yes"))
             ,age = runif(n) %>% shiftFirstPart(1.5) %>%
               cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, ordered_result = TRUE, labels = c("low", "mid", "high"))
             ,city = runif(n) %>% shiftFirstPart(1.5) %>%
               cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, labels = c("Chigaco", "Boston", "Memphis"))
             ,point = runif(n) %>% shiftFirstPart(1.2)
             )

gct <- ctree( is_buyer ~ ., data = gdt)
readTerminalNodePaths(gct, gdt)
oxiaedzo

oxiaedzo5#

这是我的解决方案,它是递归的,所以它在更大的树上也运行得很好。下面是一个demo,后面是代码本身:

library(partykit)
# Make a mixture of continuous and categorical vars:
a <- ctree(Sepal.Length ~ Species + Sepal.Width, data = iris)
partykit_res <- partykit:::.list.rules.party(a)
# printing partykit_res you'll see it contains all splits
# on the path, not just the resulting final splits.
suggested_res <- tree_leaves_segments(a)

suggested_res

[1] "Sepal.Width <= 3.2, Species in {setosa}"         
[2] "3.6 >= Sepal.Width > 3.2, Species in {setosa}"   
[3] "Sepal.Width > 3.6, Species in {setosa}"          
[4] "Sepal.Width <= 2.7, Species in {versicolor}"     
[5] "Sepal.Width <= 2.7, Species in {virginica}"      
[6] "Sepal.Width > 2.7, Species in {versicolor}"      
[7] "2.8 >= Sepal.Width > 2.7, Species in {virginica}"
[8] "Sepal.Width > 2.8, Species in {virginica}"

下面是实现代码:

tree_leaves_segments <- function(tree) {
  group_df_init <- setNames(data.frame(matrix(ncol = 3, nrow = 0)), nm = c("var", "op", "val"))
  parent <- list(tree = tree, group_df = group_df_init)
  leaf_segments <- parent_child_dfs(parent)
  leaf_segments <- leaf_segments[seq(2, length(leaf_segments), 2)]
  if (length(leaf_segments) == 1) { # root
    leaf_segments <- list(group_df = data.frame(var = "root", op = "", val = ""))
  }
  leaf_segments <- sapply(leaf_segments, function(leaf) {
    collapsed_groups <- tapply(1:nrow(leaf), leaf$var, function(idx) {
      g <- leaf[idx, ]
      return(collapse_group_df(g))
    })
    merged_collapsed_groups <- paste0(collapsed_groups, collapse = ", ")
    return(merged_collapsed_groups)
  })
  return(unname(leaf_segments))
}

collapse_group_df <- function(g) {
  if (length(g$var) == 2) {
    upper <- which(substr(g$op, 1, 1) == "<")
    upper_op <- sub(pattern = "<", replacement = ">", x = g$op[upper])
    upper_op <- trimws(upper_op)
    lower <- which(substr(g$op, 1, 1) == ">")
    lower_op <- trimws(g$op[lower])
    return(paste(g$val[upper], upper_op, g$var[1], lower_op, g$val[lower], sep = " "))
  } else {
    if (g$op == "in") {
      prefix <- "{"
      suffix <- "}"
    } else {
      prefix <- ""
      suffix <- ""
      g$op <- trimws(g$op)
    }
    return(trimws(paste0(g$var, " ", g$op, " ", prefix, g$val, suffix)))
  }
}

parent_child_dfs <- function(parent) {
  split <- parent$tree[[1]]$node$split
  if (is.null(split)) {
    return(parent)
  } else {
    slabs <- partykit::character_split(split, data = parent$tree[[1]]$data)
    dir <- substr(slabs$levels, 1, 2)
    if (all(dir %in% c("< ", "> ", "<=", ">="))) {
      val <- sub("^<= |^>= |^< |^> ", "", slabs$levels)
    } else {
      val <- slabs$levels
      dir <- rep("in", length(dir))
    }
    group_dfs <- lapply(1:length(val), function(i) {
      data.frame(var = slabs$name, op = dir[i], val = val[i])
    })
    nodes <- lapply(1:length(group_dfs), function(i) {
      group_df <- merge_group_dfs(group_dfs[[i]], parent$group_df)
      kid_id <- parent$tree$node$kids[[i]]$id
      node <- list(tree = parent$tree[[kid_id]], group_df = group_df)
      return(node)
    })
    nodes <- unlist(lapply(nodes, parent_child_dfs), recursive = F)
    return(nodes)
  }
}

merge_group_dfs <- function(group_df, parent_group_df) {
  match_idx <- which(group_df$var == parent_group_df$var)
  if (length(match_idx) == 0) {
    ans <- rbind(parent_group_df, group_df)
  } else {
    if (group_df$op == "in") {
      parent_group_df$val[match_idx] <- group_df$val
    } else {
      for (idx in match_idx) {
        if (group_df$op == parent_group_df$op[[idx]]) {
          parent_group_df$val[[idx]] <- group_df$val
        } else if (length(match_idx) == 1) {
          parent_group_df <- rbind(parent_group_df, group_df)
        }
      }
    }
    ans <- parent_group_df
  }
  return(ans)
}

相关问题