1
+ from typing import Union
2
+
3
+ from posthog .hogql import ast
4
+ from posthog .schema import CurrencyCode , RevenueTrackingConfig , RevenueTrackingEventItem
1
5
from posthog .hogql .database .models import (
2
6
StringDatabaseField ,
3
7
DateDatabaseField ,
@@ -19,3 +23,120 @@ def to_printed_clickhouse(self, context):
19
23
20
24
def to_printed_hogql (self ):
21
25
return "exchange_rate"
26
+
27
+
28
+ def convert_currency_call (
29
+ amount : ast .Expr , currency_from : ast .Expr , currency_to : ast .Expr , timestamp : ast .Expr | None = None
30
+ ) -> ast .Expr :
31
+ args = [currency_from , currency_to , amount ]
32
+ if timestamp :
33
+ args .append (timestamp )
34
+
35
+ return ast .Call (name = "convertCurrency" , args = args )
36
+
37
+
38
+ def revenue_comparison_and_value_exprs (
39
+ event : RevenueTrackingEventItem , config : RevenueTrackingConfig
40
+ ) -> tuple [ast .Expr , ast .Expr ]:
41
+ # Check whether the event is the one we're looking for
42
+ comparison_expr = ast .CompareOperation (
43
+ left = ast .Field (chain = ["event" ]),
44
+ op = ast .CompareOperationOp .Eq ,
45
+ right = ast .Constant (value = event .eventName ),
46
+ )
47
+
48
+ # If there's a revenueCurrencyProperty, convert the revenue to the base currency from that property
49
+ # Otherwise, assume we're already in the base currency
50
+ # Also, assume that `base_currency` is USD by default, it'll be empty for most customers
51
+ if event .revenueCurrencyProperty :
52
+ value_expr = ast .Call (
53
+ name = "if" ,
54
+ args = [
55
+ ast .Call (
56
+ name = "isNull" , args = [ast .Field (chain = ["events" , "properties" , event .revenueCurrencyProperty ])]
57
+ ),
58
+ ast .Call (
59
+ name = "toDecimal" ,
60
+ args = [
61
+ ast .Field (chain = ["events" , "properties" , event .revenueProperty ]),
62
+ ast .Constant (value = 10 ),
63
+ ],
64
+ ),
65
+ convert_currency_call (
66
+ ast .Field (chain = ["events" , "properties" , event .revenueProperty ]),
67
+ ast .Field (chain = ["events" , "properties" , event .revenueCurrencyProperty ]),
68
+ ast .Constant (value = (config .baseCurrency or CurrencyCode .USD ).value ),
69
+ ast .Call (name = "DATE" , args = [ast .Field (chain = ["events" , "timestamp" ])]),
70
+ ),
71
+ ],
72
+ )
73
+ else :
74
+ value_expr = ast .Call (
75
+ name = "toDecimal" ,
76
+ args = [ast .Field (chain = ["events" , "properties" , event .revenueProperty ]), ast .Constant (value = 10 )],
77
+ )
78
+
79
+ return (comparison_expr , value_expr )
80
+
81
+
82
+ def revenue_expression (config : Union [RevenueTrackingConfig , dict , None ]) -> ast .Expr :
83
+ if isinstance (config , dict ):
84
+ config = RevenueTrackingConfig .model_validate (config )
85
+
86
+ if not config or not config .events :
87
+ return ast .Constant (value = None )
88
+
89
+ exprs : list [ast .Expr ] = []
90
+ for event in config .events :
91
+ comparison_expr , value_expr = revenue_comparison_and_value_exprs (event , config )
92
+ exprs .extend ([comparison_expr , value_expr ])
93
+
94
+ # Else clause, make sure there's a None at the end
95
+ exprs .append (ast .Constant (value = None ))
96
+
97
+ return ast .Call (name = "multiIf" , args = exprs )
98
+
99
+
100
+ def revenue_sum_expression (config : Union [RevenueTrackingConfig , dict , None ]) -> ast .Expr :
101
+ if isinstance (config , dict ):
102
+ config = RevenueTrackingConfig .model_validate (config )
103
+
104
+ if not config or not config .events :
105
+ return ast .Constant (value = None )
106
+
107
+ exprs : list [ast .Expr ] = []
108
+ for event in config .events :
109
+ comparison_expr , value_expr = revenue_comparison_and_value_exprs (event , config )
110
+
111
+ exprs .append (
112
+ ast .Call (
113
+ name = "sumIf" ,
114
+ args = [
115
+ ast .Call (name = "ifNull" , args = [value_expr , ast .Constant (value = 0 )]),
116
+ comparison_expr ,
117
+ ],
118
+ )
119
+ )
120
+
121
+ if len (exprs ) == 1 :
122
+ return exprs [0 ]
123
+
124
+ return ast .Call (name = "plus" , args = exprs )
125
+
126
+
127
+ def revenue_events_where_expr (config : Union [RevenueTrackingConfig , dict , None ]) -> ast .Expr :
128
+ if isinstance (config , dict ):
129
+ config = RevenueTrackingConfig .model_validate (config )
130
+
131
+ if not config or not config .events :
132
+ return ast .Constant (value = False )
133
+
134
+ exprs : list [ast .Expr ] = []
135
+ for event in config .events :
136
+ comparison_expr , _value_expr = revenue_comparison_and_value_exprs (event , config )
137
+ exprs .append (comparison_expr )
138
+
139
+ if len (exprs ) == 1 :
140
+ return exprs [0 ]
141
+
142
+ return ast .Or (exprs = exprs )
0 commit comments