# Source code for epsproc.geomFunc.mfblmGeom


import numpy as np

# from epsproc.util import matEleSelector   # Circular/undefined import issue - call in function instead for now.
from epsproc.sphCalc import setPolGeoms
from epsproc.geomFunc import geomCalc
# from epsproc.geomFunc.geomCalc import (EPR, MFproj, betaTerm, remapllpL, w3jTable,)
from epsproc.geomFunc.geomUtils import genllpMatE

# Code as developed 16/17 March 2020.
# Needs some tidying, and should implement BLM Xarray attrs and format for output.
[docs]def mfblmXprod(matEin, QNs = None, EPRX = None, p=, BLMtable = None, BLMtableResort = None,
lambdaTerm = None, RX = None, eulerAngs = None, polProd = None,
thres = 1e-2, thresDims = 'Eke', selDims = {'it':1, 'Type':'L'},
sumDims = ['mu', 'mup', 'l','lp','m','mp'], sumDimsPol = ['P','R','Rp','p'], symSum = True,
SFflag = False, squeeze = False, phaseConvention = 'E', basisReturn = "BLM", verbose = 0, **kwargs):
r"""
Implement :math:\beta_{LM}^{MF} calculation as product of tensors.

.. math::
\begin{eqnarray}
\beta_{L,-M}^{\mu_{i},\mu_{f}} & = & \sum_{l,m,\mu}\sum_{l',m',\mu'}(-1)^{(\mu'-\mu_{0})}{B_{L,-M}}\nonumber \\
& \times & \sum_{P,R',R}{E_{P-R}(\hat{e})\Lambda_{R',R}(R_{\hat{n}})}I_{l,m,\mu}^{p_{i}\mu_{i},p_{f}\mu_{f}}(E)I_{l',m',\mu'}^{p_{i}\mu_{i},p_{f}\mu_{f}*}(E)
\end{eqnarray}

Where each component is defined by fns. in :py:module:epsproc.geomFunc.geomCalc module.

12/08/22    Updating from afblmXprod routine for use with fitting functions.
Added ProductBasis return as per afblmGeom case, for use in fitting.
Added **kwargs, unused but allows for arb basis dict unpack and passing from other functions. May want to pipe back to Full basis return however.
Updated docs as per afblmXprod.
TODO: still needs a tidy up and update, including BLM renorm options, again see afblmGeom code.

16/03/20 In progress!

Dev code:
http://localhost:8888/lab/tree/dev/ePSproc/geometric_method_dev_Betas_090320.ipynb
D:\code\ePSproc\python_dev\ePSproc_MFBLM_Numba_dev_tests_120220.PY

Parameters
----------
matE : Xarray
Xarray containing matrix elements, with QNs (l,m), as created by :py:func:readMatEle

*** Optional calculation settings

selDims : dict, default = {'it':1, 'Type':'L'}
Selection parameters for calculations, may be be checked and appened herein.

sumDims : list, default = ['mu', 'mup', 'l','lp','m','mp']
Main summation dims, will be checked herein.

sumDimsPol : list, default = ['P','R','Rp','p']

symSum : bool, default = True
Sum over symmetries sets (=={Cont, Targ, Total}) if true.

degenDrop : bool
Flag to set dropping of degenerate components.
NOT IMPLEMENTED FOR MF CASE - see afblmXprod.

thres : float, default = 1e-2
Set threshold value, used to set input matrix elements and again for outputs.

thresDims : str, default = 'Eke'
Set threshold dimension (set to be contiguous).

verbose : bool or int
Print output?

*** Optional renormalisation settings (mainly for testing only)

SFflag : bool, default = False
Multiply input matrix elements by complex scale-factor if true.

SFflagRenorm : bool, default = False
Renorm output BLMs by complex scale-factor if true.
NOT IMPLEMENTED FOR MF CASE - see afblmXprod.

BLMRenorm : int, default = 1
Set different BLM renorm conventions.
If 1 renorm by B00.
See  code for further details.
NOT IMPLEMENTED FOR MF CASE - see afblmXprod.

squeeze : bool, default = False
Squeeze output array after thresholding?
Note: this may cause dim issues if True.

*** Optional input data/basis functions (mainly for fitting routine use)

QNs : np.array, optional, default = None
List of QNs as generated by :py:func:genllpMatE.
Will be generated if not passed.

EPRX : Xarray, optional, default = None
E-field parameters, as generated by :py:func:EPR.
Defaults to normalised/unity field, pol = p (below).

p : list or array, optional, default = 
Specify polarization terms p.
Possibly currently only valid for p=0, TBC

BLMtable, BLMtableResort : Xarrays, optional, default = None
Beta calculation parameters, as defined by :py:func:geomCalc.betaTerm.
BLMtableResort includes phase settings & param renaming as herein.

lambdaTerm : Xarray, optional, default = None
Lambda term parameters, as defined by :py:func:geomCalc.MFproj

RX : Xarray, optional, default = None
Polarization geometries as defined by :py:func:epsproc.sphCalc.setPolGeoms.
If not set, defaults are used by :py:func:epsproc.geomFunc.geomCalc.MFproj.
If not set, but Euler angles are set, then these will be used.

eulerAngs : list or np.array of Euler angles (p(hi), t(heta), c(hi)), optional.
Alternative definition for polarization geometries, as used by :py:func:epsproc.sphCalc.setPolGeoms.
List or array [p,t,c...], shape (Nx3).
List or array including set labels, [label,p,t,c...], shape (Nx4)

polProd : Xarray, optional, default = None
Polarization tensor as defined by EPRXresort * lambdaTermResort

phaseConvention : optional, str, default = 'E'
Set phase conventions with :py:func:epsproc.geomCalc.setPhaseConventions.
To use preset phase conventions, pass existing dictionary.

basisReturn : optional, str, default = "BLM"
- 'BLM' return Xarray of results only.
- 'Full' return Xarray of results + basis set dictionary as set during the run.
- 'Product', as full, but minimal basis set with products only.
- 'Results' or 'Legacy' direct return of various calc. results Xarrays.

**kwargs, unused but allows for arb basis dict unpack and passing from other functions.

Returns
-------
Xarray
Set of AFBLM calculation results

dict
Dictionary of basis functions, only if basisReturn != 'BLM' (see basisReturn paramter notes).

Notes
-----

Cross-section outputs currently defined as XS = direct MF calculation output.

Optionally set SFflag = True to multiply by (complex) scale-factor.

OTHER RENORM options not implemented as yet, see afblmXprod for details.

"""

