跳转至

Python 分析函数

Table 1 — 基线特征表

import pandas as pd
import numpy as np
from scipy import stats


def describe_continuous(data, var, group):
    """分组描述连续变量"""
    result = data.groupby(group)[var].agg([
        ("N", "count"),
        ("Mean", "mean"),
        ("SD", "std"),
        ("Median", "median"),
        ("Q1", lambda x: x.quantile(0.25)),
        ("Q3", lambda x: x.quantile(0.75)),
        ("Min", "min"),
        ("Max", "max")
    ]).round(2)

    # 组间比较(t 检验)
    groups = data[group].unique()
    if len(groups) == 2:
        g1 = data.loc[data[group] == groups[0], var].dropna()
        g2 = data.loc[data[group] == groups[1], var].dropna()
        stat, pval = stats.ttest_ind(g1, g2)
        result["p_value"] = f"{pval:.4f}"

    return result


def describe_categorical(data, var, group):
    """分组描述分类变量"""
    ct = pd.crosstab(data[group], data[var], margins=True)
    # 组间比较
    stat, pval, _, _ = stats.chi2_contingency(
        pd.crosstab(data[group], data[var])
    )
    return ct, pval


def table1(data, vars, cat_vars, group):
    """生成 Table 1"""
    results = []
    for var in vars:
        if var in cat_vars:
            desc, pval = describe_categorical(data, var, group)
            results.append(f"\n--- {var} (p={pval:.4f}) ---\n{desc}")
        else:
            desc = describe_continuous(data, var, group)
            results.append(f"\n--- {var} ---\n{desc}")
    return "\n".join(results)


# 使用示例
# t1 = table1(adsl,
#             vars=["age", "bmi", "sex", "race"],
#             cat_vars=["sex", "race"],
#             group="treatment")
# print(t1)

MMRM 分析

import statsmodels.api as sm
from statsmodels.formula.api import mixedlm


def run_mmrm(data, formula="chg ~ C(treatment) * C(visit) + base"):
    """MMRM 分析"""
    model = mixedlm(
        formula,
        data=data,
        groups=data["subject"],
        re_formula="~1"
    )
    result = model.fit(method="reml")

    print(result.summary())
    return result


def extract_lsmeans(model, data, treatment="treatment", visit="visit"):
    """提取 LS Means(需手动计算)"""
    import pandas as pd
    # 获取各组边际均值
    pred_data = data[[treatment, visit, "base"]].drop_duplicates()
    pred = model.predict(pred_data)
    lsmeans = pred_data.copy()
    lsmeans["predicted"] = pred

    return lsmeans.groupby([treatment, visit])["predicted"].mean()

生存分析

from lifelines import KaplanMeierFitter, CoxPHFitter
import matplotlib.pyplot as plt


def fit_km(data, duration_col, event_col, group_col):
    """拟合 KM 曲线"""
    fig, ax = plt.subplots(figsize=(8, 6))

    for name, group in data.groupby(group_col):
        kmf = KaplanMeierFitter()
        kmf.fit(group[duration_col], group[event_col], label=name)
        kmf.plot_survival_function(ax=ax)

    plt.title("Kaplan-Meier Survival Curve")
    plt.xlabel("Time")
    plt.ylabel("Survival Probability")
    return fig


def fit_cox(data, duration_col, event_col, covariates):
    """拟合 Cox 回归"""
    cph = CoxPHFitter()
    cols = [duration_col, event_col] + covariates
    cph.fit(data[cols], duration_col=duration_col, event_col=event_col)
    cph.print_summary()
    return cph


def forest_plot_data(data, duration_col, event_col, subgroup_col):
    """准备森林图数据"""
    import pandas as pd
    from lifelines import CoxPHFitter

    results = []
    # Overall
    cph = CoxPHFitter()
    cph.fit(data[[duration_col, event_col, "treatment"]],
            duration_col=duration_col, event_col=event_col)
    hr = np.exp(cph.hazards_.iloc[0, 0])
    ci = np.exp(cph.confidence_intervals_.iloc[0].values)
    results.append({
        "Subgroup": "Overall",
        "HR": hr, "Lower": ci[0], "Upper": ci[1]
    })

    # 各亚组
    for subg in data[subgroup_col].unique():
        sub = data[data[subgroup_col] == subg]
        cph = CoxPHFitter()
        cph.fit(sub[[duration_col, event_col, "treatment"]],
                duration_col=duration_col, event_col=event_col)
        hr = np.exp(cph.hazards_.iloc[0, 0])
        ci = np.exp(cph.confidence_intervals_.iloc[0].values)
        results.append({
            "Subgroup": subg,
            "HR": hr, "Lower": ci[0], "Upper": ci[1]
        })

    return pd.DataFrame(results)

模型诊断

import matplotlib.pyplot as plt
import scipy.stats as stats


def diagnose_linear_model(model, residuals):
    """线性模型诊断四合一图"""
    fig, axes = plt.subplots(2, 2, figsize=(10, 8))

    # 残差 vs 拟合值
    axes[0, 0].scatter(model.fittedvalues, residuals, alpha=0.5)
    axes[0, 0].axhline(y=0, color="r", linestyle="--")
    axes[0, 0].set_xlabel("Fitted Values")
    axes[0, 0].set_ylabel("Residuals")
    axes[0, 0].set_title("Residuals vs Fitted")

    # Q-Q 图
    stats.probplot(residuals, dist="norm", plot=axes[0, 1])
    axes[0, 1].set_title("Normal Q-Q")

    # Scale-Location 图
    std_res = residuals / residuals.std()
    axes[1, 0].scatter(model.fittedvalues, np.sqrt(np.abs(std_res)), alpha=0.5)
    axes[1, 0].set_xlabel("Fitted Values")
    axes[1, 0].set_ylabel("√|Standardized Residuals|")
    axes[1, 0].set_title("Scale-Location")

    # 残差直方图
    axes[1, 1].hist(residuals, bins=20, edgecolor="black", alpha=0.7)
    axes[1, 1].set_xlabel("Residuals")
    axes[1, 1].set_ylabel("Frequency")
    axes[1, 1].set_title("Residuals Distribution")

    plt.tight_layout()
    return fig