R语言 高效地匹配一个向量在另一个向量中的所有值

b5lpy0ml  于 2023-06-03  发布在  其他
关注(0)|答案(5)|浏览(259)

我正在寻找一种 * 高效 * 的方法,来匹配向量y中向量x的所有值,而不仅仅是match()返回的第一个位置。我所追求的本质上是pmatch()的默认行为,但没有部分匹配:

x <- c(3L, 1L, 2L, 3L, 3L, 2L)
y <- c(3L, 3L, 3L, 3L, 1L, 3L)

预期输出:

pmatch(x, y)  
[1]  1  5 NA  2  3 NA

一种方法是使用ave(),但是随着组数量的增加,这会变得很慢并且内存效率很低:

ave(x, x, FUN = \(v) which(y == v[1])[1:length(v)])
[1]  1  5 NA  2  3 NA

有人能推荐一种有效的方法来实现这一点,最好(但不是强制性的)基于R?
用于基准测试的更大数据集:

set.seed(5)
x <- sample(5e3, 1e5, replace = TRUE)
y <- sample(x, replace = TRUE)
6bc51xsx

6bc51xsx1#

  • base* 中使用split的变体。
x <- c(3L, 1L, 2L, 3L, 3L, 2L)
y <- c(3L, 3L, 3L, 3L, 1L, 3L)

a <- split(seq_along(x), x)
b <- split(seq_along(y), y)[names(a)]
b[lengths(b)==0] <- NA
. <- do.call(rbind, Map(\(a, b) cbind(a, b[seq_along(a)]), a, b))
`[<-`(.[,2], .[,1], .[,2])
#[1]  1  5 NA  2  3 NA

或修改:

a <- split(seq_along(x), x)
b <- split(seq_along(y), y)[names(a)]
b[lengths(b)==0] <- NA
b <- unlist(Map(`length<-`, b, lengths(a)), FALSE, FALSE)
`[<-`(b, unlist(a, FALSE, FALSE), b)
#[1]  1  5 NA  2  3 NA

RCPP版本可能如下所示:

Rcpp::sourceCpp(code=r"(
#include <Rcpp.h>
#include <unordered_map>
#include <queue>

using namespace Rcpp;
// [[Rcpp::export]]
IntegerVector pm(const std::vector<int>& a, const std::vector<int>& b) {
  IntegerVector idx(no_init(a.size()));
  std::unordered_map<int, std::queue<int> > lut;
  for(int i = 0; i < b.size(); ++i) lut[b[i]].push(i);
  for(int i = 0; i < idx.size(); ++i) {
    auto search = lut.find(a[i]);
    if(search != lut.end() && search->second.size() > 0) {
      idx[i] = search->second.front() + 1;
      search->second.pop();
    } else {idx[i] = NA_INTEGER;}
  }
  return idx;
}
)")
pm(x, y)
#[1]  1  5 NA  2  3 NA

基准

set.seed(5)
x <- sample(5e3, 1e5, replace = TRUE)
y <- sample(x, replace = TRUE)

library(data.table)

matchall <- function(x, y) {
  data.table(y, rowid(y))[
    data.table(x, rowid(x)), on = .(y = x, V2), which = TRUE
  ]
}

rmatch <- function(x, y) {
  xp <- cbind(seq_along(x), x)[order(x),]
  yp <- cbind(seq_along(y), y)[order(y),]
  result <- numeric(length(x))
  
  xi <- yi <- 1
  Nx <- length(x)
  Ny <- length(y)
  while (xi <= Nx) {
    if (yi > Ny) {
      result[xp[xi,1]] <- NA
      xi <- xi + 1
    } else if (xp[xi,2] == yp[yi,2]) {
      result[xp[xi,1]] = yp[yi,1]
      xi <- xi + 1
      yi <- yi + 1
    } else if (xp[xi,2] < yp[yi,2]) {
      result[xp[xi,1]] <- NA
      xi <- xi + 1
    } else if (xp[xi,2] > yp[yi,2]) {
      yi <- yi + 1
    }
  }
  result  
}

bench::mark(
ave = ave(x, x, FUN = \(v) which(y == v[1])[1:length(v)]),
rmatch = rmatch(x, y),
make.name = match(make.names(x, TRUE), make.names(y, TRUE)),
make.unique = match(make.unique(as.character(x)), make.unique(as.character(y))),
split = {a <- split(seq_along(x), x)
  b <- split(seq_along(y), y)[names(a)]
  b[lengths(b)==0] <- NA
  . <- do.call(rbind, Map(\(a, b) cbind(a, b[seq_along(a)]), a, b))
  `[<-`(.[,2], .[,1], .[,2])},
splitB = {a <- split(seq_along(x), x)
  b <- split(seq_along(y), y)[names(a)]
  b[lengths(b)==0] <- NA
  b <- unlist(Map(`length<-`, b, lengths(a)), FALSE, FALSE)
  `[<-`(b, unlist(a, FALSE, FALSE), b)},
data.table = matchall(x, y),
RCPP = pm(x, y) )

