Scala实现朴素贝叶斯算的改正(一)

来源:互联网 发布:淘宝上最好卖的东西 编辑:程序博客网 时间:2024/06/02 10:12

前面写了用Scala实现朴素贝叶斯算法,其实是不完善的,没有真正的实现。
简单说,朴素贝叶斯算法,就是个体的每个属性都是相互独立的,计算他们的概率,这个不细说,可以参看:http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html
实现贝叶斯算法需要分两步走:
首先训练样本得到模型,然后根据模型,对个体的行为进行判断预测。
模型返回的实际就是各个属性影响事件发生的概率,把他们相乘,与不会发生的概率比较即可。

object MyNaiveBayes2 {  var class1count:Float=0  var class2count:Float=0  def main(args: Array[String]) {   //changeArray("file///F:/1/buycomputer.txt")   /* var tt=makefeature("file///F:/1/buycomputer1.txt")    for(i<-0 to tt.length-1){      println("*********"+tt(i))    }*/    changeArray("file///F:/1/buycomputer.txt","file///F:/1/buycomputer3.txt")  }  def forecast(url1:String,url2:String)={    val conf =new SparkConf().setAppName("WordCount1").setMaster("local");    val sc = new SparkContext(conf)  }  def changeArray(url:String,url2:String)={    val conf =new SparkConf().setAppName("WordCount").setMaster("local");    val sc = new SparkContext(conf)    val data = sc.textFile(url)    val feature=data.map(_.split(","))    val feature1=feature.toArray()    val aa=ArrayBuffer[Float]()    val aa1=ArrayBuffer[Float]()    val bb=ArrayBuffer[Float]()    val bb1=ArrayBuffer[Float]()    for(j <- 0 to feature1.length-1){//这个循环的是有多少个,个体    val feature2=feature1(j)     // println("--------------"+feature2(1))//.toString.split(" ").toList(0)+"==========")      val feature3=feature2(1).toString.split(" ").toList      var count=0;      for(i <-0 to feature3.length-1){//循环每个个体的属性        if("1".equals(feature3(i))&& 0==i) {          if (aa.length <= 0) {            aa += 1            aa1+=1            count+=1          } else {            count+=1            aa(0) += 1            aa1(0)+=1          }        }          if(i>0){            if("1".equals(feature3(i))&& "1".equals(feature3(0))&& count==i){              if(aa.length-1<i){                count+=1                aa+=1              }else{                count+=1                aa(i)+=1              }            }            if("0".equals(feature3(i))&& "1".equals(feature3(0))&& count==i){              if(aa1.length-1<i){                count+=1                aa1+=1              }else{                count+=1                aa1(i)+=1              }            }          }        if("0".equals(feature3(i))&& i==0) {          if (bb.length <= 0) {            bb += 1            bb1 +=1            count+=1          } else {            count+=1            bb(0) += 1            bb1(0)+=1          }        }        if(i>0){          if("1".equals(feature3(i))&& "0".equals(feature3(0))&& count==i){            if(bb.length-1<i){              bb+=1              count+=1            }else{              count+=1              bb(i)+=1            }          }          if("0".equals(feature3(i))&& "0".equals(feature3(0))&& count==i){            if(bb1.length-1<i){              bb1+=1              count+=1            }else{              count+=1              bb1(i)+=1            }          }        }        }    }   // print("****"+bb(1))    //计算会买商品的比例   var allbuy:Float=0    var buy:Float=1    var pb:Float=0    var paa=new Array[Float](aa.length-1)    var paa1=new Array[Float](aa1.length-1)    var pbb=new Array[Float](bb.length-1)    var pbb1=new Array[Float](bb1.length-1)    if(aa.length>0){     allbuy=aa(0)/feature1.length    for(j<-1 to aa.length-1){      paa(j-1)=aa(j)/aa(0)    }      for(j<- 1 to aa1.length-1){        paa1(j-1)=aa1(j)/aa1(0)      }    // pb=buy*allbuy    }    //println("****"+pb)    //计算不会买商品的比例    var noallbuy:Float=0    var buyno:Float=1    var pbn:Float=0   if(bb.length>0){     noallbuy=bb(0)/feature1.length    for(y<-1 to bb.length-1){    //  buyno*=(bb(y)/bb(0))      pbb(y-1)=bb(y)/bb(0)    }     for(y<-1 to bb1.length-1){       pbb1(y-1)=bb1(y)/bb1(0)     }    // pbn= buyno*noallbuy   }   // println("========"+pbn)  var pp=new Array[Float](2)    pp(0)=allbuy    pp(1)=noallbuy    var p= new Array[Array[Float]](5)    p(0)=pp    p(1)=paa    p(2)=paa1    p(3)=pbb    p(4)=pbb1   /* if(pb>pbn){      println("会买电脑")    }else{      println("不会买电脑")    }*/    /*for(i<-0 to p.length-1){      var b=p(i)      println()      for(j<-0 to b.length-1){        print(b(j)+" ")      }    }*/    val data1 = sc.textFile(url2)    val forecastfeature=data1.map(_.split(","))    val forecastfeature1=forecastfeature.toArray()     val result=ArrayBuffer[Int]()    for(j <- 0 to forecastfeature1.length-1) {      //这个循环的是有多少个,个体      val forecastfeature2 = forecastfeature1(j)     // println("--------------" + forecastfeature2(1)+"----"+paa.length) //.toString.split(" ").toList(0)+"==========")      val forecastfeature3 = forecastfeature2(1).toString.split(" ").toList      var forecastcount = 0;      var py:Float=1      var pny:Float=1      /*for(i<-0 to forecastfeature3.length - 1){        println(forecastfeature3(i)+"====")      }*/      for (i <- 1 to forecastfeature3.length - 1) {        //计算会买电脑的概率      if(forecastfeature3(i).equals("1")){     //   println(paa(i-1)+"***"+forecastfeature3.length)        py*=paa(i-1)      }else{        py*=paa1(i-1)      }      }      py*=pp(0)      for (i <- 1 to forecastfeature3.length - 1) {        //计算不会买电脑的概率        if(forecastfeature3(i).equals("1")){          pny*=pbb(i-1)        }else{          pny*=pbb1(i-1)        }      }      pny*=pp(1)    if(py>pny){      result+=1    }else{      result+=0    }    }    for(i <- 0 to result.length-1){      println((i+1)+"***"+result(i))    }  }}
0 0