为了账号安全,请及时绑定邮箱和手机立即绑定

“无效的分类数据:期望标签值”

“无效的分类数据:期望标签值”

一只名叫tom的猫 2022-03-10 22:04:58
我正在尝试在 java 中使用深度学习来训练模型,当我开始训练训练数据时它会出错Invalid classification data: expect label value (at label index column = 0) to be in range 0 to 1 inclusive (0 to numClasses-1, with numClasses=2); got label value of 2我不明白这个错误,因为我是深度学习 4j 的初学者。我正在使用一个查看两个人之间关系的数据集(如果两个人之间存在关系,那么类标签将为 1,否则为 0)。Java 代码public class SNA {private static Logger log = LoggerFactory.getLogger(SNA.class);public static void main(String[] args) throws Exception {    int seed = 123;    double learningRate = 0.01;    int batchSize = 50;    int nEpochs = 30;    int numInputs = 2;    int numOutputs = 2;    int numHiddenNodes = 20;    //load the training data    RecordReader rr = new CSVRecordReader(0,",");    rr.initialize(new FileSplit(new File("C:\\Users\\GTS\\Desktop\\SNA project\\experiments\\First experiment\\train\\slashdotTrain.csv")));    DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize,0, 2);    // load test data    RecordReader rrTest = new CSVRecordReader();    rr.initialize(new FileSplit(new File("C:\\Users\\GTS\\Desktop\\SNA project\\experiments\\First experiment\\test\\slashdotTest.csv")));    DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize,0, 2);    log.info("**** Building Model ****");    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()            .seed(seed)            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)            .iterations(1)            .learningRate(learningRate)            .updater(Updater.NESTEROVS).momentum(0.9)            .list()            .layer(0, new DenseLayer.Builder()                    .nIn(numInputs)                    .nOut(numHiddenNodes)                    .activation("relu")                    .weightInit(WeightInit.XAVIER)                    .build())    }}}有什么帮助吗?多谢
查看完整描述

1 回答

?
繁花如伊

TA贡献2012条经验 获得超12个赞

解决问题:将RecordReaderDataSetIteratorin的第三个参数
DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize,0, 2);由0改为2;因为数据集有三列,类标签的索引是 2,因为它是第三列。

解决方案:

DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize,2, 2);


查看完整回答
反对 回复 2022-03-10
  • 1 回答
  • 0 关注
  • 148 浏览

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信