结果

expression       min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc
  <bch:expr>  <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>
1 ave            1.27s    1.27s     0.787    3.73GB    77.2      1    98
2 rmatch      272.36ms 273.06ms     3.66     5.34MB    45.8      2    25
3 make.name   144.37ms 148.71ms     6.75    14.06MB     1.69     4     1
4 make.unique  85.48ms  92.69ms    10.6      9.49MB     3.53     6     2
5 split        23.54ms  35.19ms    23.9       9.9MB    19.9     12    10
6 splitB       11.81ms  12.13ms    72.8      7.18MB    17.7     37     9
7 data.table    6.39ms   7.04ms   129.       5.13MB    25.8     65    13
8 RCPP          3.06ms   3.14ms   311.     393.16KB     3.98   156     2

在这种情况下,C++版本是最快的,分配的内存量最少。在使用 base 的情况下,splitB变体是最快的,rmatch分配的内存量最少。

dddzy1tm

dddzy1tm2#

需要指出的是,您可以使用match + make.unique来完成相同的任务。速度方面,它可能比data.table方法慢:

match(make.unique(as.character(x)), make.unique(as.character(y)))

[1]  1  5 NA  2  3 NA
match(make.names(x, TRUE), make.names(y, TRUE))
[1]  1  5 NA  2  3 NA
uelo1irk

uelo1irk3#

使用data.table连接,灵感来自this Q&A。

library(data.table)

matchall <- function(x, y) {
  data.table(y, rowid(y))[
    data.table(x, rowid(x)), on = .(y = x, V2), which = TRUE
  ]
}

检查行为

x <- c(3L, 1L, 2L, 3L, 3L, 2L)
y <- c(3L, 3L, 3L, 3L, 1L, 3L)

matchall(x, y)
#> [1]  1  5 NA  2  3 NA

较大向量上的时序:

set.seed(5)
x <- sample(5e3, 1e5, replace = TRUE)
y <- sample(x, replace = TRUE)

system.time(z1 <- matchall(x, y))
#>    user  system elapsed 
#>    0.06    0.00    0.01

system.time(z2 <- ave(x, x, FUN = \(v) which(y == v[1])[1:length(v)]))
#>    user  system elapsed 
#>    0.88    0.43    1.31

identical(z1, z2)
#> [1] TRUE
zvms9eto

zvms9eto4#

如果您有一些多余的内存,您可以通过对值进行排序来加速该过程,基本上可以执行两个指针遍历来匹配数据。这是我们的实验结果

rmatch <- function(x, y) {
  xp <- cbind(seq_along(x), x)[order(x),]
  yp <- cbind(seq_along(y), y)[order(y),]
  result <- numeric(length(x))
  
  xi <- yi <- 1
  Nx <- length(x)
  Ny <- length(y)
  while (xi <= Nx) {
    if (yi > Ny) {
      result[xp[xi,1]] <- NA
      xi <- xi + 1
    } else if (xp[xi,2] == yp[yi,2]) {
      result[xp[xi,1]] = yp[yi,1]
      xi <- xi + 1
      yi <- yi + 1
    } else if (xp[xi,2] < yp[yi,2]) {
      result[xp[xi,1]] <- NA
      xi <- xi + 1
    } else if (xp[xi,2] > yp[yi,2]) {
      yi <- yi + 1
    }
  }
  result  
}

我测试了一些其他基地的R选项张贴在这里

mbm <- microbenchmark::microbenchmark(
  ave = ave(x, x, FUN = \(v) which(y == v[1])[1:length(v)]),
  rmatch = rmatch(x, y),
  pmatch = pmatch(x, y),
  times = 20
)

看到它似乎表现不错

Unit: milliseconds
   expr        min         lq       mean     median         uq        max neval
    ave  1227.6743  1247.6980  1283.1024  1264.1485  1324.1569  1349.3276    20
 rmatch   198.1744   201.1058   208.3158   204.5933   209.4863   247.7279    20
 pmatch 39514.4227 39595.9720 39717.5887 39628.0892 39805.2405 40105.4337    20

这些都返回相同的值向量。

bwitn5fc

bwitn5fc5#

您可以简单地运行match + paste + ave

> do.call(match, lapply(list(x, y), \(v) paste(v, ave(v, v, FUN = seq_along))))
[1]  1  5 NA  2  3 NA

相关问题