
#once snakemake is installed use the following command to test the struct tree


import snakemake.utils
snakemake.utils.min_version("7.8.0")
snake_dir = workflow.basedir
rootdir = ''.join([ sub + '/' for sub in snake_dir.split('/')[:-1] ] )
print('benchmarking running in ', rootdir)




configfile: rootdir+ "workflow/config/config_vars.yaml"



if 'folder' in config:
	if type(config['folder']) == str:
		folders = [config['folder']]
	else:
		folders = config['folder']
else:
	folders = glob_wildcards("{folders}/identifiers.txt").folders
	config['folder']  = folders

mattypes = ['fident', 'alntmscore', 'lddt' ]
alntypes = ['0', '1']
exp = ['raw', 'exp']
aligners = [ 'clustalo' , 'muscle' , '3dcoffee']

foldseekpath = config["foldseek_path"]
if foldseekpath == 'provided':
	foldseekpath = rootdir + "foldseek/foldseek"

module ML3ditree:
	# here, plain paths, URLs and the special markers for code hosting providers (see below) are possible.
	snakefile: "ML3ditree"
	config: config

use rule * from ML3ditree as ML_*

rule all:
	input:
	#get all treescore and rf distance files for all alntypes
		#expand( "{folder}/RFdistances_{exp}_.json" , folder = folders , exp = exp) ,
		expand( "{folder}/{mattype}_{alntype}_{exp}_treescores_struct_tree.json" , folder = folders , mattype = mattypes , alntype = alntypes , exp=exp ),
		expand( "{folder}/treescores_sequences.{aligner}.json" , folder = folders , aligner = aligners) ,
		expand( "{folder}/treescores_sequences_iq.{aligner}.json" , folder = folders , aligner = aligners) ,
		expand( "{folder}/plddt.json" , folder = folders ) ,
		expand("{folder}/{mattype}_{alntype}_{exp}_struct_tree.PP.nwk.rooted", folder = folders, mattype = mattypes, alntype = alntypes, exp = exp),
		#expand("{seed_folder}/foldseek_searches/foldseek_search_{seed}.tsv", seed = seed_structures, seed_folder = seed_folders),
		#expand("{seed_folder}/{seed}_prob_{prob}_qcov_{qcov}_scov_{scov}_eval_{eval}_uniref90_homologs_structs/identifiers.txt", seed = seed_structures, seed_folder = seed_folders, prob = config['prob_threshold'], qcov = config['qcov_threshold'], scov = config['scov_threshold'], eval = config['evalue_threshold'])





rule calc_tax_score:
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/sequence_dataset.csv",
		"{folder}/{mattype}_{alntype}_{exp}_struct_tree.PP.nwk.rooted"
	output:
		"{folder}/{mattype}_{alntype}_{exp}_treescores_struct_tree.json"
	log:
		"{folder}/logs/{mattype}_{alntype}_{exp}_struct_tree_scoring.log"
	script:
		"../src/calctreescores.py"

rule calc_tax_score_seq:
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/sequence_dataset.csv",
		"{folder}/sequences.aln.{aligner}.fst.nwk.rooted"
	output:
		"{folder}/treescores_sequences.{aligner}.json"
	log:
		"{folder}/logs/sequences_scoring.{aligner}.log"
	script:
		"../src/calctreescores.py"

rule calc_tax_score_iq:
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/sequence_dataset.csv",
		"{folder}/sequences.aln.{aligner}.fst.treefile.rooted"
	output:
		"{folder}/treescores_sequences_iq.{aligner}.json"
	log:
		"{folder}/logs/iq_scoring.{aligner}.log"
	script:
		"../src/calctreescores.py"


rule mad_root_struct:
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/{mattype}_{alntype}_{exp}_struct_tree.PP.nwk"
	output:
		"{folder}/{mattype}_{alntype}_{exp}_struct_tree.PP.nwk.rooted"
	log:
		"{folder}/logs/{mattype}_{alntype}_{exp}_struct_madroot.log"
	shell:
		rootdir+'madroot/mad {wildcards.folder}/{wildcards.mattype}_{wildcards.alntype}_{wildcards.exp}_struct_tree.PP.nwk'

rule mad_root_seq:
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		ancient("{folder}/sequences.aln.{aligner}.fst.nwk")
	output:
		"{folder}/sequences.aln.{aligner}.fst.nwk.rooted"
	log:
		"{folder}/logs/madrootseq.{aligner}.log"
	shell:
		rootdir +'madroot/mad  {wildcards.folder}/sequences.aln.{wildcards.aligner}.fst.nwk'

rule mad_root_iq:
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		ancient("{folder}/sequences.aln.{aligner}.fst.treefile")
	output:
		"{folder}/sequences.aln.{aligner}.fst.treefile.rooted"
	log:
		"{folder}/logs/madrootiq.{aligner}.log"
	shell:
		rootdir +'madroot/mad  {wildcards.folder}/sequences.aln.{wildcards.aligner}.fst.treefile'

rule postprocess:
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/{mattype}_{alntype}_{exp}_struct_tree.nwk"
	output:
		"{folder}/{mattype}_{alntype}_{exp}_struct_tree.PP.nwk"
	log:
		"{folder}/logs/{mattype}_{alntype}_{exp}_struct_postprocess.log"
	script:
		'../src/postprocess.py'

def get_mem_mb(wildcards, attempt):
	return attempt * 20000

