我是 Apache Spark 的初学者。我目前正在开发一个机器学习程序,该程序需要迭代更新 RDD,然后从执行程序收集近 10KB 的数据到驱动程序。不幸的是,当它运行超过 600 次迭代时,我得到一个 StackOverFlow 错误!以下是我的代码。当迭代次数超过 400 时,collectAsMap 函数发生 stackoverflow 错误!其中 indexedDevF 和 indexedData 是 indexedRDD(由 AMPLab 作为库开发,提供https://github.com/amplab/spark-indexedrdd)
breakable{
while(bLow > bHigh + 2*tolerance){
indexedDevF = indexedDevF.innerJoin(indexedData){(id, a, b) => (b, a)}.mapValues( x => ( x._2 + alphaHighDiff * broad_y.value(iHigh) * kernel(x._1, dataiHigh) + alphaLowDiff * broad_y.value(iLow) * kernel(x._1, dataiLow) ) )
if (iteration % 50 == 0 ) {
indexedDevF.checkpoint()
}
indexedDevF.persist() // essential to get correct answer
val devFMap = indexedDevF.collectAsMap() //0.5s every time according to local:4040! here will stackoverflow
var min_value = Double.PositiveInfinity
var max_value = -min_value
var min_i = -1
var max_i = -1
i = 0
while( i < m ){
if(((y(i) > 0) && (alpha(i) < cEpsilon)) || ((y(i) < 0) && (alpha(i) > epsilon))){
if( devFMap(i) <= min_value){
min_value = devFMap(i)
min_i = i
}
}
if(((y(i) > 0) && (alpha(i) > epsilon)) || ((y(i) < 0) && (alpha(i) < cEpsilon))){
if( devFMap(i) >= max_value ){
max_value = devFMap(i)
max_i = i
}
}
i = i+1
}
iHigh = min_i
iLow = max_i
bHigh = devFMap(iHigh)
bLow = devFMap(iLow)
dataiHigh = indexedData.get(iHigh.toLong).get
dataiLow = indexedData.get(iLow.toLong).get
eta = 2 - 2 * kernel(dataiHigh, dataiLow)
alphaHighOld = alpha(iHigh)
alphaLowOld = alpha(iLow)
var alphaDiff = alphaLowOld - alphaHighOld
var lowLabel = y(iLow)
var sign = y(iHigh) * lowLabel
var alphaLowLowerBound = 0D
var alphaLowUpperBound = 0D
if (sign < 0){
if (alphaDiff < 0){
alphaLowLowerBound = 0;
alphaLowUpperBound = cost + alphaDiff;
}
else{
alphaLowLowerBound = alphaDiff;
alphaLowUpperBound = cost;
}
}
else{
var alphaSum = alphaLowOld + alphaHighOld;
if (alphaSum < cost){
alphaLowUpperBound = alphaSum;
alphaLowLowerBound = 0;
}
else{
alphaLowLowerBound = alphaSum - cost;
alphaLowUpperBound = cost;
}
}
if (eta > 0){
alphaLowNew = alphaLowOld + lowLabel*(bHigh - bLow)/eta;
if (alphaLowNew < alphaLowLowerBound)
alphaLowNew = alphaLowLowerBound;
else if (alphaLowNew > alphaLowUpperBound)
alphaLowNew = alphaLowUpperBound;
}
else{
var slope = lowLabel * (bHigh - bLow);
var delta = slope * (alphaLowUpperBound - alphaLowLowerBound);
if (delta > 0){
if (slope > 0)
alphaLowNew = alphaLowUpperBound;
else
alphaLowNew = alphaLowLowerBound;
}
else
alphaLowNew = alphaLowOld;
}
alphaLowDiff = alphaLowNew - alphaLowOld;
alphaHighDiff = -sign*(alphaLowDiff);
alpha(iLow) = alphaLowNew;
alpha(iHigh) = (alphaHighOld + alphaHighDiff);
if(iteration % 50 == 0)
print(".")
iteration = iteration + 1;
}
====================
原来的问题如下,我发现checkpoint没用,程序会以stackoverflow errer结束!!我写了一个测试简单的代码来描述我的问题。还好有好心人帮我解决问题,你可以在下面找到答案!但是,即使检查点确实有效,我的程序仍然会出现 stackoverflow 错误:(
for(i <- 1 to 1000){
a = a.map(x => x+1).persist
var b = a.collect()
if(i%100 == 0){
a.checkpoint()
}
print(".")
}