#!/usr/bin/env python

"Program to generate and run MCS benchmarks"


import sys
import re
import argparse
import time
import gzip
import os
import random
import datetime
import itertools

# Status codes
COMPLETE = "."
INCOMPLETE = "I"
FAILURE = "F"
KILLED = "X"

TOTAL = "total"

class Error(Exception):
    pass

# A place to store the results of an MCS search
class MCSResult(object):
    def __init__(self, status, num_fragments, num_atoms, num_bonds, dt, description):
        self.status = status
        self.num_fragments = num_fragments
        self.num_atoms = num_atoms
        self.num_bonds = num_bonds
        self.dt = dt
        self.description = description


#################### Support code for the per-toolkit loaders

class MolRecord(object):
    def __init__(self, id, mol, recno, filename):
        self.id = id
        self.mol = mol
        self.recno = recno
        self.filename = filename
        

class LazyMolRecord(object):
    def __init__(self, id, record, recno, filename):
        self.id = id
        self.record = record
        self.recno = recno
        self.filename = filename
        self._mol = None
    @property
    def mol(self):
        if self._mol is None:
            self._mol = self._load_mol()
            # should I "self.record = None" here?
        return self._mol

    def _load_mol(self):
        raise NotImplementedError("Implement in subclass")

class LazySmilesRecord(LazyMolRecord):
    def __init__(self, id, record, recno, filename):
        super(LazySmilesRecord, self).__init__(id, record, recno, filename)
        self.smiles = record.split()[0]


def _get_format(filename):
    lcfilename = filename.lower()
    compression = None
    if lcfilename.endswith(".gz"):
        compression = ".gz"
        lcfilename = lcfilename[:-3]
    for suffix in (".smi", ".can", ".smiles", ".ism"):
        if lcfilename.endswith(suffix):
            return ("smi", compression)
    for suffix in (".sdf", ".mol", ".sd", ".mdl"):
        if lcfilename.endswith(suffix):
            return ("sdf", compression)
    raise Error("Unknown structure filename type %r" % (filename.encode("utf-8"),))

def get_record_ids(filename, id_tag):
    filetype, compression = _get_format(filename)
    if compression == ".gz":
        infile = gzip.open(filename)
    else:
        infile = open(filename, "U")
    if filetype == "smi":
        return _read_smiles_records(infile, filename)
    else:
        return _read_sdf_records(infile, filename, id_tag)
    


def _read_smiles_records(infile, filename):
    for lineno, line in enumerate(infile):
        fields = line.split()
        if not fields:
            raise Error("Missing SMILES data for line %d of %r" % (lineno+1, filename))
        if len(fields) == 1:
            raise Error("Missing identifier (column #2) for line %d of %r" % (lineno+1, filename))
        yield lineno+1, fields[1], line

def _record_id(record, id_pat, id_tag, recno, filename):
    if id_pat is None:
        pos = record.find("\n")
        return record[:pos].strip()
    m = id_pat.search(record)
    if m is None:
        raise Error("Cannot find id_tag %r in record #%d of %r" % (id_tag, recno+1, filename))
    return m.group(1)

def _read_sdf_records(infile, filename, id_tag):
    block = ""
    if id_tag is None:
        id_pat = None
    else:
        id_pat = re.compile(r"\n>\s+<" + re.escape(id_tag) + ">.*\n(.*)\n",
                            re.MULTILINE)
    recno = 0
    while 1:
        new_block = infile.read(100000) + infile.readline()
        if not new_block:
            if not block:
                break
            raise Error("Incomplete final SD record in %r" % (filename,))
        block += new_block
        pos = block.find("\n\n$$$$\n")
        if pos == -1:
            if len(block) > 1000000:
                raise Error("Block size too large; is this an SD file?")
            continue
        start = pos+7
        record = block[:start]
        yield recno, _record_id(record, id_pat, id_tag, recno, filename), record
        recno += 1
        while 1:
            end = block.find("\n\n$$$$\n", start)
            if end == -1:
                block = block[start:]
                break
            end += 7
            record = block[start:end]
            yield recno, _record_id(record, id_pat, id_tag, recno, filename), record
            recno += 1
            start = end
        
    
def _load_dataset(reader, filename, lazy, verbose):
    dataset = {}
    prev_recnos = {}
    if verbose:
        prev_time = time.time()
        if lazy:
            msg = "Lazy loading of all structures from %s\n" % (filename.encode("utf-8"),)
            status_msg = "\rLoaded %d records (lazy)"
        else:
            msg = "Loading all structures from %s\n" % (filename.encode("utf-8"),)
            status_msg = "\rLoaded %d structures"
        sys.stderr.write(msg)
        sys.stderr.flush()

    record = None
    for record in reader:
        id = record.id
        if id in dataset:
            msg = "Duplicate structure identifier %r in record #%d %r. Previous was record #%d" % (
                id, record.recno, record.filename, prev_recnos[id])
            raise Error(msg)
        if not lazy:
            record.mol # Parse the structure record and see if it's valid
        dataset[id] = record
        prev_recnos[id] = record.recno

        if verbose:
            if record.recno % 100 == 0:
                dt = time.time() - prev_time
                if dt > 1.0:
                    prev_time = time.time()
                    sys.stderr.write(status_msg % (record.recno,))
                    sys.stderr.flush()

    if verbose:
        if record is None:
            sys.stderr.write("\rNo structures loaded.\n")
        else:
            sys.stderr.write("\rLoaded %d structures.     \n" % (record.recno,))
    return dataset


#################### fmcs code

Chem = None
fmcs = None
rdkit_version = None
def init_fmcs():
    global Chem, fmcs, rdkit_version
    import fmcs
    from rdkit import Chem
    from rdkit import rdBase
    rdkit_version = getattr(rdBase, "rdkitVersion", "unknown")


def rdkit_read_smiles_mols(infile, filename):
    for lineno, id, record in _read_smiles_records(infile, filename):
        yield LazyRDKitSmilesRecord(id, record, lineno, filename)

class LazyRDKitSmilesRecord(LazySmilesRecord):
    def _load_mol(self):
        mol = Chem.MolFromSmiles(self.smiles)
        if mol is None:
            raise Error("Cannot parse SMILES %r on line %d of %r" % (
                self.smiles, self.recno, self.filename))
        return mol
        

