我正在尝试在 java 中使用深度学习来训练模型,当我开始训练训练数据时它会出错Invalid classification data: expect label value (at label index column = 0) to be in range 0 to 1 inclusive (0 to numClasses-1, with numClasses=2); got label value of 2我不明白这个错误,因为我是深度学习 4j 的初学者。我正在使用一个查看两个人之间关系的数据集(如果两个人之间存在关系,那么类标签将为 1,否则为 0)。Java 代码public class SNA {private static Logger log = LoggerFactory.getLogger(SNA.class);public static void main(String[] args) throws Exception { int seed = 123; double learningRate = 0.01; int batchSize = 50; int nEpochs = 30; int numInputs = 2; int numOutputs = 2; int numHiddenNodes = 20; //load the training data RecordReader rr = new CSVRecordReader(0,","); rr.initialize(new FileSplit(new File("C:\\Users\\GTS\\Desktop\\SNA project\\experiments\\First experiment\\train\\slashdotTrain.csv"))); DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize,0, 2); // load test data RecordReader rrTest = new CSVRecordReader(); rr.initialize(new FileSplit(new File("C:\\Users\\GTS\\Desktop\\SNA project\\experiments\\First experiment\\test\\slashdotTest.csv"))); DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize,0, 2); log.info("**** Building Model ****"); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1) .learningRate(learningRate) .updater(Updater.NESTEROVS).momentum(0.9) .list() .layer(0, new DenseLayer.Builder() .nIn(numInputs) .nOut(numHiddenNodes) .activation("relu") .weightInit(WeightInit.XAVIER) .build()) }}}有什么帮助吗?多谢
1 回答
繁花如伊
TA贡献2012条经验 获得超12个赞
解决问题:将RecordReaderDataSetIterator
in的第三个参数DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize,0, 2);
由0改为2;因为数据集有三列,类标签的索引是 2,因为它是第三列。
解决方案:
DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize,2, 2);
添加回答
举报
0/150
提交
取消