#!/usr/bin/python3
import argparse
import Levenshtein
import os
import regex as re
import secrets
import shutil
import subprocess
import sys

parser = argparse.ArgumentParser(prog='kal-spell-heur', description='Attempts speller and other heuristic analysis for missing analyses')
parser.add_argument('-s', '--speller-bin', action='store', help='Path to hfst-ospell-office; defaults to searching $PATH', default='')
parser.add_argument('-S', '--speller-fst', action='store', help='Path to kl.zhfst; defaults to searching install paths', default='')
parser.add_argument('-t', '--tokenise-bin', action='store', help='Path to hfst-tokenise; defaults to searching $PATH', default='')
parser.add_argument('-T', '--tokenise-fst', action='store', help='Path to tokeniser-disamb-gt-desc.pmhfst; defaults to searching install paths', default='')
parser.add_argument('-D', '--debug', action='store_true', help='Enable Python stack traces and other debugging', default=False)
cmdargs = parser.parse_args()

if not cmdargs.debug:
	sys.tracebacklimit = 0

g_prefix = '/usr'
g_dir = os.path.dirname(os.path.abspath(__file__))

# hfst-ospell-office
bins = [cmdargs.speller_bin, shutil.which('hfst-ospell-office'), '/usr/local/bin/hfst-ospell-office', '/opt/local/bin/hfst-ospell-office', '/usr/bin/hfst-ospell-office']
bin_spell = '';
for f in bins:
	if f and os.access(f, os.X_OK):
		bin_spell = f
		break

if not bin_spell:
	raise RuntimeError('No usable hfst-ospell-office given or found!')
if cmdargs.debug:
	print(f"kal-spell-heur: Using {bin_spell}", file=sys.stderr, flush=True)

# hfst-tokenise
bins = [cmdargs.tokenise_bin, shutil.which('hfst-tokenise'), '/usr/local/bin/hfst-tokenise', '/opt/local/bin/hfst-tokenise', '/usr/bin/hfst-tokenise']
bin_tok = '';
for f in bins:
	if f and os.access(f, os.X_OK):
		bin_tok = f
		break

if not bin_tok:
	raise RuntimeError('No usable hfst-tokenise given or found!')
if cmdargs.debug:
	print(f"kal-spell-heur: Using {bin_tok}", file=sys.stderr, flush=True)

# kl.zhfst
fsts = [
	'',
	'${prefix}/share/voikko/3/kl.zhfst'.replace('${prefix}', g_prefix),
	'/usr/local/share/voikko/3/kl.zhfst',
	'/usr/share/voikko/3/kl.zhfst',
	f'{g_dir}/../spellcheckers/3/kl.zhfst',
	]
fst_spell = ''
for f in fsts:
	if f and os.path.exists(f) and os.path.getsize(f):
		fst_spell = f
		break

if not fst_spell:
	raise RuntimeError('No usable kl.zhfst given or found!')
if cmdargs.debug:
	print(f"kal-spell-heur: Using {fst_spell}", file=sys.stderr, flush=True)

# tokeniser-disamb-gt-desc.pmhfst
fsts = [
	'',
	'${prefix}/share/giella/kal/tokeniser-disamb-gt-desc.pmhfst'.replace('${prefix}', g_prefix),
	'/usr/local/share/giella/kal/tokeniser-disamb-gt-desc.pmhfst',
	'/usr/share/giella/kal/tokeniser-disamb-gt-desc.pmhfst',
	f'{g_dir}/../tokenisers/tokeniser-disamb-gt-desc.pmhfst',
	]
fst_tok = ''
for f in fsts:
	if f and os.path.exists(f) and os.path.getsize(f):
		fst_tok = f
		break

if not fst_tok:
	raise RuntimeError('No usable tokeniser.pmhfst given or found!')
if cmdargs.debug:
	print(f"kal-spell-heur: Using {fst_tok}", file=sys.stderr, flush=True)