from epsproc.util import matEleSelector

# Set phase conventions - either from function call or via passed dict.
# if type(phaseConvention) is str:
#     phaseCons = geomCalc.setPhaseConventions(phaseConvention = phaseConvention)
# else:
#     phaseCons = phaseConvention

# For transparency/consistency with subfunctions, str/dict now set in setPhaseConventions()
phaseCons = geomCalc.setPhaseConventions(phaseConvention = phaseConvention)

# Fudge - set this for now to enforce additonal unstack and phase corrections later.
# 12/08/22 - move to passed args for basis set passing.
# BLMtableResort = None

#*** Threshold and selection
# Make explicit copy of data to avoid any overwrite issues
matE = matEin.copy()
matE.attrs = matEin.attrs  # May not be necessary with updated Xarray versions

# Use SF (scale factor)
# Write to data.values to make sure attribs are maintained. (Not the case for da = da*da.SF)
if SFflag:
matE.values = matE * matE.SF

matEthres = matEleSelector(matE, thres = thres, inds = selDims, dims = thresDims, sq = True, drop = True)

# Sum **AFTER** threshold and selection, to allow for subselection on symmetries via matEleSelector
if symSum:
if 'Sym' in matEthres.dims:
matEthres = matEthres.sum('Sym')  # Sum over ['Cont','Targ','Total'] stacked dims.

# Set terms if not passed to function
if QNs is None:
QNs = genllpMatE(matEthres, phaseConvention = phaseConvention)

#*** Polarization terms
if (EPRX is None) and (polProd is None):  # Skip if product term already passed
# if EPRX is None:
# *** EPR
# EPRX = geomCalc.EPR(form = 'xarray', p = p, phaseConvention = phaseConvention).sel({'R-p':0})  # Set for R-p = 0 for p=0 case (redundant coord) - need to fix in e-field mult term!
# EPRXresort = EPRX.unstack().squeeze().drop('l').drop('lp')  # This removes photon (l,lp) dims fully. Be careful with squeeze() - sends singleton dims to non-dimensional labels.
#         EPRXresort = EPRX.unstack().drop('l').drop('lp')  # This removes photon (l,lp) dims fully, but keeps (p,R) as singleton dims.
#         EPRXresort = EPRX.unstack().squeeze(['l','lp']).drop(['l','lp'])  # Safe squeeze & drop of selected singleton dims only.

#         EPRX = geomCalc.EPR(form = 'xarray', p = p).unstack().sum(['p','R-p'])  # Set for general sum over (p,R-p) terms - STILL need to fix in e-field mult term!
#         EPRX = geomCalc.EPR(form = 'xarray', p = p).unstack().sum('R-p')  # Set for general sum over (p,R-p) terms - STILL need to fix in e-field mult term!
EPRX = geomCalc.EPR(form = 'xarray', p = p).unstack().sel({'R-p':0}).drop('R-p')
EPRXresort = EPRX.squeeze(['l','lp']).drop(['l','lp'])  # Safe squeeze & drop of selected singleton dims only.