def rdkit_read_sdf_mols(infile, filename, id_tag):
    for molno, mol in enumerate(Chem.ForwardSDMolSupplier(infile)):
        recno = molno+1
        if mol is None:
            raise Error("Unable to parse structure #%r" % (recno,))
        if id_tag is None:
            id = mol.GetProp("_Name").strip()
        else:
            id = mol.GetProp(id_tag).strip()
        if not id:
            raise Error("Missing identifier for structure #%d of %r" % (recno, filename))
        if len(id.split()) != 1:
            raise Error("Identifier %r for structure #%d of %r must not contain a whitespace characters" % (id, recno, filename))
        yield MolRecord(id, mol, recno, mol)


def fmcs_load_dataset(filename, id_tag, lazy, verbose):
    lcfilename = filename.lower()
    try:
        filetype, compression = _get_format(filename)
        if compression == ".gz":
            infile = gzip.open(filename)
        else:
            infile = open(filename, "U")

        if filetype == "smi":
            reader = rdkit_read_smiles_mols(infile, filename)
        else:
            reader = rdkit_read_sdf_mols(infile, filename, id_tag)

    except IOError, err:
        raise Error("Cannot open structure file: %s" % (err,))
    return _load_dataset(reader, filename, lazy, verbose)
    

def fmcs_find_mcs(query_mols, args):
    t1 = time.time()
    mcs = fmcs.fmcs(query_mols,
                    min_num_atoms = args.min_num_atoms,
                    maximize = args.maximize,
                    atom_compare = args.atom_compare,
                    bond_compare = args.bond_compare,
                    ring_matches_ring_only = args.ring_matches_ring_only,
                    complete_rings_only = args.complete_rings_only,
                    timeout = args.timeout,
                    #verbose=True
        )
    t2 = time.time()
    if mcs.completed:
        status = COMPLETE
    else:
        status = INCOMPLETE
    if mcs.smarts is None:
        num_atoms = num_bonds = num_fragments = 0
        description = "-"
    else:
        num_fragments = 1
        num_atoms = mcs.num_atoms
        num_bonds = mcs.num_bonds
        description = mcs.smarts

    return MCSResult(status, num_fragments, num_atoms, num_bonds,
                     dt=t2-t1, description=description)


######## Indigo code

indigo = IndigoException = None
def init_indigo():
    if os.path.exists("/Users/dalke/ftps/indigo-python-1.1-rc-universal"):
        sys.path.insert(0, "/Users/dalke/ftps/indigo-python-1.1-rc-universal")
    global indigo, IndigoException
    from indigo import Indigo, IndigoException
    indigo = Indigo()


def indigo_read_smiles_mols(infile, filename, aromatize, fold_hydrogens):
    for lineno, id, record in _read_smiles_records(infile, filename):
        yield LazyIndigoSmilesMol(id, record, lineno, filename, aromatize, fold_hydrogens)

class LazyIndigoSmilesMol(LazySmilesRecord):
    def __init__(self, id, data, recno, filename, aromatize, fold_hydrogens):
        super(LazyIndigoSmilesMol, self).__init__(id, data, recno, filename)
        self.filename = filename
        self.aromatize = aromatize
        self.fold_hydrogens = fold_hydrogens
    def _load_mol(self):
        mol = indigo.loadMolecule(self.smiles) # and on error .... ? XXX
        if self.aromatize:
            mol.aromatize()
        if self.fold_hydrogens:
            mol.foldHydrogens()
        return mol
        

def indigo_read_sdf_mols(filename, id_tag, aromatize, fold_hydrogens):
    for molno, mol in enumerate(indigo.iterateSDFile(filename)):
        recno = molno + 1
        if id_tag is None:
            id = mol.name().strip()
        else:
            try:
                id = mol.getProperty(id_tag)
            except IndigoException:
                raise Error("Cannot find tag %r for record #%d of %r" % (id_tag, recno, filename))
        if not id:
            raise Error("Missing identifier for record #%d of %r" % (recno, filename))
        if aromatize:
            mol.aromatize()
        if fold_hydrogens:
            mol.foldHydrogens()
        if len(id.split()) != 1:
            raise Error("Identifier for record $%d must not contain a space" % (recno, filename))
        yield MolRecord(id, mol, recno, filename)

def indigo_load_dataset(filename, id_tag, aromatize=True, fold_hydrogens=True, lazy=False, verbose=False):
    lcfilename = filename.lower()
    try:
        filetype, compression = _get_format(filename)
        if filetype == "smi":
            if compression == ".gz":
                infile = gzip.open(filename)
            else:
                infile = open(filename, "U")
            reader = indigo_read_smiles_mols(infile, filename, aromatize, fold_hydrogens)
        else:
            reader = indigo_read_sdf_mols(filename, id_tag, aromatize, fold_hydrogens)
    except (IOError, IndigoException), err:
        raise Error("Cannot open structure file: %s" % (err,))

    return _load_dataset(reader, filename, lazy=lazy, verbose=verbose)

def indigo_find_mcs_exact(query_mols, args):
    return _indigo_find_mcs(query_mols, args, "exact")

def indigo_find_mcs_approx(query_mols, args):
    method = "approx %d" % (args.iterations,)
    return _indigo_find_mcs(query_mols, args, method)


def _indigo_find_mcs(query_mols, args, method):
    if args.timeout is not None:
        indigo.setOption("timeout", int(round(args.timeout * 1000)))
    arr = indigo.createArray()
    for mol in query_mols:
        #print " ", mol.smiles()
        arr.arrayAdd(mol)

    t1 = time.time()
    try:
        scaf = indigo.extractCommonScaffold(arr, method)
        t2 = time.time()
    except IndigoException, err:
        t2 = time.time()
        if "There are no scaffolds found" in str(err):
            status = COMPLETE
            num_fragments = num_atoms = num_bonds = 0
            description = "-"
        elif "timed out" in str(err):
            status = FAILURE
            num_fragments = num_atoms = num_bonds = -1
            description = "-"
        else:
            raise
    else:
        status = COMPLETE
        num_fragments = 1
        if args.maximize == "atoms":
            data = [(scaffold.countAtoms(), scaffold.countBonds(), scaffold.smiles())
                        for scaffold in scaf.allScaffolds().iterateArray()]
            assert data
            data.sort(reverse=True)
            num_atoms, num_bonds, description = data[0]
        elif args.maximize == "bonds":
            data = [(scaffold.countBonds(), scaffold.countAtoms(), scaffold.smiles())
                        for scaffold in scaf.allScaffolds().iterateArray()]
            assert data
            data.sort(reverse=True)
            num_bonds, num_atoms, description = data[0]
        else:
            raise AssertionError(args.maximize)
            
        if not data:
            num_fragments = num_atoms = num_bonds = 0
            description = "-"
            status = COMPLETE

    return MCSResult(status, num_fragments, num_atoms, num_bonds, t2-t1, description)
    