rule quicktree:
	resources:
		mem_mb=get_mem_mb
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/{mattype}_{alntype}_{exp}_fastmemat.txt"
	output:
		"{folder}/{mattype}_{alntype}_{exp}_struct_tree.nwk"
	log:
		"{folder}/logs/{mattype}_{alntype}_{exp}_fastme.log"
	shell:
		'quicktree -i m {wildcards.folder}/{wildcards.mattype}_{wildcards.alntype}_{wildcards.exp}_fastmemat.txt > {wildcards.folder}/{wildcards.mattype}_{wildcards.alntype}_{wildcards.exp}_struct_tree.nwk '

rule fasttree:
	resources:
		mem_mb=get_mem_mb
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		alignment=ancient("{folder}/sequences.aln.{aligner}.fst")
	output:
		tree="{folder}/sequences.aln.{aligner}.fst.nwk"
	log:
		"{folder}/logs/fasttree.{aligner}.log"
	params:
		extra="",
	wrapper:
		"v1.20.0/bio/fasttree"

rule iqtree:
	threads: 8
	resources:
		mem_mb=get_mem_mb
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		ancient("{folder}/sequences.aln.{aligner}.fst")
	output:
		"{folder}/sequences.aln.{aligner}.fst.treefile",
	log:
		"{folder}/logs/iqtree.{aligner}.log"
	params:
		extra="",
	shell:
		'iqtree -s {wildcards.folder}/sequences.aln.{wildcards.aligner}.fst  -redo -m LG+I+G -seed 42 -nt AUTO -n 50 '

rule foldseek2distmat:
	resources:
		mem_mb=get_mem_mb
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/allvall_{alntype}.csv"
	params:
		fmt = 'query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,lddt,qaln,taln,cigar,alntmscore'
	output:
		"{folder}/fident_{alntype}_raw_fastmemat.txt",
		"{folder}/alntmscore_{alntype}_raw_fastmemat.txt",
		"{folder}/lddt_{alntype}_raw_fastmemat.txt",
		"{folder}/fident_{alntype}_exp_fastmemat.txt",
		"{folder}/alntmscore_{alntype}_exp_fastmemat.txt",
		"{folder}/lddt_{alntype}_exp_fastmemat.txt",
	log:
		"{folder}/logs/{alntype}_foldseek2distmat.log"
	script:
		"../src/foldseekres2distmat.py"


rule foldseek2stats:
	resources:
		mem_mb=get_mem_mb
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/allvall_{alntype}.csv"
	params:
		fmt = 'query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,lddt,qaln,taln,cigar,alntmscore'
	output:
		"{folder}/alnfident_dist_{alntype}_raw_.json",
	log:
		"{folder}/logs/{alntype}_foldseek_fident_stats.log"
	script:
		"../src/calcfident_foldtree_stats.py"

rule foldseek_allvall_0:	
	resources:
		mem_mb=get_mem_mb
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/finalset.csv"
	output:
		"{folder}/allvall_0.csv"
	log:
		"{folder}/logs/foldseekallvall.log"
	shell:
		foldseekpath + " easy-search {wildcards.folder}/structs/ {wildcards.folder}/structs/ {wildcards.folder}/allvall_0.csv {wildcards.folder}/tmp --format-output 'query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,lddt,qaln,taln,cigar,alntmscore' --exhaustive-search --alignment-type 0 -e inf" 

rule foldseek_allvall_1:
	resources:
		mem_mb=get_mem_mb
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		"{folder}/finalset.csv"
	output:
		"{folder}/allvall_1.csv"
	log:
		"{folder}/logs/foldseekallvall.log"
	shell:
		foldseekpath + " easy-search {wildcards.folder}/structs/ {wildcards.folder}/structs/ {wildcards.folder}/allvall_1.csv {wildcards.folder}/tmp --format-output 'query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,lddt,qaln,taln,cigar,alntmscore' --exhaustive-search --alignment-type 2 -e inf" 


rule clustalo:
	resources:
		mem_mb=get_mem_mb
	conda: 
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		ancient("{folder}/sequences.fst"),
	output:
		"{folder}/sequences.aln.clustalo.fst",
	params:
		extra=" --force ",
	log:
		"{folder}/logs/clustalo.log",
	threads: 1
	wrapper:
		"v1.20.0/bio/clustalo"

rule muscle:
	conda: 
		#"config/fold_tree.yaml"
		"foldtree"
	resources:
		mem_mb=get_mem_mb
	input:
		fasta =ancient("{folder}/sequences.fst"),
	output:
		alignment="{folder}/sequences.aln.muscle.fst",
	log:
		"{folder}/logs/muscle/muscle.log",
	threads: 1
	wrapper:
		"v2.3.2/bio/muscle"

rule dl_ids_sequences:
	conda: 
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		ids="{folder}/identifiers.txt",
	output:
		"{folder}/sequence_dataset.csv",
	log:
		"{folder}/logs/dlsequences.log"
	params:
		cath=config["cath"],
		custom_structs=config["custom_structs"],
		clean_folder=config["clean_folder"],
	script:
		"../src/dl_sequences.py"
	
rule plddt:
	conda:
		#"config/fold_tree.yaml"
		"foldtree"
	input:
		ancient("{folder}/finalset.csv"),
	output:
		"{folder}/plddt.json",
	log:
		"{folder}/logs/plddt.log"
	script:
		'../src/grabplddt.py'


rule dl_ids_structs:
	input:
		"{folder}/sequence_dataset.csv",
	output:
		#"{folder}/sequences.fst",
		"{folder}/finalset.csv",
	conda: 
		#"config/fold_tree.yaml"
		"foldtree",
	log:
		"{folder}/logs/dlstructs.log",
	params:
		filtervar=config["filter"],
		cath=config["cath"],
		filtervar_min=config["filter_min"],
		filtervar_avg=config["filter_avg"],
		custom_structs=config["custom_structs"],
		clean_folder=config["clean_folder"],
	script:
		"../src/dl_structs.py"