Estimating causal effects using a simplified R implementation

How {marginaleffects} can make your causal inferece scripts readable

R
marginaleffects
causal inference
ATE
Using marginaleffects, I show how to estimate ATE, ATT, ATU, and CATE in an RCT study
Author

Daniel S. Mazhari-Jensen

Published

November 20, 2025

Causal inference sounds cool - how easy is it to do?

I recently participated in a causal inference class and realized how difficult it is for some of my students and colleagues to actually compute causal inference estimates. Although the most important point is to understand when and why to use these methods, I also think accessibility to the computational part could be helpful for many.

Credits to Pausal Živference

Credits to Francisco Yirá

In this short blog, I’ll introduce four key estimands that can be computed using a framework called G-computation:

  1. the average treatment effect (ATE)
  2. the average treatment effect on the treated (ATT)
  3. the average treatment effect on the untreated (ATU)
  4. the conditional average treatment effect (CATE)

The estimand

The Average Treatment Effect (ATE) measures how the treatment impacts outcomes across the entire study population. It addresses the question: Would it be beneficial to offer this program or treatment to everyone?

The Average Treatment Effect on the Treated (ATT) focuses on the effect of the treatment specifically for those participants who actually received it. It helps evaluate: Should the treatment be continued for the group currently receiving it?

The Average Treatment Effect on the Untreated (ATU) examines the effect of the treatment on those who did not receive it. This helps answer: Would it be advantageous to expand the program or treatment to individuals who were initially excluded?

The Conditional Average Treatment Effect (CATE) estimates how the treatment affects outcomes for a specific subgroup of the population, defined by certain characteristics (e.g., age, gender, or baseline risk). It answers the question: How does the treatment work for people with particular traits or conditions?

The assumptions

When estimating causal effects, such as the ATE or CATE, four fundamental assumptions must hold to ensure valid conclusions:

Exchangeability (or no unmeasured confounding) – This assumes that, after accounting for observed variables, the treated and untreated groups are comparable. In other words, there are no hidden factors that systematically affect both treatment assignment and the outcome.

Positivity (or overlap) – Every individual in the study population must have a nonzero chance of receiving each treatment option. Without this, it becomes impossible to estimate the effect for some groups because they never experience one of the treatments.

Non-interference (or the Stable Unit Treatment Value Assumption, SUTVA) – One person’s outcome should not be influenced by whether someone else receives the treatment. Each participant’s outcome depends only on their own treatment status.

Consistency – The treatment is well-defined, and each individual’s observed outcome under the treatment they actually received is the same as the outcome that would be observed under the same treatment condition in theory.

The implementation in R

Our dataset

Code
# keep only ctrl and trt1
dat <- dplyr::filter(ChickWeight, Time %in% c(0, 12)) |>
  #dplyr::filter(Diet %in% c(1, 3, 4)) |>
  dplyr::mutate(
    treatment = dplyr::if_else(Diet == '1', '1', '2')
  ) |>
  dplyr::mutate(timepoint = dplyr::if_else(Time == 0, "pre", "post")) |>
  tidyr::pivot_wider(
    id_cols = c(Chick, treatment), # Diet not needed...
    names_from = timepoint,
    values_from = weight
  )

#dat$Diet <- as.factor(dat$Diet)
dat$treatment <- as.factor(dat$treatment)
#dat$study_arm <- dat$group

#dat$treatment <- forcats::fct_relevel(c('1', '2'), c('ctrl', 'trt1'))
#dat$group <- c()

Specifying our model - the estimator

First, we fit a statistical model which controls for confounders. Then, we use the fitted model to “predict” or “impute” what would happen to an individual under alternative treatment scenarios. Finally, we compare counterfactual predictions to derive an estimate of the treatment effect (Hernán and Robins 2020).

To get an adjustment set for confounder control, we need to make a DAG

flowchart LR
  A(Pre-weight) --> B(Post-weight)
  C(Treatment) --> B

Thus, we will be using the pre-weight as a confounding variable for the modeling.

Code
# model without adjustment set
unadj_mod <- lm(post ~ treatment, data = dat)
unadj_mod

Call:
lm(formula = post ~ treatment, data = dat)

Coefficients:
(Intercept)   treatment2  
     108.53        33.84  
Code
# model with adjustment set
adj_mod <- lm(post ~ treatment * pre, data = dat)
adj_mod

Call:
lm(formula = post ~ treatment * pre, data = dat)

Coefficients:
   (Intercept)      treatment2             pre  treatment2:pre  
        14.835         283.241           2.256          -6.069  

Note that this model does not correct for adherence. Therefore, it is the intention-to-treat effect. If we adjusted for adherence or in any other way try to correct for individuals deviating from the protocol, this would give the per-protocol effect.

Computing estimates | model, estimand