########
class _Writer(object):
    def __init__(self, outfile):
        self.outfile = outfile

    def _writeline(self, message):
        message = message.replace("\n", "")
        self.outfile.write(message + "\n")
        self.outfile.flush()

    def _writeline_progress(self, message):
        if self.progress_file is not None:
            message = message.replace("\n", "")
            self.progress_file.write(message + "\n")
            self.progress_file.flush()

    def comment(self, comment):
        self._writeline("#  " + comment)

    def summary(self, comment):
        self._writeline("#  " + comment)

    def progress(self, message):
        self._writeline("## " + message)

    def software(self, software):
        self._writeline("#software " + software)

    def options(self, options):
        self._writeline("#options " + options)

    def timestamp(self, now):
        self._writeline("#timestamp " + now.isoformat())

    def error(self, message):
        self._writeline("#Error " + message)

    def file(self, filename):
        self._writeline("#File " + filename.encode("utf-8"))

    def id_tag(self, tag):
        self._writeline("#Id-tag " + tag)
        
    def token(self, token):
        self._writeline(token.tostring())

class MCSBenchmarkWriter(_Writer):
    def magic(self):
        self._writeline("#MCS-Benchmark/1")
        
    def error(self, message):
        self._writeline("#Error " + message)
        
    def file(self, filename):
        self._writeline("#File " + filename.encode("utf-8"))

    def id_tag(self, tag):
        self._writeline("#Id-tag " + tag)

    def mcs_result(self, label, ids, result):
        self._writeline("%s %s" % (label, " ".join(ids)))
        if result is not None:
            self.comment("  Took %.2f seconds" % (result.dt,))

    def mcs_result_all(self, label, result):
        self._writeline("%s all" % (label,))
        if result is not None:
            self.comment("  Took %.2f seconds" % (result.dt,))
            
        
class MCSBenchmarkOutputWriter(_Writer):
    def __init__(self, outfile):
        self.outfile = outfile

    def _writeline(self, message):
        message = message.replace("\n", "")
        self.outfile.write(message + "\n")
        self.outfile.flush()

    def magic(self):
        self._writeline("#MCS-Benchmark-Output/1")

    def _mcs_result(self, label, result):
        self._writeline(label + " " +
                "{0.status} {0.num_fragments} {0.num_atoms} {0.num_bonds} {0.dt:.2f} {0.description}".format(
                    result))
        
    def mcs_result(self, label, ids, result):
        if ids:
            self.comment("Using " + " ".join(ids))
        self._mcs_result(label, result)
        
    def mcs_result_all(self, label, result):
        self.comment("Using all structures ...")
        return self.mcs_result(label, None, result)


class MCSSearch(object):
    header_fields = ["magic", "software", "options", "date", "timerange"]
    
    def __init__(self, args):
        self.args = args
        self.dataset = None
        self.id_tag = None
        if args.min_time <= 0.0:
            if args.max_time is None:
                check_time = lambda dt: 1
                timerange = None
            else:
                check_time = lambda dt: dt <= args.max_time
                timerange = "Displaying searches which took at most %.1f seconds" % (args.max_time,)
        else:
            if args.max_time is None:
                check_time = lambda dt: args.min_time <= dt
                timerange = "Displaying searches which took at least %.1f seconds" % (args.min_time,)
            else:
                check_time = lambda dt: args.min_time <= dt <= args.max_time
                timerange = "Displaying searches which took between %.1f and %.1f seconds" % (
                    args.min_time, args.max_time)
        self._check_time = check_time
        self._timerange = timerange

    def write_header(self, output):
        for name in self.header_fields:
            getattr(self, "write_"+name)(output)

    def write_magic(self, output):
        output.magic()

    def write_software(self, output):
        raise NotImplementedError("Must be implemented in the subclass")

    def write_options(self, output):
        raise NotImplementedError("Must be implemented in the subclass")

    def write_date(self, output):
        output.timestamp(datetime.datetime.now())

    def write_timerange(self, output):
        if self._timerange is not None:
            output.comment(self._timerange)


    def process(self, output, token, stats):
        try:
            if isinstance(token, FileStmt):
                self.process_file(output, token)
            elif isinstance(token, IdTagStmt):
                self.process_id_tag(output, token)
            elif isinstance(token, CommentStmt):
                self.process_comment(output, token)
            elif isinstance(token, RequiredStmt):
                self.process_required(output, token)
            elif isinstance(token, OptionalStmt):
                self.process_optional(output, token)
            elif isinstance(token, MCSRequestAll):
                self.process_request_all(output, token, stats)
            elif isinstance(token, MCSRequest):
                self.process_request(output, token, stats)
            elif isinstance(token, MagicStmt):
                self.process_magic(output, token)
            elif isinstance(token, ProgressStmt):
                self.process_progress(output, token)
            else:
                raise Error("Unknown token %r" % (token,))
        except Error, err:
            output.error(str(err))
            raise

    def process_magic(self, output, token):
        pass
    
    def process_file(self, output, token):
        self.dataset = None
        try:
            dataset = self.load_dataset(token.filename, self.id_tag)
        except IOError, msg:
            raise Error("Cannot read %s: %r" % (token.filename, msg))
        else:
            output.file(token.filename)
            output.progress("  Loaded %d structures." % (len(dataset),))

        self.dataset = dataset

    def process_id_tag(self, output, token):
        self.id_tag = token.tag

    def process_comment(self, output, token):
        output.comment(token.comment)

    def process_required(self, output, token):
        raise Error("Unsupported required line %r" % (token.tostring(),))

    def process_optional(self, output, token):
        return

    def process_progress(self, output, token):
        return

    def process_request_all(self, output, token, stats):
        if self.dataset is None:
            raise Error("No structure filename specified; cannot do MCS search")
        
        mcs_result = self.find_mcs(self.dataset.values(), self.args)
        stats[mcs_result.status].add(mcs_result.dt, mcs_result.num_atoms, mcs_result.num_bonds)
        stats[TOTAL].add(mcs_result.dt, mcs_result.num_atoms, mcs_result.num_bonds)
        if self._check_time(mcs_result.dt):
            output.mcs_result_all(token.label, mcs_result)

    def process_request(self, output, token, stats):
        dataset = self.dataset
        if dataset is None:
            raise Error("No structure filename specified; cannot do MCS search")

        query_mols = []
        for id in token.ids:
            record = self.dataset.get(id, None)
            if record is None:
                raise Error("Cannot find id %r in the current dataset" % (id,))
            query_mols.append(record.mol)

        if len(query_mols) < 2:
            raise Error("Must have at least two molecules in order to find the MCS")

        mcs_result = self.find_mcs(query_mols, self.args)
        stats[mcs_result.status].add(mcs_result.dt, mcs_result.num_atoms, mcs_result.num_bonds)
        stats[TOTAL].add(mcs_result.dt, mcs_result.num_atoms, mcs_result.num_bonds)
        if self._check_time(mcs_result.dt):
            output.mcs_result(token.label, token.ids, mcs_result)

