3 回答
data:image/s3,"s3://crabby-images/6ba37/6ba3798c8f48f736e1ae18439b001e178e37e63b" alt="?"
TA贡献1840条经验 获得超5个赞
使用它来获得每类准确性:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
class Metrics(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self._data = []
def on_epoch_end(self, batch, logs={}):
x_test, y_test = self.validation_data[0], self.validation_data[1]
y_predict = np.asarray(model.predict(x_test))
true = np.argmax(y_test, axis=1)
pred = np.argmax(y_predict, axis=1)
cm = confusion_matrix(true, pred)
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
self._data.append({
'classLevelaccuracy':cm.diagonal() ,
})
return
def get_data(self):
return self._data
metrics = Metrics()
history = model.fit(x_train, y_train, epochs=100, validation_data=(x_test, y_test), callbacks=[metrics])
metrics.get_data()
您可以在指标类中更改代码。随心所欲..并且这个工作。你只是用来metrics.get_data()获取所有信息..
data:image/s3,"s3://crabby-images/13790/13790e556928407dbcbe43259735fbf5ccffe916" alt="?"
TA贡献1858条经验 获得超8个赞
好吧,准确性是一个global
指标,没有per-class accuracy
. 也许你的意思是,这就是orproportion of the class correctly identified
的确切定义。TPR
recall
data:image/s3,"s3://crabby-images/5794c/5794c99cb06547647aa29905713e450c4ca58a54" alt="?"
TA贡献1828条经验 获得超3个赞
如果您想获得某个类别或一组特定类别的准确性,掩码可能是一个很好的解决方案。看这段代码:
def cus_accuracy(real, pred):
score = accuracy(real, pred)
mask = tf.math.greater_equal(real, 5)
mask = tf.cast(mask, dtype=real.dtype)
score *= mask
mask2 = tf.math.less_equal(real, 10)
mask2 = tf.cast(mask2, dtype=real.dtype)
score *= mask2
return tf.reduce_mean(score)
这个指标给出了 5 到 10 类的准确度。我用它来测量 seq2seq 模型中某些单词的准确度。
添加回答
举报