spell = subprocess.Popen([bin_spell, '-T', '-d', fst_spell], bufsize=1, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='UTF-8', text=True)
tok = subprocess.Popen([bin_tok, '-L', fst_tok], bufsize=1, stdin=subprocess.PIPE, stdout=subprocess.PIPE, encoding='UTF-8', text=True)

while g := spell.stdout.readline():
	g = g.strip()
	if g == '@@ hfst-ospell-office is alive':
		break
	if g.startswith('@@ '):
		continue
	raise RuntimeError(f'Speller gave unexpected greeting: {g}')


cohorts = []
cohort = {
	'p': '',
	'w': '',
	'rs': [],
	's': '',
	}

old_words = {
	'b': set(['baaja', 'baalia', 'baaliar', 'bajeri', 'biibili', 'biili', 'biiler', 'bussi']),
	'd': set(['diaavulu', 'decembari']),
	'f': set(['farisiiari', 'februaari', 'feeria', 'feeriar', 'freer']),
	'g': set(['gassi', 'guuti']),
	'h': set(['hiisti', 'horaa', 'horaartor', 'huaa', 'huaartor']),
	'j': set(['januaari', 'joorli', 'joorlisior', 'jorngoq', 'juuli', 'juulli', 'juumooq', 'juuni', 'juuti']),
	'l': set(['laaja', 'lakker', 'lakki', "lal'laaq", 'lappi', 'liimmer', 'liimmi']),
	'r': set(['raaja', 'rinngi', 'rommi', 'russeq', 'ruua', 'ruujori', 'ruusa', 'ruusaar']),
	'v': set(['viinnequt', 'viinni']),
	}

spell_cache = {}
tok_cache = {}

def do_spell(token):
	global spell_cache

	if token in spell_cache:
		return spell_cache[token]

	spell.stdin.write(f'5 {token}\n')
	out = spell.stdout.readline()
	rv = {}

	if out.startswith('&'):
		sgs = out[2:].strip().split('\t')
		for sg in sgs:
			m = re.match(r'^(.+?) \(([\d.]+);\d+\)$', sg.strip())
			rv[m[1]] = m[2]

	if len(spell_cache) >= 20000:
		spell_cache.clear()
	spell_cache[token] = rv
	return rv

def do_tokenise(token):
	global tok_cache

	if token in tok_cache:
		return tok_cache[token]

	nonce = secrets.token_urlsafe(4)
	tok.stdin.write(f'{token}\n<s{nonce}>\n')
	rv = {
		'w': '',
		'rs': [],
		}
	while (l := tok.stdout.readline()) != f'<s{nonce}>\n':
		if l.startswith('"<'):
			rv['w'] = l
		elif l.startswith('\t"'):
			if '" ?' in l:
				rv = {
					'w': '',
					'rs': [],
					}
				break
			rv['rs'].append(l)

	if len(tok_cache) >= 20000:
		tok_cache.clear()
	tok_cache[token] = rv
	return rv

def kal_from(token):
	token = token.lower()
	rv = 0

	first = token[0]
	if first in old_words:
		if token in old_words[first]:
			return 0
		rv = 1

	if not re.match(r'[aeikmnopqstu]', first):
		rv = max(rv, 1)

	if re.match(r'^.[bcdwxyzæøå]', token):
		rv = max(rv, 2)

	# ToDo: 'ai' anywhere but the end signifies non-kal
	if re.match(r'^.ai', token):
		rv = max(rv, 3)

	eorq = re.finditer(r'[eo]+[^eorq]', token)
	for m in eorq:
		rv = max(rv, m.start()+2)

	if not re.match(r'[aikpqtu]', token[-1]):
		rv = max(rv, len(token))

	if m := re.match(r'[^aefgijklmnopqrstuvŋ][aefgijklmnopqrstuvŋ]+$', token):
		rv = max(rv, m.start()+1)

	cons = re.finditer(r'([qwrtpsdfghjklzxcvbnmŋ])([qwrtpsdfghjklzxcvbnmŋ])', token)
	for m in cons:
		if m[1] == 'r':
			continue
		if m[1] == 'n' and m[2] == 'g':
			continue
		if m[1] == 't' and m[2] == 's':
			continue
		if m[1] != m[2]:
			rv = max(rv, m.start()+2)

	return rv