#######

class FMCSSearch(MCSSearch):
    def __init__(self, args):
        super(FMCSSearch, self).__init__(args)
        init_fmcs()

    def load_dataset(self, filename, id_tag):
        return fmcs_load_dataset(filename, id_tag, self.args.lazy, self.args.verbose)

    find_mcs = staticmethod(fmcs_find_mcs)

    def write_software(self, output):
        output.software("fmcs/" + fmcs.__version__ + " RDKit/" + rdkit_version)

    def write_options(self, output):
        args = self.args
        s = "atom-compare=%s bond-compare=%s min-num-atoms=%s" % (
            args.atom_compare, args.bond_compare, args.min_num_atoms)
        if args.complete_rings_only:
            s += " complete-rings-only=True"
        elif args.ring_matches_ring_only:
            s += " ring-matches-ring-only=True"
        if args.timeout is not None:
            s += " timeout={0:.2f}".format(args.timeout)
        output.options(s)

class IndigoExactSearch(MCSSearch):
    def __init__(self, args):
        super(IndigoExactSearch, self).__init__(args)
        init_indigo()

    def load_dataset(self, filename, id_tag):
        return indigo_load_dataset(filename, id_tag, self.args.aromatize, self.args.fold_hydrogens,
                                   self.args.lazy, verbose=self.args.verbose)

    find_mcs = staticmethod(indigo_find_mcs_exact)

    def write_software(self, output):
        output.software("Indigo/" + indigo.version() + " extractCommonScaffold")

    def write_options(self, output):
        s = "method=exact maximize=%s atom-compare=elements bond-compare=bondtypes" % (self.args.maximize,)
        s += " aromatize=%s fold-hydrogens=%s" % (self.args.aromatize, self.args.fold_hydrogens)
        if self.args.timeout is not None:
            s += " timeout={0:.2f}".format(self.args.timeout)
        output.options(s)


class IndigoApproxSearch(MCSSearch):
    def __init__(self, args):
        super(IndigoApproxSearch, self).__init__(args)
        init_indigo()

    def load_dataset(self, filename, id_tag):
        return indigo_load_dataset(filename, id_tag, self.args.aromatize, self.args.fold_hydrogens,
                                   self.args.lazy, verbose=self.args.verbose)

    find_mcs = staticmethod(indigo_find_mcs_approx)

    def write_software(self, output):
        output.software("Indigo/" + indigo.version() + " extractCommonScaffold")

    def write_options(self, output):
        s = "method=approx iterations=%d maximize=%s atom-compare=elements bond-compare=bondtypes" % (
            self.args.iterations, self.args.maximize)
        s += " aromatize=%s fold-hydrogens=%s" % (self.args.aromatize, self.args.fold_hydrogens)
        if self.args.timeout is not None:
            s += " timeout={0:.2f}".format(self.args.timeout)
        output.options(s)


#######

class MagicStmt(object):
    type = "Magic"
    version = 1
    def __init__(self, line):
        self.line = line
    def tostring(self):
        return self.line

class RequiredStmt(object):
    type = "Required"
    def __init__(self, name, text):
        self.name = name
        self.text = text
    def tostring(self):
        return "#%s %s" % (self.name, self.text)

class FileStmt(RequiredStmt):
    type = "File"
    def __init__(self, filename):
        super(FileStmt, self).__init__("File", filename.encode("utf8"))
        self.filename = filename

class IdTagStmt(RequiredStmt):
    type = "IdTag"
    def __init__(self, tag):
        super(IdTagStmt, self).__init__("Id-tag", tag)
        self.tag = tag
        
class OptionalStmt(object):
    type = "Optional"
    def __init__(self, name, text):
        self.name = name
        self.text = text
    def tostring(self):
        return "#%s %s" % (self.name, self.text)

class CommentStmt(object):
    type = "Comment"
    def __init__(self, comment):
        self.comment = comment
    def tostring(self):
        return "# %s" % (self.comment,)

class ProgressStmt(object):
    def __init__(self, line):
        self.line = line
    def tostring(self):
        return line

class MCSRequest(object):
    type = "MCSRequest"
    def __init__(self, label, ids):
        self.label = label
        self.ids = ids
    def tostring(self):
        return "%s %s" % (self.label, " ".join(self.ids))

class MCSRequestAll(object):
    type = "MCSRequestAll"
    def __init__(self, label):
        self.label = label
    def tostring(self):
        return "%s all" % (self.label,)

##

