pwncash/lib/Internal/Budget.hs

396 lines
13 KiB
Haskell

module Internal.Budget (readBudget) where
import Control.Monad.Except
import Data.Decimal hiding (allocate)
import Data.Foldable
import Data.Hashable
import Internal.Types.Main
import Internal.Utils
import RIO hiding (to)
import qualified RIO.List as L
import qualified RIO.Map as M
import qualified RIO.NonEmpty as NE
import qualified RIO.Text as T
import RIO.Time
readBudget :: (MonadInsertError m, MonadFinance m) => Budget -> m [Tx CommitR]
readBudget
b@Budget
{ bgtLabel
, bgtIncomes
, bgtTransfers
, bgtShadowTransfers
, bgtPretax
, bgtTax
, bgtPosttax
, bgtInterval
} =
do
spanRes <- getSpan
case spanRes of
Nothing -> return []
Just budgetSpan -> do
(intAllos, _) <- combineError intAlloRes acntRes (,)
let res1 = mapErrors (readIncome c bgtLabel intAllos budgetSpan) bgtIncomes
let res2 = expandTransfers c bgtLabel budgetSpan bgtTransfers
txs <- combineError (concat <$> res1) res2 (++)
shadow <- addShadowTransfers bgtShadowTransfers txs
return $ txs ++ shadow
where
c = CommitR (hash b) CTBudget
acntRes = mapErrors isNotIncomeAcnt alloAcnts
intAlloRes = combineError3 pre_ tax_ post_ (,,)
pre_ = sortAllos bgtPretax
tax_ = sortAllos bgtTax
post_ = sortAllos bgtPosttax
sortAllos = liftExcept . mapErrors sortAllo
alloAcnts =
(alloAcnt <$> bgtPretax)
++ (alloAcnt <$> bgtTax)
++ (alloAcnt <$> bgtPosttax)
getSpan = do
globalSpan <- askDBState (unBSpan . csBudgetScope)
case bgtInterval of
Nothing -> return $ Just globalSpan
Just bi -> do
localSpan <- liftExcept $ resolveDaySpan bi
return $ intersectDaySpan globalSpan localSpan
sortAllo :: MultiAllocation v -> InsertExcept (DaySpanAllocation v)
sortAllo a@Allocation {alloAmts = as} = do
bs <- foldSpan [] $ L.sortOn amtWhen as
return $ a {alloAmts = reverse bs}
where
foldSpan acc [] = return acc
foldSpan acc (x : xs) = do
let start = amtWhen x
res <- case xs of
[] -> resolveDaySpan start
(y : _) -> resolveDaySpan_ (intStart $ amtWhen y) start
foldSpan (x {amtWhen = res} : acc) xs
--------------------------------------------------------------------------------
-- Income
-- TODO this will scan the interval allocations fully each time
-- iteration which is a total waste, but the fix requires turning this
-- loop into a fold which I don't feel like doing now :(
readIncome
:: (MonadInsertError m, MonadFinance m)
=> CommitR
-> T.Text
-> IntAllocations
-> DaySpan
-> Income
-> m [Tx CommitR]
readIncome
key
name
(intPre, intTax, intPost)
ds
Income
{ incWhen
, incCurrency
, incFrom = TaggedAcnt {taAcnt = srcAcnt, taTags = srcTags}
, incPretax
, incPosttax
, incTaxes
, incToBal = TaggedAcnt {taAcnt = destAcnt, taTags = destTags}
, incGross
, incPayPeriod
, incPriority
} =
combineErrorM
(combineError incRes nonIncRes (,))
(combineError cpRes dayRes (,))
$ \_ (cp, days) -> do
let gross = realFracToDecimal (cpPrec cp) incGross
foldDays (allocate cp gross) start days
where
incRes = isIncomeAcnt srcAcnt
nonIncRes =
mapErrors isNotIncomeAcnt $
destAcnt
: (alloAcnt <$> incPretax)
++ (alloAcnt <$> incTaxes)
++ (alloAcnt <$> incPosttax)
cpRes = lookupCurrency incCurrency
dayRes = liftExcept $ expandDatePat ds incWhen
start = fromGregorian' $ pStart incPayPeriod
pType' = pType incPayPeriod
flatPre = concatMap flattenAllo incPretax
flatTax = concatMap flattenAllo incTaxes
flatPost = concatMap flattenAllo incPosttax
sumAllos = sum . fmap faValue
entry0 a c ts = Entry {eAcnt = a, eValue = (), eComment = c, eTags = ts}
allocate cp gross prevDay day = do
scaler <- liftExcept $ periodScaler pType' prevDay day
let precision = cpPrec cp
let (preDeductions, pre) =
allocatePre precision gross $
flatPre ++ concatMap (selectAllos day) intPre
let tax =
allocateTax precision gross preDeductions scaler $
flatTax ++ concatMap (selectAllos day) intTax
aftertaxGross = gross - sumAllos (tax ++ pre)
let post =
allocatePost precision aftertaxGross $
flatPost ++ concatMap (selectAllos day) intPost
let src = entry0 srcAcnt "gross income" srcTags
let dest = entry0 destAcnt "balance after deductions" destTags
let allos = allo2Trans <$> (pre ++ tax ++ post)
let primary =
EntrySet
{ esTotalValue = gross
, esCurrency = cpID cp
, esFrom = HalfEntrySet {hesPrimary = src, hesOther = []}
, esTo = HalfEntrySet {hesPrimary = dest, hesOther = allos}
}
return $
Tx
{ txCommit = key
, txDate = day
, txPrimary = Left primary
, txOther = []
, txDescr = ""
, txBudget = name
, txPriority = incPriority
}
periodScaler
:: PeriodType
-> Day
-> Day
-> InsertExcept PeriodScaler
periodScaler pt prev cur = return scale
where
n = workingDays wds prev cur
wds = case pt of
Hourly HourlyPeriod {hpWorkingDays} -> hpWorkingDays
Daily ds -> ds
scale prec x = case pt of
Hourly HourlyPeriod {hpAnnualHours, hpDailyHours} ->
realFracToDecimal prec (x / fromIntegral hpAnnualHours)
* fromIntegral hpDailyHours
* fromIntegral n
Daily _ -> realFracToDecimal prec (x * fromIntegral n / 365.25)
-- ASSUME start < end
workingDays :: [Weekday] -> Day -> Day -> Natural
workingDays wds start end = fromIntegral $ daysFull + daysTail
where
interval = diffDays end start
(nFull, nPart) = divMod interval 7
daysFull = fromIntegral (length wds') * nFull
daysTail = fromIntegral $ length $ takeWhile (< nPart) wds'
startDay = dayOfWeek start
wds' = L.sort $ (\x -> diff (fromWeekday x) startDay) <$> L.nub wds
diff a b = fromIntegral $ mod (fromEnum a - fromEnum b) 7
-- ASSUME days is a sorted list
foldDays
:: MonadInsertError m
=> (Day -> Day -> m a)
-> Day
-> [Day]
-> m [a]
foldDays f start days = case NE.nonEmpty days of
Nothing -> return []
Just ds
| any (start >) ds ->
throwError $
InsertException [PeriodError start $ minimum ds]
| otherwise ->
combineErrors $
snd $
L.mapAccumL (\prevDay day -> (day, f prevDay day)) start days
isIncomeAcnt :: (MonadInsertError m, MonadFinance m) => AcntID -> m ()
isIncomeAcnt = checkAcntType IncomeT
isNotIncomeAcnt :: (MonadInsertError m, MonadFinance m) => AcntID -> m ()
isNotIncomeAcnt = checkAcntTypes (AssetT :| [EquityT, ExpenseT, LiabilityT])
checkAcntType
:: (MonadInsertError m, MonadFinance m)
=> AcntType
-> AcntID
-> m ()
checkAcntType t = checkAcntTypes (t :| [])
checkAcntTypes
:: (MonadInsertError m, MonadFinance m)
=> NE.NonEmpty AcntType
-> AcntID
-> m ()
checkAcntTypes ts i = void $ go =<< lookupAccountType i
where
go t
| t `L.elem` ts = return i
| otherwise = throwError $ InsertException [AccountError i ts]
flattenAllo :: SingleAllocation v -> [FlatAllocation v]
flattenAllo Allocation {alloAmts, alloTo} = fmap go alloAmts
where
go Amount {amtValue, amtDesc} =
FlatAllocation
{ faTo = alloTo
, faValue = amtValue
, faDesc = amtDesc
}
-- ASSUME allocations are sorted
selectAllos :: Day -> DaySpanAllocation v -> [FlatAllocation v]
selectAllos day Allocation {alloAmts, alloTo} =
go <$> filter ((`inDaySpan` day) . amtWhen) alloAmts
where
go Amount {amtValue, amtDesc} =
FlatAllocation
{ faTo = alloTo
, faValue = amtValue
, faDesc = amtDesc
}
allo2Trans :: FlatAllocation Decimal -> Entry AcntID LinkDeferred TagID
allo2Trans FlatAllocation {faValue, faTo = TaggedAcnt {taAcnt, taTags}, faDesc} =
Entry
{ eValue = LinkDeferred (EntryFixed faValue)
, eComment = faDesc
, eAcnt = taAcnt
, eTags = taTags
}
allocatePre
:: Precision
-> Decimal
-> [FlatAllocation PretaxValue]
-> (M.Map T.Text Decimal, [FlatAllocation Decimal])
allocatePre precision gross = L.mapAccumR go M.empty
where
go m f@FlatAllocation {faValue = PretaxValue {preCategory, preValue, prePercent}} =
let v =
if prePercent
then gross *. (preValue / 100)
else realFracToDecimal precision preValue
in (mapAdd_ preCategory v m, f {faValue = v})
allocateTax
:: Precision
-> Decimal
-> M.Map T.Text Decimal
-> PeriodScaler
-> [FlatAllocation TaxValue]
-> [FlatAllocation Decimal]
allocateTax precision gross preDeds f = fmap (fmap go)
where
go TaxValue {tvCategories, tvMethod} =
let agi = gross - sum (mapMaybe (`M.lookup` preDeds) tvCategories)
in case tvMethod of
TMPercent p -> agi *. p / 100
TMBracket TaxProgression {tpDeductible, tpBrackets} ->
let taxDed = f precision tpDeductible
in foldBracket f precision (agi - taxDed) tpBrackets
-- | Compute effective tax percentage of a bracket
-- The algorithm can be thought of in three phases:
-- 1. Find the highest tax bracket by looping backward until the AGI is less
-- than the bracket limit
-- 2. Computing the tax in the top bracket by subtracting the AGI from the
-- bracket limit and multiplying by the tax percentage.
-- 3. Adding all lower brackets, which are just the limit of the bracket less
-- the amount of the lower bracket times the percentage.
--
-- In reality, this can all be done with one loop, but it isn't clear these
-- three steps are implemented from this alone.
foldBracket :: PeriodScaler -> Precision -> Decimal -> [TaxBracket] -> Decimal
foldBracket f prec agi bs = fst $ foldr go (0, agi) $ L.sortOn tbLowerLimit bs
where
go TaxBracket {tbLowerLimit, tbPercent} a@(acc, remain) =
let l = f prec tbLowerLimit
in if remain >= l
then (acc + (remain - l) *. (tbPercent / 100), l)
else a
allocatePost
:: Precision
-> Decimal
-> [FlatAllocation PosttaxValue]
-> [FlatAllocation Decimal]
allocatePost prec aftertax = fmap (fmap go)
where
go PosttaxValue {postValue, postPercent}
| postPercent = aftertax *. (postValue / 100)
| otherwise = realFracToDecimal prec postValue
--------------------------------------------------------------------------------
-- shadow transfers
-- TODO this is going to be O(n*m), which might be a problem?
addShadowTransfers
:: (MonadInsertError m, MonadFinance m)
=> [ShadowTransfer]
-> [Tx CommitR]
-> m [Tx CommitR]
addShadowTransfers ms = mapErrors go
where
go tx = do
es <- catMaybes <$> mapErrors (fromShadow tx) ms
return $ tx {txOther = Right <$> es}
fromShadow
:: (MonadInsertError m, MonadFinance m)
=> Tx CommitR
-> ShadowTransfer
-> m (Maybe ShadowEntrySet)
fromShadow tx ShadowTransfer {stFrom, stTo, stDesc, stRatio, stCurrency, stMatch} =
combineErrorM curRes shaRes $ \cur sha -> do
let es = entryPair stFrom stTo cur stDesc stRatio ()
return $ if not sha then Nothing else Just es
where
curRes = lookupCurrencyKey stCurrency
shaRes = liftExcept $ shadowMatches stMatch tx
shadowMatches :: TransferMatcher -> Tx CommitR -> InsertExcept Bool
shadowMatches TransferMatcher {tmFrom, tmTo, tmDate, tmVal} Tx {txPrimary, txDate} = do
-- NOTE this will only match against the primary entry set since those
-- are what are guaranteed to exist from a transfer
valRes <- case txPrimary of
Left es -> valMatches tmVal $ toRational $ esTotalValue es
Right _ -> return True
return $
memberMaybe fa tmFrom
&& memberMaybe ta tmTo
&& maybe True (`dateMatches` txDate) tmDate
&& valRes
where
fa = either getAcntFrom getAcntFrom txPrimary
ta = either getAcntTo getAcntTo txPrimary
getAcntFrom = getAcnt esFrom
getAcntTo = getAcnt esTo
getAcnt f = eAcnt . hesPrimary . f
memberMaybe x AcntSet {asList, asInclude} =
(if asInclude then id else not) $ x `elem` asList
--------------------------------------------------------------------------------
-- random
alloAcnt :: Allocation w v -> AcntID
alloAcnt = taAcnt . alloTo
type IntAllocations =
( [DaySpanAllocation PretaxValue]
, [DaySpanAllocation TaxValue]
, [DaySpanAllocation PosttaxValue]
)
type DaySpanAllocation = Allocation DaySpan
type PeriodScaler = Precision -> Double -> Decimal
data FlatAllocation v = FlatAllocation
{ faValue :: !v
, faDesc :: !T.Text
, faTo :: !TaggedAcnt
}
deriving (Functor, Show)