评估
用法
from causallearn.graph.ArrowConfusion import ArrowConfusion
from causallearn.graph.AdjacencyConfusion import AdjacencyConfusion
from causallearn.graph.SHD import SHD
# For arrows
arrow = ArrowConfusion(truth_cpdag, est)
arrowsTp = arrow.get_arrows_tp()
arrowsFp = arrow.get_arrows_fp()
arrowsFn = arrow.get_arrows_fn()
arrowsTn = arrow.get_arrows_tn()
arrowPrec = arrow.get_arrows_precision()
arrowRec = arrow.get_arrows_recall()
# For adjacency matrices
adj = AdjacencyConfusion(truth_cpdag, est)
adjTp = adj.get_adj_tp()
adjFp = adj.get_adj_fp()
adjFn = adj.get_adj_fn()
adjTn = adj.get_adj_tn()
adjPrec = adj.get_adj_precision()
adjRec = adj.get_adj_recall()
# Structural Hamming Distance
shd = SHD(truth_cpdag, est).get_shd()
参数
X: 具有 T*D 维度的数据。
truth_cpdag: 图类。
est: 图类。
返回值
arrowsTp/Fp/Fn/Tn: 真阳性/假阳性/假阴性/真阴性箭头。
arrowPrec: 箭头的精确度。
arrowRec: 箭头的召回率。
adjTp/Fp/Fn/Tn: 真阳性/假阳性/假阴性/真阴性边。
adjPrec: 邻接矩阵的精确度。
adjRec: 邻接矩阵的召回率。
shd: 结构汉明距离。