diff --git a/src/main/java/com/andrewlalis/perfin/data/impl/JdbcAnalyticsRepository.java b/src/main/java/com/andrewlalis/perfin/data/impl/JdbcAnalyticsRepository.java index 108888a..55ace6c 100644 --- a/src/main/java/com/andrewlalis/perfin/data/impl/JdbcAnalyticsRepository.java +++ b/src/main/java/com/andrewlalis/perfin/data/impl/JdbcAnalyticsRepository.java @@ -11,6 +11,8 @@ import javafx.scene.paint.Color; import java.math.BigDecimal; import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; import java.util.*; public record JdbcAnalyticsRepository(Connection conn) implements AnalyticsRepository { @@ -76,7 +78,8 @@ public record JdbcAnalyticsRepository(Connection conn) implements AnalyticsRepos } private List> getTransactionAmountByCategoryAndType(TimestampRange range, Currency currency, AccountEntry.Type type) { - return DbUtil.findAll( + // First find totals for each category, using only transactions without any line items (should be most). + List> totalsBeforeLineItems = DbUtil.findAll( conn, """ SELECT @@ -95,24 +98,81 @@ public record JdbcAnalyticsRepository(Connection conn) implements AnalyticsRepos FROM transaction_tag tt LEFT JOIN transaction_tag_join ttj ON tt.id = ttj.tag_id WHERE ttj.transaction_id = transaction.id + ) AND + ( + SELECT COUNT(tli.id) = 0 + FROM transaction_line_item tli + WHERE tli.transaction_id = transaction.id ) GROUP BY tc.id ORDER BY total DESC;""", List.of(currency.getCurrencyCode(), type.name(), range.start(), range.end()), - rs -> { - BigDecimal total = rs.getBigDecimal(1); - TransactionCategory category = null; - long categoryId = rs.getLong(2); - if (!rs.wasNull()) { - Long parentId = rs.getLong(3); - if (rs.wasNull()) parentId = null; - String name = rs.getString(4); - Color color = Color.valueOf("#" + rs.getString(5)); - category = new TransactionCategory(categoryId, parentId, name, color); - } - return new Pair<>(category, total); - } + this::parseAmountAndCategory ); + // Then augment the data for any transactions which do have line items. + List> totalsFromLineItemsOnly = DbUtil.findAll( + conn, + """ + SELECT SUM(tli.value_per_item * tli.quantity) AS s, tc.* + FROM transaction_line_item tli + LEFT JOIN transaction_category tc ON tc.id = tli.category_id + LEFT JOIN transaction t ON t.id = tli.transaction_id + LEFT JOIN account_entry ae ON ae.transaction_id = t.id + WHERE + t.currency = ? AND + ae.type = ? AND + t.timestamp >= ? AND + t.timestamp <= ? AND + '!exclude' NOT IN ( + SELECT tt.name + FROM transaction_tag tt + LEFT JOIN transaction_tag_join ttj ON tt.id = ttj.tag_id + WHERE ttj.transaction_id = t.id + ) + GROUP BY tli.category_id + ORDER BY s DESC""", + List.of(currency.getCurrencyCode(), type.name(), range.start(), range.end()), + this::parseAmountAndCategory + ); + // Finally add data for any remaining value in transactions with line items, which wasn't accounted for in line items. + List> totalsFromLeftoverTransactions = DbUtil.findAll( + conn, + """ + SELECT SUM(s), c_id, c_parent_id, c_name, c_color + FROM ( + SELECT transaction.amount - SUM(tli.value_per_item * tli.quantity) AS s, + tc.id AS c_id, tc.parent_id AS c_parent_id, tc.name AS c_name, tc.color AS c_color + FROM transaction + LEFT JOIN transaction_line_item tli ON tli.transaction_id = transaction.id + LEFT JOIN transaction_category tc ON tc.id = transaction.category_id + LEFT JOIN account_entry ae ON ae.transaction_id = transaction.id + WHERE + transaction.currency = ? AND + ae.type = ? AND + transaction.timestamp >= ? AND + transaction.timestamp <= ? AND + '!exclude' NOT IN ( + SELECT tt.name + FROM transaction_tag tt + LEFT JOIN transaction_tag_join ttj ON tt.id = ttj.tag_id + WHERE ttj.transaction_id = transaction.id + ) AND + ( + SELECT COUNT(tli.id) > 0 + FROM transaction_line_item tli + WHERE tli.transaction_id = transaction.id + ) + GROUP BY transaction.id + ) + GROUP BY c_id""", + List.of(currency.getCurrencyCode(), type.name(), range.start(), range.end()), + this::parseAmountAndCategory + ); + return combineCategorizedAmounts(List.of( + totalsBeforeLineItems, + totalsFromLineItemsOnly, + totalsFromLeftoverTransactions + )); } private List> groupByRootCategory(List> spendByCategory) { @@ -140,4 +200,39 @@ public record JdbcAnalyticsRepository(Connection conn) implements AnalyticsRepos result.sort((p1, p2) -> p2.second().compareTo(p1.second())); return result; } + + private Pair parseAmountAndCategory(ResultSet rs) throws SQLException { + BigDecimal amount = rs.getBigDecimal(1); + long categoryId = rs.getLong(2); + if (rs.wasNull()) { + return new Pair<>(null, amount); + } + Long parentId = rs.getLong(3); + if (rs.wasNull()) parentId = null; + String name = rs.getString(4); + Color color = Color.valueOf("#" + rs.getString(5)); + return new Pair<>(new TransactionCategory(categoryId, parentId, name, color), amount); + } + + private List> combineCategorizedAmounts(List>> lists) { + BigDecimal uncategorizedAmount = BigDecimal.ZERO; + Map categorizedAmounts = new HashMap<>(); + for (var list : lists) { + for (var p : list) { + if (p.first() == null) { + uncategorizedAmount = uncategorizedAmount.add(p.second()); + } else { + BigDecimal value = categorizedAmounts.computeIfAbsent(p.first(), category -> BigDecimal.ZERO); + categorizedAmounts.put(p.first(), value.add(p.second())); + } + } + } + List> amountsByCategory = new ArrayList<>(); + amountsByCategory.add(new Pair<>(null, uncategorizedAmount)); + for (var entry : categorizedAmounts.entrySet()) { + amountsByCategory.add(new Pair<>(entry.getKey(), entry.getValue())); + } + amountsByCategory.sort((p1, p2) -> p2.second().compareTo(p1.second())); + return amountsByCategory; + } }