rule all:
    input:
        auspice = "auspice/progressive_align.json"

rule files:
    params:
        aligned = "data/aligned.fasta",
        new_seqs = ["data/new1.fasta", "data/new2.fasta", "data/new3_present_in_alignment.fasta", "data/new4_present_in_seqs.fasta"],
        reference = "data/reference.gb",
        metadata = "data/metadata.tsv"

files = rules.files.params

rule align:
    message: "Aligning sequences to existing alignment"
    input:
        alignment = files.aligned,
        sequences = files.new_seqs,
        reference = files.reference
    output:
        alignment = "results/aligned.fasta"
    shell:
        """
        augur align \
            --existing-alignment {input.alignment} \
            --sequences {input.sequences:q} \
            --reference-sequence {input.reference} \
            --output {output.alignment} \
            --fill-gaps
        """

rule tree:
    message: "Building tree"
    input:
        alignment = rules.align.output.alignment
    output:
        tree = "results/tree_raw.nwk"
    params:
        method = "iqtree"
    shell:
        """
        augur tree \
            --alignment {input.alignment} \
            --output {output.tree} \
            --method {params.method}
        """

rule refine:
    message: "Refining tree"
    input:
        tree = rules.tree.output.tree,
        alignment = rules.align.output,
        metadata = files.metadata
    output:
        tree = "results/tree.nwk",
        node_data = "results/branch_lengths.json"
    shell:
        """
        augur refine \
            --tree {input.tree} \
            --alignment {input.alignment} \
            --metadata {input.metadata} \
            --output-tree {output.tree} \
            --output-node-data {output.node_data} \
            --root KX369547.1
        """

rule ancestral:
    message: "Reconstructing ancestral sequences and mutations"
    input:
        tree = rules.refine.output.tree,
        alignment = rules.align.output
    output:
        node_data = "results/nt_muts.json"
    params:
        inference = "joint"
    shell:
        """
        augur ancestral \
            --tree {input.tree} \
            --alignment {input.alignment} \
            --output-node-data {output.node_data} \
            --inference {params.inference}
        """

rule translate:
    message: "Translating amino acid sequences"
    input:
        tree = rules.refine.output.tree,
        node_data = rules.ancestral.output.node_data,
        reference = files.reference
    output:
        node_data = "results/aa_muts.json"
    shell:
        """
        augur translate \
            --tree {input.tree} \
            --ancestral-sequences {input.node_data} \
            --reference-sequence {input.reference} \
            --output-node-data {output.node_data} \
        """

auspice_config = {
  "title": "Add sequences to existing alignment test build",
  "colorings": [
    {"key": "gt", "title": "Genotype", "type": "categorical"},
    {"key": "country", "type": "ordinal"}
  ],
  "panels": ["tree", "entropy"],
  "filters": []
}

rule make_config_file:
    message: "Making auspice config file (as specified in the Snakefile)"
    output:
        config = "results/auspice_config.json"
    params:
        auspice_config = auspice_config
    run:
        import json
        with open(output.config, 'w') as fh:
            json.dump(auspice_config, fh, indent=2)

rule export:
    message: "Exporting for auspice"
    input:
        tree = rules.refine.output.tree,
        metadata = files.metadata,
        branch_lengths = rules.refine.output.node_data,
        nt_muts = rules.ancestral.output.node_data,
        aa_muts = rules.translate.output.node_data,
        auspice_config = rules.make_config_file.output.config
    output:
        auspice = rules.all.input.auspice
    shell:
        """
        snakemake --cores all check
        augur export v2 \
            --tree {input.tree} \
            --metadata {input.metadata} \
            --node-data {input.branch_lengths} {input.nt_muts} {input.aa_muts} \
            --auspice-config {input.auspice_config} \
            --output {output.auspice}
        """


rule clean:
    message: "Removing directories: {params}"
    params:
        "results ",
        "auspice"
    shell:
        "rm -rfv {params}"

rule check:
    run:
        from Bio import AlignIO, SeqIO, Seq
        existing_alignment = AlignIO.read("data/aligned.fasta", 'fasta')
        reference = SeqIO.read("data/reference.gb", 'genbank')
        new_sequences = [*SeqIO.parse("data/new1.fasta", "fasta"), *SeqIO.parse("data/new2.fasta", "fasta")]
        new_alignment = AlignIO.read("results/aligned.fasta", "fasta")
        existing_alignment_len = existing_alignment.get_alignment_length()
        new_alignment_len = new_alignment.get_alignment_length()
        print("reference length: ", len(reference))
        reference_matches = existing_alignment_len == len(reference)
        print("existing alignment length: ", existing_alignment_len, "Correct:", reference_matches)
        print("new sequence lengths:\n", "\n".join(["\t{}: {}".format(x.name, len(x)) for x in new_sequences]))
        new_alignment_matches = existing_alignment_len == new_alignment_len
        print("new alignment length:", new_alignment_len, "Correct:", new_alignment_matches)
        assert(reference_matches and new_alignment_matches)
        new_alignment_dict = {x.name: x.seq for x in new_alignment}
        for s in existing_alignment:
            print("Checking %s is unchanged..."%s.name)
            assert(s.seq == new_alignment_dict[s.name])