class BenchmarkReader(object):
    def __init__(self, infile, controlfile=None):
        self.infile = infile
        self.lineno = 0
        self.controlfile = controlfile

    def _readline(self):
        if self.controlfile is not None:
            self.controlfile.write("#Ready\n")
            self.controlfile.flush()
        line = self.infile.readline()
        if line:
            self.lineno += 1
        return line

    def __iter__(self):
        return iter(self._next_token, None)

    def _next_token(self):
        if self.lineno == 0:
            first_line = self._readline()
            if not first_line:
                raise Error("Empty input file; missing header line")
            if first_line != "#MCS-Benchmark/1\n":
                name = getattr(self.infile, "name", None)
                if name is not None:
                    name = repr(name)
                else:
                    name = "input"
                raise Error("First line of %s must be '#MCS-Benchmark/1', not %r" % (name, first_line))
            return MagicStmt(first_line)

        line = self._readline()
        if not line:
            # End of file
            return None

        try:
            # Remove terminal newline
            line = line.rstrip("\n")
            if line.startswith("#File "):
                _, _, filename = line.partition(" ")
                return FileStmt(filename.decode("utf8"))
            elif line.startswith("#Id-tag "):
                return IdTagStmt(line[8:].strip())
            elif line.startswith("# "):
                return CommentStmt(line[2:])
            elif line == "#":
                return CommentStmt("")
            elif line.startswith("##"):
                return ProgressStmt(line)
            elif line.startswith("#"):
                c = line[1]
                if not c.isalpha():
                    raise Error("Unsupported statement %r" % (line,))
                name, _, text = line.partition(" ")
                name = name[1:]
                if c == c.upper():
                    return RequiredStmt(name, text)
                else:
                    return OptionalStmt(name, text)
            else:
                # This must be a request for an MCS comparison
                return parse_mcs_request(line)
        except ValueError, err:
            raise ValueError("%s at line %d" % (err, self.lineno))
    
def parse_mcs_request(line):
    # It's in the form:
    #   label id[0] id[1]> ....
    # -or-
    #   label 'all'

    assert line[:1] != "#", "Should already have been checked"
    assert not line[:1].isspace()
    
    fields = line.split()
    if len(fields) == 1:
        raise ValueError("Missing identifiers: %r" % (line,))
    label = fields[0]

    if len(fields) == 2:
        option = fields[1]
        if option == 'all':
            return MCSRequestAll(label)
        raise ValueError("Unknown MCS request option %r" % (option,))

    return MCSRequest(label, fields[1:])


#####
# mcsbenchmark indigo --args x.mcsb
# mcsbenchmark indigo-approx --args x.mcsb
# mcsbenchmark fmcs --args x.mcsb
# mcsbenchmark pairs --seed N --num-tests 1000 structure_filename --prefix spam{.smi} {.mcsb}
# mcsbenchmark neighbors --seed N --num 1000 -k k --threshold T --prefix spam{.smi} {.mcsb} structure_filename fps_filename 


def parse_timeout(s):
    if s == "none":
        return None
    timeout = float(s)
    if timeout < 0.0:
        raise argparse.ArgumentTypeError("Must be a non-negative value, not %r" % (s,))
    return timeout


def _add_standard_mcs_options(parser):
    parser.add_argument("--timeout", type=parse_timeout, default=None, metavar="SECONDS",
                        help="Quit the MCS calculation after SECONDS seconds")
    parser.add_argument("--min-time", type=float, default=0.0, metavar="SECONDS",
                        help="Do not report searches taking less than SECONDS seconds")
    parser.add_argument("--max-time", type=float, default=None, metavar="SECONDS",
                        help="Do not report searches taking more than SECONDS seconds")
    parser.add_argument("--lazy", action="store_true",
                        help="Do not parse the structure records until needed")
    parser.add_argument("--output-format", choices=["mcs-output", "mcs-benchmark"], default=None)
    parser.add_argument("--client", action="store_true",
                        help="Enable experimental client protocol")
    parser.add_argument("--progress", action="store_true",
                        help="Write partial progress information to the benchmark output file")
    parser.add_argument("--verbose", action="store_true",
                        help="Write status and summary information to stderr")

parser = argparse.ArgumentParser(description="Run an MCS benchmark or generate benchmark data")

subparsers = parser.add_subparsers(title="subcommands",
                                   description = "Valid subcommands")

parser_fmcs = subparsers.add_parser("fmcs", help="Benchmark fmcs")
if 1:
    parser_fmcs.add_argument("--maximize", choices=["atoms", "bonds"], default="atoms",
                             help="Should the MCS maximize the number of atoms or the number of bonds?")
    parser_fmcs.add_argument("--atom-compare", choices=["any", "elements", "isotopes"],
                             default="elements", help=(
                        "Specify the atom comparison method. With 'any', every atom matches every "
                        "other atom. With 'elements', atoms match only if they contain the same element. "
                        "With 'isotopes', atoms match only if they have the same isotope number; element "
                        "information is ignored so [5C] and [5P] are identical. (Default: elements)"))

    parser_fmcs.add_argument("--bond-compare", choices=["any", "ignore-aromaticity", "bondtypes"],
                             default="bondtypes", help=(
                        "Specify the bond comparison method. With 'any', every bond matches every "
                        "other bond. With 'ignore-aromaticity', aromatic bonds match single, aromatic, "
                        "and double bonds, but no other types match each other. With 'bondtypes', bonds "
                        "are the same only if their bond type is the same. (Default: bondtypes)"))
    
    parser_fmcs.add_argument("--min-num-atoms", type=int, default=2, metavar="INT",
                             help="Minimum number of atoms in the MCS. Must be at least 2. (Default: 2)")
    
    parser_fmcs.add_argument("--ring-matches-ring-only", action="store_true", help=
                    "Modify the bond comparison so that ring bonds only match ring bonds and chain "
                    "bonds only match chain bonds. (Ring atoms can still match non-ring atoms.)")

    parser_fmcs.add_argument("--complete-rings-only", action="store_true", help=
                    "If a bond is a ring bond in the input structures and a bond is in the MCS "
                    "then the bond must also be in a ring in the MCS. Selecting this option also "
                    "enables --ring-matches-ring-only.")
    _add_standard_mcs_options(parser_fmcs)
    
    parser_fmcs.add_argument("mcsb_filename", nargs="?")
    parser_fmcs.set_defaults(cmd="fmcs")

