package de.ismll.hylap;

import de.ismll.core.InstanceUtils;
import de.ismll.core.Instances;
import de.ismll.core.regression.FactorizationMachineRegression;
import de.ismll.core.regression.FactorizedMultilayerPerceptron;
import de.ismll.core.regression.MultiLayerPerceptron;
import de.ismll.core.regression.Regression;
import de.ismll.core.regression.neuralnet.ActivationTanh;
import de.ismll.hylap.util.ErrorMetrics;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;
import java.util.StringTokenizer;
import weka.core.TestInstances;

/* loaded from: input_file:de/ismll/hylap/RecHypMain.class */
public class RecHypMain {
    public static void main(String[] strArr) throws IOException {
        if (strArr.length < 0) {
            System.out.println("-f\tPath to the datafolder where your datasets are stored.\n-d\t Dataset that will be used for testing\n-r\tThe regression model,\"fm\" (Factorization Machine), \"fmlp\" (Factorized Multilayer Perceptron) \"mlp\" (Multilayer Perceptron)\n-epochs\t number of epochs for training \n-stdev\t standard deviation of gaussian that initializes parameters for factorization machine \n-reg\t regularizer of factorization machine \n-k\t number of latent factors \n-nodes\t number of nodes per layer for mlp and fmlp \n-layer\t number of layers \n-mom\t momentum for learning \n-learn\t learn rate for all methods\n-split\t the split that will be used \n-num_splits\t the number of splits, default value is 10\n-p_train\t percentage of how much data will be used for training\n-p_val\t percentage of how much data will be used for validation\n");
            System.exit(0);
        }
        HashMap hashMap = new HashMap();
        int i = 0;
        while (i < strArr.length) {
            if (strArr[i].equals("-help") || strArr[i].equals("-h")) {
                hashMap.put(strArr[i], strArr[i]);
            } else {
                String str = strArr[i];
                i++;
                hashMap.put(str, strArr[i]);
            }
            i++;
        }
        if (hashMap.containsKey("-help") || hashMap.containsKey("-h")) {
            System.out.println("-f\tPath to the datafolder where your datasets are stored.\n-d\t Dataset that will be used for testing\n-r\tThe regression model,\"fm\" (Factorization Machine), \"fmlp\" (Factorized Multilayer Perceptron) \"mlp\" (Multilayer Perceptron)\n-epochs\t number of epochs for training \n-stdev\t standard deviation of gaussian that initializes parameters for factorization machine \n-reg\t regularizer of factorization machine \n-k\t number of latent factors \n-nodes\t number of nodes per layer for mlp and fmlp \n-layer\t number of layers \n-mom\t momentum for learning \n-learn\t learn rate for all methods\n-split\t the split that will be used \n-num_splits\t the number of splits, default value is 10\n-p_train\t percentage of how much data will be used for training\n-p_val\t percentage of how much data will be used for validation\n");
            System.exit(0);
        }
        int parseInt = hashMap.containsKey("-num_splits") ? Integer.parseInt((String) hashMap.get("-num_splits")) : 10;
        int parseInt2 = hashMap.containsKey("-epochs") ? Integer.parseInt((String) hashMap.get("-epochs")) : 1;
        double parseDouble = hashMap.containsKey("-p_train") ? Double.parseDouble((String) hashMap.get("-p_train")) : 0.8d;
        double parseDouble2 = hashMap.containsKey("-p_val") ? Double.parseDouble((String) hashMap.get("-p_val")) : 0.1d;
        int parseInt3 = hashMap.containsKey("-split") ? Integer.parseInt((String) hashMap.get("-split")) : 0;
        File[] fileArr = null;
        ArrayList arrayList = new ArrayList();
        if (hashMap.containsKey("-f")) {
            File file = new File((String) hashMap.get("-f"));
            fileArr = file.isDirectory() ? file.listFiles() : new File[]{file};
        } else {
            System.err.println("No dataset specified!");
            System.err.println("-f\tPath to the datafolder where your datasets are stored.\n-d\t Dataset that will be used for testing\n-r\tThe regression model,\"fm\" (Factorization Machine), \"fmlp\" (Factorized Multilayer Perceptron) \"mlp\" (Multilayer Perceptron)\n-epochs\t number of epochs for training \n-stdev\t standard deviation of gaussian that initializes parameters for factorization machine \n-reg\t regularizer of factorization machine \n-k\t number of latent factors \n-nodes\t number of nodes per layer for mlp and fmlp \n-layer\t number of layers \n-mom\t momentum for learning \n-learn\t learn rate for all methods\n-split\t the split that will be used \n-num_splits\t the number of splits, default value is 10\n-p_train\t percentage of how much data will be used for training\n-p_val\t percentage of how much data will be used for validation\n");
            System.exit(0);
        }
        shuffleFileArray(fileArr, 100L);
        for (File file2 : fileArr) {
            arrayList.add(file2.getName());
        }
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        if (hashMap.containsKey("-d")) {
            StringTokenizer stringTokenizer = new StringTokenizer((String) hashMap.get("-d"), ",");
            stringTokenizer.countTokens();
            while (stringTokenizer.hasMoreTokens()) {
                String nextToken = stringTokenizer.nextToken();
                if (arrayList.contains(nextToken)) {
                    arrayList2.add(nextToken);
                } else {
                    System.err.println("The dataset " + nextToken + " does not exist, therefore cannot be a target dataset!");
                }
            }
            for (int i2 = 0; i2 < fileArr.length; i2++) {
                if (arrayList2.contains(fileArr[i2].getName())) {
                    arrayList3.add(Integer.valueOf(i2));
                }
            }
            String str2 = "Predictions will be made for dataset: ";
            for (int i3 = 0; i3 < arrayList2.size(); i3++) {
                str2 = String.valueOf(str2) + ((String) arrayList2.get(i3)) + TestInstances.DEFAULT_SEPARATORS;
            }
            System.out.println(str2);
        }
        int length = fileArr.length;
        int i4 = 0;
        Instances[] instancesArr = new Instances[fileArr.length];
        for (int i5 = 0; i5 < fileArr.length; i5++) {
            instancesArr[i5] = new Instances(fileArr[i5]);
            i4 += instancesArr[i5].numInstances();
        }
        int numValues = instancesArr[0].numValues() + length;
        for (int i6 = 0; i6 < arrayList3.size(); i6++) {
            int intValue = ((Integer) arrayList3.get(i6)).intValue();
            instancesArr[intValue].computeSplits(parseInt, parseDouble, parseDouble2, intValue);
        }
        Instances[] trainInstances = getTrainInstances(instancesArr, arrayList3, true, parseInt3);
        Instances[] validationInstances = getValidationInstances(instancesArr, arrayList3, true, parseInt3);
        Instances[] predictInstances = getPredictInstances(instancesArr, arrayList3, true, parseInt3);
        Instances combineInstances = InstanceUtils.combineInstances(trainInstances);
        Instances combineInstances2 = InstanceUtils.combineInstances(predictInstances);
        Instances combineInstances3 = InstanceUtils.combineInstances(validationInstances);
        combineInstances.shuffle(1000);
        int numInstances = combineInstances.numInstances();
        int numInstances2 = combineInstances3.numInstances();
        int numInstances3 = combineInstances2.numInstances();
        System.out.println("Number of Train instances: " + numInstances);
        System.out.println("Number of Validation instances: " + numInstances2);
        System.out.println("Number of Predict instances: " + numInstances3);
        System.out.println("Number of all instances available: " + i4);
        Regression regression = null;
        if (((String) hashMap.get("-r")).equals("mlp")) {
            regression = new MultiLayerPerceptron(Double.parseDouble((String) hashMap.get("-learn")), Double.parseDouble((String) hashMap.get("-mom")), combineInstances.numValues(), Integer.parseInt((String) hashMap.get("-layer")), Integer.parseInt((String) hashMap.get("-nodes")), new ActivationTanh());
            System.out.println("Will learn a multilayer perceptron...");
        } else if (((String) hashMap.get("-r")).equals("fmlp")) {
            regression = new FactorizedMultilayerPerceptron(Double.parseDouble((String) hashMap.get("-learn")), Double.parseDouble((String) hashMap.get("-mom")), Integer.parseInt((String) hashMap.get("-k")), combineInstances.numValues(), Integer.parseInt((String) hashMap.get("-layer")), Integer.parseInt((String) hashMap.get("-nodes")), new ActivationTanh());
            System.out.println("Will learn a factorized multilayer perceptron...");
        } else if (((String) hashMap.get("-r")).equals("fm")) {
            regression = new FactorizationMachineRegression(Double.parseDouble((String) hashMap.get("-reg")), Double.parseDouble((String) hashMap.get("-learn")), Integer.parseInt((String) hashMap.get("-k")), numValues, Double.parseDouble((String) hashMap.get("-stdev")), 1.0d, 0.0d);
            System.out.println("Will learn a Factorization Machine...");
        } else {
            System.err.println("Unknown regression function \"" + ((String) hashMap.get("-r")));
            System.out.println("-f\tPath to the datafolder where your datasets are stored.\n-d\t Dataset that will be used for testing\n-r\tThe regression model,\"fm\" (Factorization Machine), \"fmlp\" (Factorized Multilayer Perceptron) \"mlp\" (Multilayer Perceptron)\n-epochs\t number of epochs for training \n-stdev\t standard deviation of gaussian that initializes parameters for factorization machine \n-reg\t regularizer of factorization machine \n-k\t number of latent factors \n-nodes\t number of nodes per layer for mlp and fmlp \n-layer\t number of layers \n-mom\t momentum for learning \n-learn\t learn rate for all methods\n-split\t the split that will be used \n-num_splits\t the number of splits, default value is 10\n-p_train\t percentage of how much data will be used for training\n-p_val\t percentage of how much data will be used for validation\n");
            System.exit(1);
        }
        double[] targets = combineInstances.getTargets();
        double[] targets2 = combineInstances3.getTargets();
        double[] targets3 = combineInstances2.getTargets();
        double[] dArr = new double[parseInt2];
        double[] dArr2 = new double[parseInt2];
        double[] dArr3 = new double[parseInt2];
        System.out.println("Will perform " + parseInt2 + " training epochs");
        for (int i7 = 0; i7 < parseInt2; i7++) {
            regression.train(combineInstances);
            double[] predict = regression.predict(combineInstances);
            double[] predict2 = regression.predict(combineInstances3);
            double[] predict3 = regression.predict(combineInstances2);
            double computeRMSE = ErrorMetrics.computeRMSE(predict, targets);
            double computeRMSE2 = ErrorMetrics.computeRMSE(predict2, targets2);
            double computeRMSE3 = ErrorMetrics.computeRMSE(predict3, targets3);
            dArr[i7] = computeRMSE;
            dArr2[i7] = computeRMSE2;
            dArr3[i7] = computeRMSE3;
            System.out.println("TRAIN: " + computeRMSE + "\t VAL: " + computeRMSE2 + "\t TEST: " + computeRMSE3);
        }
        System.out.println("Training completed!");
        double d = Double.MAX_VALUE;
        int i8 = 0;
        double[][] dArr4 = new double[parseInt2][3];
        for (int i9 = 0; i9 < parseInt2; i9++) {
            if (dArr2[i9] < d) {
                d = dArr2[i9];
                i8 = i9;
            }
            dArr4[i9][0] = dArr[i9];
            dArr4[i9][1] = dArr2[i9];
            dArr4[i9][2] = dArr3[i9];
        }
        double d2 = dArr3[i8];
        double d3 = dArr4[parseInt2 - 1][2];
        System.out.println("Best-Validation-RMSE:" + d);
        System.out.println("Best-Predict-RMSE:" + d2);
        System.out.println("Final-Predict-RMSE:" + d3);
        System.out.println("Best-Validation-Epoch:" + i8);
    }

