Call Spark mllib linear regression to print weights and other coefficients are all NaN

  1. use spark mllib linear regression to do traffic forecast printing training, weight and other coefficients are all NaN

data format:
520221 | 0009 | 0009 | 292 | 000541875150 | 2018 | 04 | 18 | 11 | 3 | 137
520626 | 0038 | 0038 | 520626 | 203030001000 | 2018 | 04 | 18 | 3 | 119
520621 | 0024 | 0024 | 005 | 000530002050 | 2018 | 04 | 18 | 11 | 3 | 91

the last item is labeled traffic flow

2. The code is as follows:
package com.spark;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import scala.Tuple2;

public class CarPassRegression {

public static void main(String[] args){

    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR);//
    //
    SparkConf conf= new SparkConf();
    conf.setAppName("pass_regression").setMaster("local[*]")
            .set("spark.sql.warehouse.dir","file:///");

    JavaSparkContext sc = new JavaSparkContext(conf);

    String trainDataPath ="E://test_data//target_carPass//traindata//*";
    //
    JavaRDD<String> rdd= sc.textFile(trainDataPath);

    JavaRDD<LabeledPoint> traindata=rdd.map(new Function<String, LabeledPoint>() {
        @Override
        public LabeledPoint call(String s) throws Exception {

            String [] part = s.split("\\|");
            //label
            double lable =Double.parseDouble(part[part.length-1]);

            double [] features = new double[part.length-1];
            for(int i=0;i<features.length;iPP){
                features[i] =Double.parseDouble(part[i]);
            }

            return new LabeledPoint(lable, Vectors.dense(features));
        }
    });

    traindata.cache();

    /*
    
     */
    int numIterations = 10000;  //
    double stepSize = 0.000001;//

    final LinearRegressionModel model= LinearRegressionWithSGD.
            train(JavaRDD.toRDD(traindata),numIterations,stepSize);

    System.out.println(model.weights()); //

    //
    JavaRDD<Tuple2<Double, Double>> valuesAndPreds = traindata.map(
            new Function<LabeledPoint, Tuple2<Double, Double>>(){
                public Tuple2<Double, Double> call(LabeledPoint point){
                    double prediction = model.predict(point.features());
                    return new Tuple2<Double, Double>(prediction, point.label());
                }
            }
    );

    //
    double MSE = new JavaDoubleRDD(valuesAndPreds.map(
            new Function<Tuple2<Double, Double>, Object>(){
                public Object call(Tuple2<Double, Double> pair){
                    return Math.pow(pair._1() - pair._2(), 2.0);
                }
            }
    ).rdd()).mean();
    System.out.println("training MeanSquared Error = " + MSE);

    //
    JavaRDD<Tuple2<Object, Object>>  valuesAndPreds2= traindata.map(new Function<LabeledPoint, Tuple2<Object, Object>>(){
        public Tuple2<Object, Object> call(LabeledPoint point)
                throws Exception {
            double prediction = model.predict(point.features());
            return new Tuple2<Object, Object>(prediction, point.label());
        }

    });
    RegressionMetrics metrics = new RegressionMetrics(JavaRDD.toRDD(valuesAndPreds2));
    System.out.println("R2()= "+metrics.r2());
    System.out.println("MSE() = "+metrics.meanSquaredError());
    System.out.println("RMSE() "+metrics.rootMeanSquaredError());
    System.out.println("MAE()= "+metrics.meanAbsoluteError());

    // 
    model.save(sc.sc(), "target/tmp/carPassLinearRegressionWithSGDModel");

    LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(),
            "target/tmp/carPassLinearRegressionWithSGDModel");




}

}
3. Result:

clipboard.png

Menu