parser_indigo_exact = subparsers.add_parser("indigo-exact",
            help="Benchmark Indigo's extractCommonScaffold exact search")
if 1:
    parser_indigo_exact.add_argument(
        "--maximize", choices=["atoms", "bonds"], default="atoms",
        help="Should the MCS maximize the number of atoms or the number of bonds?")
    
    parser_indigo_exact.add_argument("--no-aromatize", dest="aromatize", action="store_false",
                                      help="Don't reperceive input aromaticity")
    parser_indigo_exact.add_argument("--no-fold-hydrogens", dest="fold_hydrogens", action="store_false",
                                      help="Don't remove hydrogens from the input structure")
    _add_standard_mcs_options(parser_indigo_exact)
    
    parser_indigo_exact.add_argument("mcsb_filename", nargs="?")
    parser_indigo_exact.set_defaults(cmd="indigo")

parser_indigo_approx = subparsers.add_parser("indigo-approx",
            help="Benchmark Indigo's extractCommonScaffold approximate search")
if 1:
    parser_indigo_approx.add_argument(
        "--maximize", choices=["atoms", "bonds"], default="atoms",
        help= "Should the MCS maximize the number of atoms or the number of bonds?")
    parser_indigo_approx.add_argument("--iterations", type=int, default=1000, metavar="N",
                                      help="Stop the search after N iterations (Default: 1000)")
    parser_indigo_approx.add_argument("--no-aromatize", dest="aromatize", action="store_false",
                                      help="Don't reperceive input aromaticity")
    parser_indigo_approx.add_argument("--no-fold-hydrogens", dest="fold_hydrogens", action="store_false",
                                      help="Don't remove hydrogens from the input structure")
    _add_standard_mcs_options(parser_indigo_approx)
    
    parser_indigo_approx.add_argument("mcsb_filename", nargs="?")
    parser_indigo_approx.set_defaults(cmd="indigo-approx")


parser_random = subparsers.add_parser("random",
             help="Generate an MCS benchmark file using randomly selected records from a structure file")
if 1:
    parser_random.add_argument("--seed", type=int, help = "initial random number seed")
    parser_random.add_argument("-k", type=int, default=2,
                               help="select k elements for each test")
    parser_random.add_argument("--num-tests", "-n", type=int, default=100, metavar="N",
                               help = "number of test pairs to generate")
    parser_random.add_argument("--id-tag", metavar="TAG",
                               help = "SD tag name containing the primary identifier")
    parser_random.add_argument("--subset-filename", metavar="FILENAME",
                               help = "Save the subset of the structures used for the tests into FILENAME")
    parser_random.add_argument("--verbose", action="store_true",
                               help="Write status information to stderr")
    parser_random.add_argument("structure_filename", help = "input SD or SMILES file")
    parser_random.set_defaults(cmd="random")

parser_neighbors = subparsers.add_parser("neighbors",
             help="Generate an MCS benchmark file using nearest-neighbor searches of a fingerprint file")
if 1:
    parser_neighbors.add_argument("--seed", type=int, help = "initial random number seed")
    parser_neighbors.add_argument("--num-tests", "-n", type=int, default=100, metavar="N",
                                  help = "number of test cases to generate")
    parser_neighbors.add_argument("--k", "-k", type=int, default=None,
                                  help = "maximum number of neighgors to use")
    parser_neighbors.add_argument("--k-min", type=int, default=2,
                                  help = "minimum number of neighbors to use")
    parser_neighbors.add_argument("--threshold", type=float, default=None,
                                  help = "minimum threshold")
    parser_neighbors.add_argument("--prefix", help = "output prefix for the MCS and structure filenames")
    parser_neighbors.add_argument("--structures", metavar="FILENAME",
                                  help = "input SD or SMILES file (Default: use the FPS source field)")
    parser_neighbors.add_argument("--id-tag", metavar="TAG",
                                  help = "SD tag name containing the primary identifier")
    parser_neighbors.add_argument("--subset-filename", metavar="FILENAME",
                                  help = "Save the subset of the structures used for the tests into FILENAME")
    parser_neighbors.add_argument("--verbose", action="store_true",
                                  help="Write progress information to stderr")
    parser_neighbors.add_argument("fps_filename", help = "structure fingerprints for the similarity search")
    parser_neighbors.set_defaults(cmd="neighbors")

parser_subset = subparsers.add_parser("subset", help=
         ("Given a benchmark file with one '#File', create a new structure "
          "file containing only the records from the file which are used by the benchmark."))
if 1:
    parser_subset.add_argument("--subset-filename", metavar="FILENAME", required=True,
                               help = "Save the subset of the structures used for the tests into FILENAME")
    parser_subset.add_argument("--verbose", action="store_true",
                               help="Write progress information to stderr")
    parser_subset.add_argument("mcsb_filename")
    parser_subset.set_defaults(cmd="subset")

def get_benchmark_reader(filename, as_client):
    if filename is None:
        infile = sys.stdin
    else:
        infile = open(filename)
    if as_client:
        controlfile = sys.stdout
    else:
        controlfile = None
    return BenchmarkReader(infile, controlfile)

class Stat(object):
    def __init__(self, name):
        self.name = name
        self.count = 0
        self.total_time = 0.0
        self.total_num_atoms = self.total_num_bonds = 0
    def __len__(self):
        return self.count
    def add(self, dt, num_atoms=-1, num_bonds=-1):
        self.count += 1
        self.total_time += dt
        if (num_atoms == -1) != (num_bonds == -1):
            raise AssertionError("Both num_atoms and num_bonds must be -1, or both not -1")
        if num_atoms != -1:
            self.total_num_atoms += num_atoms
            self.total_num_bonds += num_bonds
            

    @property
    def per_second(self):
        if self.total_time == 0.0:
            return "N/A"
        else:
            return "{0:.1f}".format(self.count / self.total_time)
    @property
    def average_time(self):
        if self.count == 0:
            return "N/A"
        else:
            return "{0:.1f}".format(self.total_time/self.count)

    @property
    def average_num_atoms(self):
        if self.count == 0:
            return "N/A"
        return "{0:.1f}".format(self.total_num_atoms/float(self.count))

    @property
    def average_num_bonds(self):
        if self.count == 0:
            return "N/A"
        return "{0:.1f}".format(self.total_num_bonds/float(self.count))
    