Now, we want to go from a “missing data”-problem, where only some individuals got the treatment while the others got no treatment, to a full matrix of all individuals got both treated and not. Of course, this is not what happened in real life. This is counter to the fact (i.e. counterfactual) estimates, which we can predict using normal regression coefficients. It looks something like this:

Code
# Show first 3 subjects with outcome and only being treated or not treater with missing values on the other
dat |>
  # create subject id
  dplyr::mutate(ID = dplyr::row_number()) |>
  # spread so each subject has columns for both treatments
  tidyr::pivot_wider(
    id_cols = ID,
    names_from = treatment,
    values_from = post,
    values_fill = NA
  ) |>
  dplyr::slice(c(1, 20, 35, 40)) # .preserve = group
# A tibble: 4 × 3
     ID   `1`   `2`
  <int> <dbl> <dbl>
1     1   106    NA
2    20    77    NA
3    35    NA   201
4    40    NA   154
Code
# Show subjects with outcome using marginaleffect imputation
counter_factuals <- marginaleffects::predictions(
  adj_mod,
  variables = "treatment"
)

counter_factual_matrix <- counter_factuals |>
  # ensure we have a subject index:
  tibble::as_tibble() |>
  #dplyr::mutate(Subject = dplyr::row_number()) |>
  dplyr::filter(rowidcf %in% c(1, 20, 35, 40)) |>
  dplyr::mutate(treatment = dplyr::if_else(rowid > 40, 'Y_j^1', 'Y_j^0')) |>
  dplyr::select(treatment, rowidcf, estimate) |>
  tidyr::pivot_wider(
    id_cols = c(rowidcf),
    names_from = treatment,
    values_from = estimate
  )

knitr::kable(
  counter_factual_matrix,
  caption = "Predicted weights for Subjects 1, 20, 35, and 40 under Treatment = 0 and 1"
)
Predicted weights for Subjects 1, 20, 35, and 40 under Treatment = 0 and 1
rowidcf Y_j^0 Y_j^1
1 109.5950 137.9178
20 105.0826 145.5444
35 102.8264 149.3577
40 109.5950 137.9178

We are only interested in the average effect, as individual treatment effects are not possible to estimate.

Code
marginaleffects::avg_predictions(
  adj_mod,
  variables = "treatment",
  by = "treatment"
)

 treatment Estimate Std. Error    z Pr(>|z|)     S 2.5 % 97.5 %
         1      108       7.89 13.6   <0.001 138.3  92.1    123
         2      141       5.72 24.7   <0.001 445.4 130.1    153

Type: response

remember that the ATE is the difference, which we can get by using the ´avg_comparisons()´ of {marginaleffects}

Code
ATE = marginaleffects::avg_comparisons(
  adj_mod,
  variables = "treatment",
  newdata = dat
)
ATE

 Estimate Std. Error    z Pr(>|z|)    S 2.5 % 97.5 %
       34       9.86 3.45   <0.001 10.8  14.7   53.3

Term: treatment
Type: response
Comparison: 2 - 1

The ATT

Code
ATT <- marginaleffects::avg_comparisons(
  adj_mod,
  variables = "treatment",
  newdata = subset(treatment == '2')
)
ATT

 Estimate Std. Error   z Pr(>|z|)    S 2.5 % 97.5 %
     35.4       10.7 3.3   <0.001 10.0  14.4   56.4

Term: treatment
Type: response
Comparison: 2 - 1

and the ATU

Code
ATU <- marginaleffects::avg_comparisons(
  adj_mod,
  variables = "treatment",
  newdata = subset(treatment == '1')
)
ATU

 Estimate Std. Error    z Pr(>|z|)   S 2.5 % 97.5 %
     31.2       9.55 3.27  0.00109 9.8  12.5   49.9

Term: treatment
Type: response
Comparison: 2 - 1

The CATE.

Code
CATE <- marginaleffects::avg_comparisons(
  adj_mod,
  variables = "treatment",
  by = "pre"
)
CATE

 pre Estimate Std. Error    z Pr(>|z|)    S  2.5 % 97.5 %
  39     46.5       25.0 1.86   0.0626  4.0  -2.45   95.5
  40     40.5       16.3 2.48   0.0132  6.2   8.47   72.5
  41     34.4       10.0 3.42   <0.001 10.7  14.70   54.1
  42     28.3       11.3 2.50   0.0123  6.3   6.15   50.5
  43     22.3       18.6 1.19   0.2327  2.1 -14.29   58.8

Term: treatment
Type: response
Comparison: 2 - 1

Note that we gain some interesting information here. We see, as would be expected, that chicks with low pre-weight (susceptible to being underweight) will gain more weight, whereas chicks with initial high pre-weight (robust or previliged) will gain less weight.