0

我正在尝试使用 R 中 stringdist 包中的 stringdist 函数计算余弦相似度。我想通过计算与baseline_dt 的每一行的余弦相似度并取所有值的平均值来获得 score_dt 中每一行的平均余弦相似度。我使用下面的代码成功地获得了结果。但是,我正在寻找效率更高的代码,因为对于大型数据集,下面的嵌套 for 循环非常慢。

 baseline_dt <- read.table(text="id Product.Group.Code   R1   R2   R3   R4   S1   S2   S3   U1   U2   U3 U4 U6
    91  65418                164 0.68 0.70 0.50 0.59   NA   NA 0.96   NA 0.68   NA NA NA
    93  57142                164   NA 0.94   NA   NA 0.83   NA   NA 0.54   NA   NA NA NA
    99  66740                164 0.68 0.68 0.74   NA 0.63 0.68 0.72   NA   NA   NA NA NA
    100 76712                164 0.54 0.54 0.40   NA 0.39 0.39 0.39 0.50   NA 0.50 NA NA
    101 56463                164 0.67 0.67 0.76   NA   NA 0.76 0.76 0.54   NA   NA NA NA
    125 11713                164   NA   NA   NA   NA   NA 0.88   NA   NA   NA   NA NA NA",header=TRUE)


 scoring_dt <- read.table(text="id Product.Group.Code   R1   R2   R3   R4   S1   S2   S3   U1   U2   U3 U4 U6
11  999                164 0.68 0.70 0.50 0.59   0.7   NA 0.96   NA 0.68   NA NA NA
22  555                164   0 0.94   0   NA 0.83   0.6   NA 0.54   NA   NA NA NA",header=TRUE)

请在下面找到 R 代码。

dc  <- setNames(data.frame(matrix(ncol = 3, nrow = 0)), c("baseline_id", "scoring_id", "cosine_score"))
    dt  <- setNames(data.frame(matrix(ncol = 2, nrow = 0)), c("scoring_id", "Avg_cosine_score"))
    predictor <- c("R1" ,"R2" ,"R3" ,"R4", "S1", "S2", "S3", "U1", "U2" ,"U3", "U4" ,"U6")

    id <-"id"
    baseline_dt <- data.table::setDT(baseline_dt)
    scoring_dt <- data.table::setDT(scoring_dt)

    for(i in 1:length(scoring_dt[[id]])){

      for(j in 1:length(baseline_dt[[id]])){

        dc[j,1] <- baseline_dt[[id]][j]
        dc[j,2] <- scoring_dt[[id]][i]
        cos <- stringdist::stringdist(as.character(baseline_dt[ ,predictor ,with=F][j,]),as.character(scoring_dt[,predictor,with=F][i,]),
                                      method=method,nthread=8)
        cos[is.na(cos)] <- 0
        dc[j,3] <- 1-mean(cos)
      }
      dt[i,1] <- scoring_dt[[id]][i]
      dt[i,2] <- mean(dc[,3])
    }

    View(dt)

我希望将我的代码转换为更高效的代码。我已经尝试过 foreach 并行循环,但似乎没有任何东西可以加快我的代码速度。

**注意 - 我有混合数据字符以及二进制(0 和 1),这就是我使用 stringdist 函数的原因。我不能使用 lsa 包中的余弦函数。

4

0 回答 0