def format_stat(stat, show_division=True):
    if show_division:
        fmt = "{0.name}: {0.count}/{0.total_time:.2f}s ({0.per_second})"
    else:
        fmt = "{0.name}: {0.count}/{0.total_time:.2f}s"
    return fmt.format(stat)

def do_search(args, searcher, output):
    benchmark_reader = get_benchmark_reader(args.mcsb_filename, args.client)
    searcher.write_header(output)
    stats = {TOTAL: Stat("Total"),
             COMPLETE: Stat("Complete"),
             INCOMPLETE: Stat("Incomplete"),
             FAILURE: Stat("Fail"),
             KILLED: Stat("Kill"), # Not yet needed ...
             }
    prev_count = 0
    for token in benchmark_reader:
        searcher.process(output, token, stats)
        count = stats[TOTAL].count
        if count % 25 == 0 and count != prev_count:
            progress_msg = (
                 "{0.name}: {0.count}/{0.total_time:.1f}s ({0.per_second}/s) "
                 "{1.name}: {1.count}/{1.total_time:.1f}s ({1.per_second}/s) "
                 "{2.name}: {2.count}/{2.total_time:.1f}s "
                 "{3.name}: {3.count} "
                 "{4.name}: {4.count}").format(
                     stats[TOTAL], stats[COMPLETE], stats[INCOMPLETE], stats[FAILURE], stats[KILLED])
            if args.progress:
                output.progress(progress_msg)
            if args.verbose:
                sys.stderr.write("## " + progress_msg + "\n")
                sys.stderr.flush()
        prev_count = count

    if args.verbose:
        def summary(msg):
            output.summary(msg)
            sys.stderr.write(msg + "\n")
    else:
        summary = output.summary

    output.summary("")
    output.summary("         Summary")
    output.summary("")
    prefix = "{0.name} {0.count} in {0.total_time:.2f} seconds"
    per_second = prefix + " ({0.per_second}/second)"
    avg_time = prefix + " (average {0.average_time} sec)"
    match_size = "    {0.total_num_atoms} atoms {0.total_num_bonds} bonds; average {0.average_num_atoms} atoms {0.average_num_bonds} bonds"
    
    summary(per_second.format(stats[TOTAL]))
    summary(match_size.format(stats[TOTAL]))
    
    summary(per_second.format(stats[COMPLETE]))
    summary(match_size.format(stats[COMPLETE]))

    summary(per_second.format(stats[INCOMPLETE]))
    summary(match_size.format(stats[INCOMPLETE]))

    summary(avg_time.format(stats[FAILURE]))
    summary(avg_time.format(stats[KILLED]))


def open_search_output(args):
    if args.output_format is None or args.output_format == "mcs-benchmark-output":
        return MCSBenchmarkOutputWriter(sys.stdout)
    if args.output_format == "mcs-benchmark":
        return MCSBenchmarkWriter(sys.stdout)
    raise AssertionError("Unknown format name %r" % (args.output_format,))

def do_random(args):
    if args.k < 2:
        parser_random.error("-k must be at least 2")
    if args.num_tests <= 0:
        parser_random.error("--num-tests must be at least 1")
        
    ids = []
    for recno, id, record in get_record_ids(args.structure_filename, args.id_tag):
        ids.append(id)

    if args.k > len(ids):
        raise SystemExit("-k is %d but there are only %d records in %r" % (
            args.k, len(ids), args.structure_filename.encode("utf-8")))

    all_query_ids = []
    seed = args.seed
    if seed is None:
        seed = random.randrange(2**32)
    rng = random.Random(seed)
    for i in xrange(args.num_tests):
        query_ids = rng.sample(ids, args.k)
        all_query_ids.append(query_ids)

    needed_ids = set(itertools.chain.from_iterable(all_query_ids))
    file_filename = _make_subset_structure_file(
        args.structure_filename, args.subset_filename,
        args.id_tag, needed_ids)

    output = _start_select_output(args, file_filename)
    output.comment("%d randomly generated tests, each with %d ids. Seed=%d" % (
        args.num_tests, args.k, seed))
    _write_all_queries(output, all_query_ids, None)

def do_neighbors(args):
    try:
        import chemfp
    except ImportError:
        sys.stderr.write("Please install chemfp from http://code.google.com/p/chem-fingerprints/\n")
        raise
    if args.verbose:
        sys.stderr.write("Reading fingerprints from %s\n" % (args.fps_filename.encode("utf-8"),))
        sys.stderr.flush()
    fps = chemfp.load_fingerprints(args.fps_filename)
    ids = fps.ids

    if args.k_min < 2:
        parser_neighbors.error("--k-min must be at least 2")
    k_min = args.k_min

    if args.k is None:
        if args.threshold is None:
            k = 3
            threshold = 0.7
            description = "3-nearest Tanimoto search with threshold 0.700"
        else:
            k = None
            threshold = args.threshold
            description = "Tanimoto search with threshold %.3f" % (threshold,)
    else:
        k = args.k
        if args.k < k_min:
            parser_neighbors.error("--k must be at least equal to --k-min, or left unspecified")
        
        if args.threshold is None:
            threshold = 0.0
            description = "%d-nearest Tanimoto search" % (k,)
        else:
            threshold = args.threshold
            description = "%d-nearest Tanimoto search with threshold %.3f" % (k, threshold)

    if args.structures is None:
        sources = fps.metadata.sources
        if len(sources) == 1:
            structure_filename = sources[0]
        else:
            if not sources:
                raise Error("No --structures specified and no sources listed in %r" % (
                    args.fps_filename,))
            raise Error("No --structures specified and too many sources listed in %r" % (
                    args.fps_filename,))
    else:
        structure_filename = args.structures

    seed = args.seed
    if seed is None:
        seed = random.randrange(2**32)
    rng = random.Random(seed)
    query_index = rng.randrange(len(fps))

    no_hits = 0
    all_query_ids = []
    messages = []
    if args.verbose:
        sys.stderr.write(description + "\n")
        sys.stderr.flush()

    while len(all_query_ids) < args.num_tests:
        if args.verbose:
            n = len(all_query_ids) + 1
            if n == 1 or (n % 10 == 0):
                sys.stderr.write("\rNeighbor search %d/%d" % (n, args.num_tests))
                sys.stderr.flush()

        query_id, query_fp = rng.choice(fps)
        if k is None:
            hits = fps.threshold_tanimoto_search_fp(query_fp, threshold=threshold)
            hits.reorder("decreasing-score")
        else:
            hits = fps.knearest_tanimoto_search_fp(query_fp, k=k, threshold=threshold)
        if len(hits) < k_min:
            if no_hits is not None:
                no_hits += 1
                if no_hits > 1000:
                    raise Error("Tried %d times to find a similarity match, and failed")
            continue
        else:
            no_hits = None

        all_query_ids.append(hits.get_ids())
        messages.append("query=%s num hits=%d minimum score=%.2f " % (
            query_id, len(hits), hits.get_scores()[-1]))
    if args.verbose:
        sys.stderr.write("\rCompleted %d similarity searches.\n" % (args.num_tests,))

    # Save the needed structures to a new structure file
    needed_ids = set(itertools.chain.from_iterable(all_query_ids))
    file_filename = _make_subset_structure_file(
        structure_filename, args.subset_filename,
        args.id_tag, needed_ids, args.verbose)

    if args.verbose:
        sys.stderr.write("Writing MCS benchmark to stdout\n")
    output = _start_select_output(args, file_filename)
    output.comment("Test cases found using " + description + ". Seed=%d" % (seed,))
    _write_all_queries(output, all_query_ids, messages)
    if args.verbose:
        sys.stderr.write("Done.\n")

