评估

用法

from causallearn.graph.ArrowConfusion import ArrowConfusion
from causallearn.graph.AdjacencyConfusion import AdjacencyConfusion
from causallearn.graph.SHD import SHD

# For arrows
arrow = ArrowConfusion(truth_cpdag, est)

arrowsTp = arrow.get_arrows_tp()
arrowsFp = arrow.get_arrows_fp()
arrowsFn = arrow.get_arrows_fn()
arrowsTn = arrow.get_arrows_tn()

arrowPrec = arrow.get_arrows_precision()
arrowRec = arrow.get_arrows_recall()

# For adjacency matrices
adj = AdjacencyConfusion(truth_cpdag, est)

adjTp = adj.get_adj_tp()
adjFp = adj.get_adj_fp()
adjFn = adj.get_adj_fn()
adjTn = adj.get_adj_tn()

adjPrec = adj.get_adj_precision()
adjRec = adj.get_adj_recall()

# Structural Hamming Distance
shd = SHD(truth_cpdag, est).get_shd()

参数

X: 具有 T*D 维度的数据。

truth_cpdag: 图类。

est: 图类。

返回值

arrowsTp/Fp/Fn/Tn: 真阳性/假阳性/假阴性/真阴性箭头。

arrowPrec: 箭头的精确度。

arrowRec: 箭头的召回率。

adjTp/Fp/Fn/Tn: 真阳性/假阳性/假阴性/真阴性边。

adjPrec: 邻接矩阵的精确度。

adjRec: 邻接矩阵的召回率。

shd: 结构汉明距离。