def kal_segment(token):
	if not re.match(r'^[a-zæøåŋ]+$', token, flags=re.I):
		return token

	token = re.sub(r'nng', 'ŋŋ', token)
	token = re.sub(r'ng', 'ŋ', token)

	C = re.compile(r'[bcdfghjklmnŋpqrstvwxz]', flags=re.I)
	V = re.compile(r'[aeiouyæøå]', flags=re.I)

	i = 0
	split = ''
	while True:
		split += token[i]
		if len(token) >= i+3 and V.match(token[i]) and C.match(token[i+1]) and V.match(token[i+2]):
			split += ' '
		elif len(token) >= i+4 and V.match(token[i]) and C.match(token[i+1]) and C.match(token[i+2]) and V.match(token[i+3]):
			i += 1
			split += token[i]
			split += ' '
		elif len(token) >= i+2 and token[i].lower() != token[i+1].lower() and V.match(token[i]) and V.match(token[i+1]):
			split += ' '
		i += 1
		if i >= len(token)-2:
			break
	split += token[i:]

	return split.replace('ŋ ŋ', 'n ng').replace('ŋŋ', 'nng').replace('ŋ', 'ng').split(' ')

def print_cohort(c):
	if c['p']:
		print(c['p'], end='')
	if c['w']:
		print(c['w'], end='')
	for r in c['rs']:
		print(r, end='')
	if c['s']:
		print(c['s'], end='')

def process():
	global cohorts, cohort

	errs = [False]*len(cohorts)
	for i,c in enumerate(cohorts):
		for r in c['rs']:
			if '" ?' in r:
				errs[i] = True

	for i,c in enumerate(cohorts):
		if not errs[i]:
			continue
		w = c['w'][2:-3]
		rs = []

		sgs = do_spell(w)
		for k,v in sgs.items():
			d = Levenshtein.distance(w, k)
			nc = do_tokenise(k)
			for r in nc['rs']:
				r = r.rstrip()
				rs.append(f'{r} Heur/Spell <h:{k}> <hw:{v}> <hd:{d}>\n')

		n = kal_from(w)
		nw = w[0:n]
		sgs = kal_segment(w[n:])
		for si in range(len(sgs)):
			rem = ''.join(sgs[si:])
			nc = None
			if w[0].isupper():
				nc = do_tokenise('Xxx' + rem)
			else:
				nc = do_tokenise('xxx' + rem)
			if nc['rs']:
				ncw = nc['w'][2:-3]
				for r in nc['rs']:
					if ' Prop' not in r and ' iProp' not in r:
						continue
					r = re.sub(r'\t"Xxx', f'\t"{nw}', r.rstrip(), flags=re.I)
					d = len(w) - len(rem)
					rs.append(f'{r} Heur/Xxx <h:{ncw}> <hn:{si}> <hd:{d}>\n')
				if rs:
					break
			nw += sgs[si]

		if rs:
			c['rs'] = rs

	for i,c in enumerate(cohorts):
		print_cohort(c)

	cohorts = []
	cohort = {
		'p': '',
		'w': '',
		'rs': [],
		's': '',
		}

for line in sys.stdin:
	#line = line.rstrip()

	if line.startswith('"<'):
		if cohort['w']:
			cohorts.append(cohort)
			if cohorts[-1]['w'] == '"<¶>"':
				process()
			cohort = {
				'p': '',
				'w': '',
				'rs': [],
				's': '',
				}
		cohort['w'] = line

	elif line.startswith('\t"'):
		if cohort['w']:
			cohort['rs'].append(line)
		else:
			cohort['p'] += line

	elif line.startswith('<s') or line.startswith('<STREAMCMD:FLUSH>'):
		cohorts.append(cohort)
		process()
		cohort['p'] = line

	else:
		if cohort['w']:
			cohort['s'] += line
		else:
			cohort['p'] += line

cohorts.append(cohort)
process()