def _make_subset_structure_file(structure_filename, subset_filename, id_tag, needed_ids, verbose=False):
    if subset_filename is None:
        return structure_filename

    input_filetype, input_compression = _get_format(structure_filename)
    output_filetype, output_compression = _get_format(subset_filename)
    if input_filetype != output_filetype:
        raise SystemExit("Input is a %r file but the output is a %r file" % (
            input_filetype, output_filetype))
    if output_compression == ".gz":
        outfile = gzip.open(subset_filename, "w")
    else:
        outfile = open(subset_filename, "w")

    found_ids = set()
    try:
        for recno, id, record in get_record_ids(structure_filename, id_tag):
            if verbose:
                if recno == 1 or recno % 1000 == 0:
                    sys.stderr.write("\rProcessing record %d" % (recno,))
                    sys.stderr.flush()
            if id in needed_ids:
                if id in found_ids:
                    raise Error("Multiple records with the id %r" % (id,))
                outfile.write(record)
                found_ids.add(id)
                if len(found_ids) == len(needed_ids):
                    break
        if verbose:
            sys.stderr.write("\rProcessed %d records\n" % (recno,))
            sys.stderr.flush()
    finally:
        outfile.close()
    if found_ids != needed_ids:
        diff = needed_ids - found_ids
        if len(diff) < 10:
            raise Error("Could not find %d required identifiers from %r: %s" % (
                len(diff), structure_filename.encode("utf-8"), " ".join(sorted(needed_ids))))
        else:
            raise Error("Could not find %d required identifiers from %r, examples: %s ..." % (
                len(diff), structure_filename.encode("utf-8"), " ".join(sorted(needed_ids)[:10])))
    return subset_filename

def _start_select_output(args, file_filename):
    output = MCSBenchmarkWriter(sys.stdout)
    output.magic()
    if args.id_tag is not None:
        output.id_tag(args.id_tag)
    output.file(file_filename)
    return output

def _write_all_queries(output, all_query_ids, messages=None):
    if messages is None:
        for label_i, query_ids in enumerate(all_query_ids):
            output.mcs_result(str(label_i+1), query_ids, None)
    else:
        for label_i, (query_ids, message) in enumerate(zip(all_query_ids, messages)):
            if message is not None:
                output.comment(message)
            output.mcs_result(str(label_i+1), query_ids, None)


def do_subset(args):
    benchmark_reader = get_benchmark_reader(args.mcsb_filename, as_client=False)
    outfile = sys.stdout
    benchmark_writer = MCSBenchmarkWriter(outfile)
    old_structures = None
    id_tag = None
    needed_ids = set()
    for token in benchmark_reader:
        if isinstance(token, FileStmt):
            if old_structures is None:
                old_structures = token.filename
                token.filename = args.subset_filename
            else:
                raise Error("'subset' only supports one #File but multiple #File lines found")
        elif isinstance(token, MCSRequestAll):
            raise Error("Cannot make a subset; there is an MCS request 'all' line: %s" % (token.tostring(),))
        elif isinstance(token, MCSRequest):
            if old_structures is None:
                raise Error("MCS request %r but no #File specified" % (token.tostring(),))
            needed_ids.update(token.ids)
        elif isinstance(token, IdTagStmt):
            # This does not doing a rigorous validity test
            id_tag = token.tag

        benchmark_writer.token(token)

    outfile.close()

    if needed_ids:
        _make_subset_structure_file(old_structures, args.subset_filename, id_tag, needed_ids)
    else:
        # Make an empty file.
        open(args.save_structures, "w").close()
        
    

def main(args=None):
    args = parser.parse_args(args)

    if args.cmd == "fmcs":
        if args.min_num_atoms < 2:
            parser.error("--min-num-atoms must be at least 2")
        output = open_search_output(args)
        searcher = FMCSSearch(args)
        do_search(args, searcher, output)

    elif args.cmd == "indigo":
        output = open_search_output(args)
        searcher = IndigoExactSearch(args)
        do_search(args, searcher, output)

    elif args.cmd == "indigo-approx":
        output = open_search_output(args)
        searcher = IndigoApproxSearch(args)
        do_search(args, searcher, output)

    elif args.cmd == "random":
        do_random(args)

    elif args.cmd == "neighbors":
        do_neighbors(args)

    elif args.cmd == "subset":
        do_subset(args)
        
    else:
        raise AssertionError("not implemented: %r" % (args.cmd,))
        
if __name__ == "__main__":
    main(sys.argv[1:])