if phaseCons['mfblmCons']['negRcoordSwap']:
EPRXresort['R'] *= -1

if (lambdaTerm is None) and (polProd is None):  # Skip if product term already passed
# if lambdaTerm is None:
# Set polGeoms if Euler angles are passed.
if eulerAngs is not None:
RX = setPolGeoms(eulerAngs = eulerAngs)

# *** Lambda term
lambdaTerm, lambdaTable, lambdaD, _ = geomCalc.MFproj(RX = RX, form = 'xarray', phaseConvention = phaseConvention)  #, eNames = ['Phi','Theta','Chi'])
# lambdaTermResort = lambdaTerm.squeeze().drop('l').drop('lp')   # This removes photon (l,lp) dims fully.
lambdaTermResort = lambdaTerm.squeeze(['l','lp']).drop(['l','lp'])  # Safe squeeze & drop of selected singleton dims only.

# *** Blm term with specified QNs
if (BLMtable is None) and (BLMtableResort is None):  # Skip this is BLMtableResort is passed
# if BLMtable is None:

QNsBLMtable = QNs.copy()

# Switch signs (m,M) before 3j calcs.
if phaseCons['mfblmCons']['BLMmPhase']:
QNsBLMtable[:,3] *= -1
QNsBLMtable[:,5] *= -1

BLMtable = geomCalc.betaTerm(QNs = QNsBLMtable, form = 'xdaLM', phaseConvention = phaseConvention)

#         if BLMmPhase:
#             BLMtable['m'] *= -1

if BLMtableResort is None:
BLMtableResort = BLMtable.copy().unstack()

if phaseCons['mfblmCons']['negMcoordSwap']:
BLMtableResort['M'] *= -1

if phaseCons['mfblmCons']['Mphase']:
BLMtableResort *= np.power(-1, np.abs(BLMtableResort.M))  # Associated phase term

if phaseCons['mfblmCons']['negmCoordSwap']:
BLMtableResort['m'] *= -1

if phaseCons['mfblmCons']['mPhase']:
BLMtableResort *= np.power(-1, np.abs(BLMtableResort.m))  # Associated phase term

#*** Products
# Matrix element pair-wise multiplication by (l,m,mu) dims
matEconj = matEthres.copy().conj()
# matEconj = matEconj.unstack().rename({'l':'lp','m':'mp','mu':'mup'})  # Full unstack
# matEmult = matEthres.unstack() * matEconj
matEconj = matEconj.unstack('LM').rename({'l':'lp','m':'mp','mu':'mup'})  # Unstack LM only.
matEmult = matEthres.unstack('LM') * matEconj
matEmult.attrs['dataType'] = 'multTest'

# Threshold product and drop dims.
# matEmult = ep.util.matEleSelector(matEmult, thres = thres, dims = thresDims)
matEmult = matEleSelector(matEmult, thres = thres, dims = thresDims)

# Product terms with similar dims
BLMprod = matEmult * BLMtableResort  # Unstacked case with phase correction - THIS DROPS SYM TERMS? Takes intersection of das - http://xarray.pydata.org/en/stable/computation.html#automatic-alignment

if polProd is None:
# polProd = (EPRXresort * lambdaTermResort).sum(sumDimsPol)  # Sum polarization terms here to keep total dims minimal in product. Here dims = (mu,mup,Euler/Labels)
polProd = (EPRXresort * lambdaTermResort)  # Without polarization terms sum to allow for mupPhase below (reqs. p)

# Set additional phase term, (-1)^(mup-p) **** THIS MIGHT BE SPURIOUS FOR GENERAL EPR TENSOR CASE??? Not sure... but definitely won't work if p summed over above!
if phaseCons['mfblmCons']['mupPhase']:
mupPhaseTerm = np.power(-1, np.abs(polProd.mup - polProd.p))
polProd *= mupPhaseTerm

# Additional [P]^1/2 degen term, NOT included in EPR defn.
polProd *= np.sqrt(2*polProd.P+1)

polProd = polProd.sum(sumDimsPol)

polProd = matEleSelector(polProd, thres = thres)  # Select over dims for reduction.

# Test big mult...
# mTerm = polProd.sel({'R':0,'Labels':'z'}) * BLMprod.sum(['Total'])    # With selection of z geom.  # BLMprod.sum(['Cont', 'Targ', 'Total'])
# mTerm = polProd.sel({'R':0}) * BLMprod    # BLMprod.sum(['Cont', 'Targ', 'Total'])
mTerm = polProd * BLMprod
# Multiplication works OK, and is fast... but might be an ugly result... INDEED - result large and slow to manipulate, lots of dims and NaNs. Better to sub-select terms first!

