package de.ismll.hylap;

import de.ismll.core.InstanceUtils;
import de.ismll.core.Instances;
import de.ismll.core.Random;
import de.ismll.hylap.acquisitionFunction.ExpectedImprovement;
import de.ismll.hylap.surrogateModel.FNNE;
import de.ismll.hylap.surrogateModel.MLPE;
import de.ismll.hylap.surrogateModel.SurrogateModel;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import org.apache.commons.math3.stat.descriptive.moment.Mean;

/* loaded from: input_file:de/ismll/hylap/SMBOMain.class */
public class SMBOMain {
    public static void help() {
        System.out.println("-f\tPath to the folder where your datasets are stored.\n-dataset\tName of the dataset to evaluate.\n-tries\tNumber of steps for the SMBO algorithm.\n-iter\tNumber of iterations, results are averaged.\n-output\tThe location where the results shall be saved.\n-seed\tRandom seed (Default: 0, Random: r)\n-learn\t Learning rate for all models\n-mom\t Momentum term\n-k\t Number of latent features for fmlp\n-nodes\t Number of nodes\n-layer\t Number of layers\n-models\t Number of models in ensemble (Default: 100)\n-epochs\t Number of epochs\n-s\tThe surrogate model, \"mlp\" for Multilayer Perceptron, or \"fmlp\" for Factorized Multilayer Perceptron");
        System.exit(0);
    }

    public static void main(String[] strArr) throws IOException {
        HashMap hashMap = new HashMap();
        int i = 0;
        while (i < strArr.length) {
            String str = strArr[i];
            int i2 = i + 1;
            hashMap.put(str, strArr[i2]);
            i = i2 + 1;
        }
        int parseInt = hashMap.containsKey("-models") ? Integer.parseInt((String) hashMap.get("-models")) : 100;
        double parseDouble = hashMap.containsKey("-learn") ? Double.parseDouble((String) hashMap.get("-learn")) : 0.01d;
        double parseDouble2 = hashMap.containsKey("-mom") ? Double.parseDouble((String) hashMap.get("-mom")) : 0.01d;
        int parseInt2 = hashMap.containsKey("-k") ? Integer.parseInt((String) hashMap.get("-k")) : 5;
        int parseInt3 = hashMap.containsKey("-nodes") ? Integer.parseInt((String) hashMap.get("-nodes")) : 5;
        int parseInt4 = hashMap.containsKey("-layer") ? Integer.parseInt((String) hashMap.get("-layer")) : 5;
        int parseInt5 = hashMap.containsKey("-epochs") ? Integer.parseInt((String) hashMap.get("-epochs")) : 100;
        if (!hashMap.containsKey("-dataset") || !hashMap.containsKey("-f") || !hashMap.containsKey("-tries") || !hashMap.containsKey("-iter")) {
            help();
        }
        String str2 = (String) hashMap.get("-dataset");
        String str3 = (String) hashMap.get("-f");
        File[] listFiles = new File(str3).listFiles();
        int parseInt6 = Integer.parseInt((String) hashMap.get("-tries"));
        int parseInt7 = Integer.parseInt((String) hashMap.get("-iter"));
        if (!new File(String.valueOf(str3) + "/" + str2).exists()) {
            System.out.println("Data set " + str2 + " does not exist in folder " + new File(str3).getAbsolutePath() + ".");
            System.exit(1);
        }
        if (!hashMap.containsKey("-seed")) {
            Random.setSeed(0L);
        } else if (((String) hashMap.get("-seed")).equals("r")) {
            Random.setSeed(System.currentTimeMillis());
        } else {
            Random.setSeed(Long.parseLong((String) hashMap.get("-seed")));
        }
        File file = hashMap.get("-output") == null ? null : new File((String) hashMap.get("-output"));
        System.out.println("Loading data sets from " + new File(str3).getAbsolutePath() + ".");
        Instances[] instancesArr = new Instances[listFiles.length - 1];
        int i3 = -1;
        for (int i4 = 0; i4 < listFiles.length; i4++) {
            if (listFiles[i4].getName().equals(str2)) {
                i3 = i4;
            }
        }
        int i5 = 0;
        for (int i6 = 0; i6 < listFiles.length; i6++) {
            if (i6 != i3) {
                int i7 = i5;
                i5++;
                instancesArr[i7] = new Instances(listFiles[i6]);
            }
        }
        Instances instances = new Instances(listFiles[i3]);
        if (1 != 0) {
            for (int i8 = 0; i8 < instancesArr.length; i8++) {
                instancesArr[i8] = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instancesArr[i8], instancesArr.length + 1, i8);
            }
            instances = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instances, instancesArr.length + 1, instancesArr.length);
        }
        System.out.println("Starting the SMBO framework.");
        double[][] dArr = new double[parseInt6][parseInt7];
        double[][] dArr2 = new double[parseInt6][parseInt7];
        double[] dArr3 = new double[parseInt6];
        for (int i9 = 0; i9 < parseInt7; i9++) {
            System.out.println("Starting iteration " + (i9 + 1) + ".");
            ExpectedImprovement expectedImprovement = new ExpectedImprovement();
            SurrogateModel surrogateModel = null;
            if (!hashMap.containsKey("-s")) {
                surrogateModel = null;
            } else if (((String) hashMap.get("-s")).equals("fmlp")) {
                System.out.println("Will learn a factorized multilayer perceptron.");
                surrogateModel = new FNNE(instancesArr, parseInt, parseDouble, parseDouble2, parseInt3, parseInt4, parseInt5, parseInt2);
            } else if (((String) hashMap.get("-s")).equals("mlp")) {
                System.out.println("Will learn a multilayer perceptron.");
                surrogateModel = new MLPE(instancesArr, parseInt, parseDouble, parseDouble2, parseInt3, parseInt4, parseInt5);
            } else {
                System.out.println("Unknown surrogate function \"" + ((String) hashMap.get("-s")) + "\"");
                System.exit(1);
            }
            SMBO smbo = new SMBO(instances, expectedImprovement, surrogateModel);
            for (int i10 = 0; i10 < parseInt6; i10++) {
                if (i10 <= 0 || dArr2[i10 - 1][i9] != 1.0d) {
                    smbo.iterate();
                    System.out.println("Trial " + i10 + " has been completed.");
                    dArr[i10][i9] = smbo.getBestAccuracy();
                    dArr2[i10][i9] = smbo.getBestRank();
                } else {
                    dArr[i10][i9] = dArr[i10 - 1][i9];
                    dArr2[i10][i9] = dArr2[i10 - 1][i9];
                }
            }
        }
        Mean mean = new Mean();
        System.out.println("Printing results to console.");
        System.out.println("Accuracy(mean),Rank(mean)");
        for (int i11 = 0; i11 < parseInt6; i11++) {
            System.out.println(String.valueOf(mean.evaluate(dArr[i11])) + "," + mean.evaluate(dArr2[i11]));
        }
    }
}
