# Modeling public opinion over time and space: Trust in state institutions in Europe, 1989-2019
# 
# 31 January 2021
#
# Replication materials: Poststratified differences by age, sex, and education


# Packages
library(dplyr)
library(tidyr)
library(stringr)
library(brms)


### SETUP

bind <- function(...) cbind(...)

nsamples_edu <- 100
nsamples_trust <- 1000

dem <- readRDS("dem_20201020.rds") %>% rename(year = time)
mimp <- readRDS("models_cntry_spline_20201020.rds")


### Loop over countries

for (cntry in countries) {

	survey_data <- readRDS("all_cat_27_edu3_subset_2_20211212.rds") %>%
		filter(t_year >= 1989,
			   t_cntry == cntry) %>%
		mutate(caseid = row_number()) %>%
		select(caseid, t_project, t_wave, t_year, t_cntry, survey, sex, age_cat, 
			   educ3, thres, trust_parl, trust_jus, trust_polpart) %>%
		gather(inst, trust, 11:13)
	  
	data_cntry <- expand.grid(geo = cntry, 
							  time = seq(1989, 2019, by = 0.1),
							  sex = c("F", "M"), 
							  age_cat = c("(19,34]", "(34,54]", "(54,74]")) %>%
		mutate_if(is.factor, as.character)

	post <- posterior_predict(mimp[[cntry]], newdata = data_cntry, nsamples = nsamples_edu)

	imputed_mult_educ3 <- cbind( rbind( t(post[,,1]), t(post[,,2]), t(post[,,3])), 
                                        data_cntry,
                                        educ3 = c(rep(1, ncol(post)), rep(2, ncol(post)), rep(3, ncol(post)))) %>% 
		data.frame()

	print(paste0("Read model start: ", Sys.time()))

	# pick survey with the most cases
	survey_ph <- survey_data %>%
		filter(t_cntry == cntry) %>%
		count(survey) %>%
		filter(n == max(n)) %>% pull(survey)

	# pick thres that corresponds to survey
	thres_ph <- survey_data %>%
		filter(survey == survey_ph) %>%
		count(thres) %>% pull(thres)

	# pick caseid

	caseid_ph <- survey_data %>%
		filter(survey == survey_ph) %>%
		summarise(caseid = first(caseid)) %>%
		pull(caseid)

	min_year <- min(survey_data$t_year[survey_data$t_cntry == cntry])
	max_year <- max(survey_data$t_year[survey_data$t_cntry == cntry])

	dem_imp1 <- imputed_mult_educ3 %>%
		filter(time >= min_year & time <= max_year) %>%
		gather(snumber, prop_educ, 1:nsamples_edu) %>%
		mutate(year = floor(time)) %>%
		left_join(dem, by = c("year", "geo", "age_cat", "sex")) %>%
		mutate(npop_cat = prop_educ * npop) %>%
		group_by(geo, time, snumber) %>%
		mutate(prop_cat_all = npop_cat / sum(npop_cat)) %>%
		ungroup() %>%
		select(geo, time, sex, age_cat, snumber, educ3, year, npop_cat) %>%
		mutate(educ3 = as.numeric(educ3)) %>%
		mutate_at(vars(age_cat, sex, educ3), factor) %>%
		mutate(project_f = NA,
			   inst_f = NA,
			   survey_inst = paste0(survey_ph, "trust_parl_"),
			   thres = thres_ph,
			   caseid = caseid_ph,
			   t_year = time,
			   Intercept = 1) %>%
		drop_na(npop_cat)

	print(paste0("Post samples start: ", Sys.time()))

	model <- readRDS(paste0("models/", tolower(cntry), "_des.rds"))

	post_samples <- posterior_linpred(model, re_formula = NA, newdata = dem_imp1, nsamples = nsamples_trust)

	post <- cbind(t(post_samples),  
						 dem_imp1[, c("geo", "time", "sex", "age_cat", "educ3", "npop_cat", "snumber")]) %>%
						 data.frame() %>%
						 gather(nsample, trust, 1:nsamples_trust)

	print(paste0("post_sex start: ", Sys.time()))

	post_sex <- post %>%
	group_by(snumber, nsample, geo, time, sex) %>%
	summarise(trust_est = sum(trust*npop_cat) / sum(npop_cat)) %>%
	spread(sex, trust_est) %>%
	mutate(diff = M - F) %>%
	ungroup() %>%
	group_by(geo, time) %>%
	summarise(diff_median = median(diff),
			  diff_low_95 = quantile(diff, 0.025),
			  diff_high_95 = quantile(diff, 0.975),
			  diff_low_90 = quantile(diff, 0.05),
			  diff_high_90 = quantile(diff, 0.95),
			  diff_low_75 = quantile(diff, 0.125),
			  diff_high_75 = quantile(diff, 0.975),
			  diff_low_50 = quantile(diff, 0.25),
			  diff_high_50= quantile(diff, 0.75),
			  diff_mad = mad(diff),
			  diff_mean = mean(diff),
			  diff_sd = sd(diff)) %>%
	mutate(type = "sex")

	print(paste0("post_age start: ", Sys.time()))

	post_age <- post %>%
	group_by(snumber, nsample, geo, time, age_cat) %>%
	summarise(trust_est = sum(trust*npop_cat) / sum(npop_cat)) %>%
	spread(age_cat, trust_est) %>%
	mutate(diff = `(54,74]` - `(19,34]`) %>%
	ungroup() %>%
	group_by(geo, time) %>%
	summarise(diff_median = median(diff),
			  diff_low_95 = quantile(diff, 0.025),
			  diff_high_95 = quantile(diff, 0.975),
			  diff_low_90 = quantile(diff, 0.05),
			  diff_high_90 = quantile(diff, 0.95),
			  diff_low_75 = quantile(diff, 0.125),
			  diff_high_75 = quantile(diff, 0.975),
			  diff_low_50 = quantile(diff, 0.25),
			  diff_high_50= quantile(diff, 0.75),
			  diff_mad = mad(diff),
			  diff_mean = mean(diff),
			  diff_sd = sd(diff)) %>%
	mutate(type = "age")

	print(paste0("post_educ start: ", Sys.time()))

	post_educ <- post %>%
	group_by(snumber, nsample, geo, time, educ3) %>%
	summarise(trust_est = sum(trust*npop_cat) / sum(npop_cat)) %>%
	spread(educ3, trust_est) %>%
	mutate(diff = `3` - `1`) %>%
	ungroup() %>%
	group_by(geo, time) %>%
	summarise(diff_median = median(diff),
			  diff_low_95 = quantile(diff, 0.025),
			  diff_high_95 = quantile(diff, 0.975),
			  diff_low_90 = quantile(diff, 0.05),
			  diff_high_90 = quantile(diff, 0.95),
			  diff_low_75 = quantile(diff, 0.125),
			  diff_high_75 = quantile(diff, 0.975),
			  diff_low_50 = quantile(diff, 0.25),
			  diff_high_50= quantile(diff, 0.75),
			  diff_mad = mad(diff),
			  diff_mean = mean(diff),
			  diff_sd = sd(diff)) %>%
	mutate(type = "educ")

	print(paste0("post_mrp start: ", Sys.time()))

	post_mrp <- bind_rows(post_age, post_sex, post_educ)

	saveRDS(post_mrp, paste0("bycntry/post_pp_", cntry, ".rds"))
}


# combine single-country MRP estimates of differences ------------

temp <- list.files(path = "bycntry/", pattern = ".rds")
f <- file.path(path = "bycntry/", temp)

mrp_all_diffs <- lapply(f, readRDS)

mrp_all_diffs_df <- bind_rows(mrp_all_diffs) %>%
  mutate(country = countrycode::countrycode(geo, "iso2c", "country.name"),
         year = time)

saveRDS(mrp_all_diffs_df, "post_strat_diffs_des_20220410.rds")