    public static Instances[] getTrainInstances(Instances[] instancesArr, ArrayList<Integer> arrayList, boolean z, int i) {
        Instances[] instancesArr2 = new Instances[instancesArr.length];
        for (int i2 = 0; i2 < instancesArr.length; i2++) {
            if (arrayList.contains(Integer.valueOf(i2))) {
                instancesArr2[i2] = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instancesArr[i2].getTrainSplit(i, z), instancesArr.length, i2);
            } else {
                instancesArr2[i2] = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instancesArr[i2], instancesArr.length, i2);
            }
        }
        return instancesArr2;
    }

    public static Instances[] getValidationInstances(Instances[] instancesArr, ArrayList<Integer> arrayList, boolean z, int i) {
        Instances[] instancesArr2 = new Instances[arrayList.size()];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            int intValue = arrayList.get(i2).intValue();
            instancesArr2[i2] = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instancesArr[intValue].getValidationSplit(i), instancesArr.length, intValue);
        }
        return instancesArr2;
    }

    public static Instances[] getPredictInstances(Instances[] instancesArr, ArrayList<Integer> arrayList, boolean z, int i) {
        Instances[] instancesArr2 = new Instances[arrayList.size()];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            int intValue = arrayList.get(i2).intValue();
            instancesArr2[i2] = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instancesArr[intValue].getTestSplit(i), instancesArr.length, intValue);
        }
        return instancesArr2;
    }

    public static Instances[] getTrainInstances(Instances[] instancesArr, boolean z, int i) {
        Instances[] instancesArr2 = new Instances[instancesArr.length];
        for (int i2 = 0; i2 < instancesArr.length; i2++) {
            instancesArr2[i2] = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instancesArr[i2].getTrainSplit(i, z), instancesArr.length, i2);
        }
        return instancesArr2;
    }

    public static Instances[] getValidationInstances(Instances[] instancesArr, boolean z, int i) {
        Instances[] instancesArr2 = new Instances[instancesArr.length];
        for (int i2 = 0; i2 < instancesArr.length; i2++) {
            instancesArr2[i2] = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instancesArr[i2].getValidationSplit(i), instancesArr.length, i2);
        }
        return instancesArr2;
    }

    public static Instances[] getPredictInstances(Instances[] instancesArr, boolean z, int i) {
        Instances[] instancesArr2 = new Instances[instancesArr.length];
        for (int i2 = 0; i2 < instancesArr.length; i2++) {
            instancesArr2[i2] = InstanceUtils.copyInstancesAndAddOrdinalFeatures(instancesArr[i2].getTestSplit(i), instancesArr.length, i2);
        }
        return instancesArr2;
    }

    public static void shuffleFileArray(File[] fileArr, long j) {
        Random random = new Random(j);
        for (int length = fileArr.length - 1; length > 0; length--) {
            int nextInt = random.nextInt(length + 1);
            File file = fileArr[nextInt];
            fileArr[nextInt] = fileArr[length];
            fileArr[length] = file;
        }
    }
}