# No subselection, mTerm.size = 6804000
# For polProd.sel({'R':0}), mTerm.size = 1360800
# For polProd.sel({'R':0,'Labels':'z'}), mTerm.size = 453600
# Adding also BLMprod.sum(['Total']), mTerm.size = 226800
# Adding also BLMprod.sum(['Cont', 'Targ', 'Total']), mTerm.size = 113400  So, for sym specific calcs, may be better to do split-apply type methods

# mTerm.attrs['file'] = 'MulTest'  # Temporarily adding this, not sure why this is an issue here however (not an issue for other cases...)
mTerm.attrs = matEin.attrs  # Propagate attrs from input matrix elements.
# mTerm.attrs['phaseConvention'] = {phaseConvention:phaseCons}  # Log phase conventions used.
mTerm.attrs['phaseCons'] = geomCalc.setPhaseConventions(phaseConvention = phaseConvention)  # Log phase conventions used.

# return mTerm, sumDims

# Sum and threshold
#     sumDims = ['P', 'mu', 'mup', 'Rp', ]  # Define dims to sum over
xDim = {'LM':['L','M']}
mTermSum = mTerm.sum(sumDims)

if squeeze is True:
mTermSum = mTermSum.squeeze()  # Leave this as optional, since it can cause issues for M=0 only case

mTermSumThres = matEleSelector(mTermSum.stack(xDim), thres=thres, dims = thresDims)
#     mTermSumThres = mTermSum

# Normalise
# TODO: Set XS as per old mfpad()
#     BLMXout['XS'] = (('Eke','Euler'), BLMXout.data)  # Set XS = B00
#     BLMXout = BLMXout/BLMXout.XS  # Normalise
if SFflag:
mTermSumThres.values = mTermSumThres/mTermSumThres.SF

mTermSumThres['XS'] = mTermSumThres.sel({'L':0,'M':0}).drop('LM').copy()  # This basically works, and keeps all non-summed dims... but may give issues later...? Make sure to .copy(), otherwise it's just a pointer.
mTermSumThres /= mTermSumThres.sel({'L':0,'M':0}).drop('LM')

# Propagate attrs
mTermSum.attrs = mTerm.attrs
mTermSum.attrs['dataType'] = 'multTest'

mTermSumThres.attrs = mTerm.attrs
mTermSumThres.attrs['dataType'] = 'multTest'

# return mTermSumThres, mTermSum, mTerm

# 20/10/20 added output options as per last afblmGeom code update.
#**** Tidy up and reformat to standard BLM array (see ep.util.BLMdimList() )
# TODO: finish this, and set this as standard output
BetasNormX = mTermSumThres.unstack().rename({'L':'l','M':'m'}).stack({'BLM':['l','m']})

# Set/propagate global properties
BetasNormX.attrs = matE.attrs
BetasNormX.attrs['thres'] = thres

# TODO: update this for **vargs
# BLMXout.attrs['sumDims'] = sumDims # May want to explicitly propagate symmetries here...?
# BLMXout.attrs['selDims'] = [(k,v) for k,v in selDims.items()]  # Can't use Xarray to_netcdf with dict set here, at least for netCDF3 defaults.
BetasNormX.attrs['dataType'] = 'BLM'

# Set return args based on basisReturn parameter
# Full results set, including all versions
if verbose:
print(f"Return type {basisReturn}.")

if basisReturn in ["Results", "Legacy"]:
return mTermSumThres, mTermSum, mTerm

# Return basis arrays/tensors
elif basisReturn == "Full":
basis = {'QNs':QNs, 'EPRX':EPRXresort, 'lambdaTerm':lambdaTermResort,
'BLMtable':BLMtable, 'BLMtableResort':BLMtableResort,
'phaseConvention':phaseConvention, 'phaseCons':phaseCons}
# 'AKQS':AKQS, 'phaseConvention':phaseConvention, 'phaseCons':phaseCons}

return  BetasNormX, basis

# Return product basis fns. for use in fitting routines
elif basisReturn == "ProductBasis":
basis = {'BLMtableResort':BLMtableResort, 'polProd':polProd, 'phaseConvention':phaseCons}  # , 'BLMRenorm':BLMRenorm}

return  BetasNormX, basis

# Minimal return
elif basisReturn == "BLM":
return BetasNormX

else:
print(f"Return type {basisReturn} not recognised, defaulting to BLM only.")
return BetasNormX