机器学习中有一类问题叫做多标签分类(multi-label),其和多分类(multi-class)问题不同。多分类问题是将一个样本x分到某一个类别$y_i$,而多标签分类问题是将一个样本x分到某些类别$y_i$, .., $y_j$等,也就是说多分类问题的类别之间是互斥的,所有类别的概率和为1。而多标签分类问题的类别之间不互斥,所有类别的概率和不为1,多标签分类问题可以理解为n个二分类问题。

最近在做一个多标签分类问题的工作,涉及到多标签分类问题如何评估效果,这里进行一些总结。

首先看下**多分类问题的评估**,知乎用户王晋东不在家已经讲的非常好了。

image-20210902172039241

对于**多标签分类的评估**,上文已经说到,其可以看做是n个二分类问题,那么我们可以通过计算每个类别下的auc,综合起来对模型进行评估,这个其实很容易理解。但是在实现的时候需要注意几个问题:

  • tf2已经提供了多标签分类计算auc的函数,但是其综合了所有类别的auc,计算得到一个平均auc。如下所示,只需要指定multi_label=True就可以计算多标签分类的平均auc了。

    1
    2
    3
    4
    5
    6
    tf.keras.metrics.AUC(
    num_thresholds=200, curve='ROC',
    summation_method='interpolation', name=None, dtype=None,
    thresholds=None, multi_label=False, num_labels=None, label_weights=None,
    from_logits=False
    )
  • 虽然tf2提供了多标签分类计算auc的函数,但是其计算的是平均auc,更多情况下我们希望看下每个类别的auc情况。所以还需要自己实现下。思路就是分别拎出各个类别的样本计算auc,计算某个类别auc的时候需要剔除其它类别的样本。同时因为tf提供的auc计算函数计算的是累积auc(就是计算auc的中间变量:真正例/假正例/真负例/假负例等本地变量值会累积),所以需要对每个类别的auc确定一个变量作用域,eg with tf.compat.v1.variable_scope('class_' + str(class_id)):,避免计算出的各个类别auc之间相互干扰。分析源码得知,tf在计算真正例/假正例/真负例/假负例时,是通过metric_variable()函数定义了变量真正例,同时metric_variable()的实现也是依赖于variable_scope.variable,所以通过tf.compat.v1.variable_scope()可以避免各个类别的auc相互干扰。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    def metric_variable(shape, dtype, validate_shape=True, name=None):
    return variable_scope.variable(
    lambda: array_ops.zeros(shape, dtype),
    trainable=False,
    collections=[
    ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
    ],
    validate_shape=validate_shape,
    synchronization=variable_scope.VariableSynchronization.ON_READ,
    aggregation=variable_scope.VariableAggregation.SUM,
    name=name)
  • 上面也说了tf计算的是累积auc,所以需要区分自己想要的是batch auc还是累积auc。一个建议是,可以考虑在训练的时候计算batch_auc,观察auc的实时变化情况;预测的时候计算累积auc,观察整体效果。这个可以通过tf.control_dependencies()控制。代码如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    if is_training:
    tf.compat.v1.add_to_collection(
    tf.compat.v1.GraphKeys.UPDATE_OPS,
    tf.compat.v1.local_variables_initializer())
    update_ops = tf.compat.v1.get_collection(
    tf.compat.v1.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops), tf.compat.v1.variable_scope(
    "train", reuse=tf.compat.v1.AUTO_REUSE):
    train_op = optimizer.minimize(self.losses, global_step=global_step)
  • tf.compat.v1.metrics.auc()函数的返回值有两个auc_value, update_op。需要先执行update_op,再执行auc_value。因为auc_value依赖于update_op先执行才会被更新。参见代码计算混淆矩阵的实现:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    def _confusion_matrix_at_thresholds(labels,
    predictions,
    thresholds,
    weights=None,
    includes=None):
    ...
    if 'fp' in includes:
    false_p = metric_variable(
    [num_thresholds], dtypes.float32, name='false_positives')
    is_false_positive = math_ops.cast(
    math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
    if weights_tiled is not None:
    is_false_positive *= weights_tiled
    update_ops['fp'] = state_ops.assign_add(false_p,
    math_ops.reduce_sum(
    is_false_positive, 1))
    values['fp'] = false_p

    return values, update_ops

    可以看到上述代码中,只有update_ops['fp']先被执行,false_p才会被更新。其实也可以直接拿update_ops的值作为当前数据更新之后的auc。

参考