Skip to content

Commit aee0037

Browse files
committed
Refactor TransformVisitor
1 parent 1a1a5bf commit aee0037

6 files changed

Lines changed: 214 additions & 238 deletions

File tree

grade/internal/domain/seedwork/specification/nodes.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,23 +334,23 @@ func (n ItemNode) Accept(v Visitor) error {
334334
return v.VisitItem(n)
335335
}
336336

337-
func Field(object ObjectNode, name string) FieldNode {
337+
func Field(object EmptiableObject, name string) FieldNode {
338338
return FieldNode{
339339
object: object,
340340
name: name,
341341
}
342342
}
343343

344344
type FieldNode struct {
345-
object ObjectNode
345+
object EmptiableObject
346346
name string
347347
}
348348

349349
func (n FieldNode) Name() string {
350350
return n.name
351351
}
352352

353-
func (n FieldNode) Object() ObjectNode {
353+
func (n FieldNode) Object() EmptiableObject {
354354
return n.object
355355
}
356356

grade/internal/infrastructure/repositories/endorser/endorser_specifications.go

Lines changed: 38 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8,125 +8,92 @@ import (
88
endorserVal "github.com/emacsway/grade/grade/internal/domain/endorser/values"
99
member "github.com/emacsway/grade/grade/internal/domain/member/values"
1010
"github.com/emacsway/grade/grade/internal/domain/seedwork/exporters"
11+
s "github.com/emacsway/grade/grade/internal/domain/seedwork/specification"
1112
tenant "github.com/emacsway/grade/grade/internal/domain/tenant/values"
12-
s "github.com/emacsway/grade/grade/internal/infrastructure/seedwork/specification"
13+
is "github.com/emacsway/grade/grade/internal/infrastructure/seedwork/specification"
1314
)
1415

1516
type EndorserCanCompleteEndorsementSpecification struct {
1617
endorser.EndorserCanCompleteEndorsementSpecification
1718
}
1819

1920
func (e *EndorserCanCompleteEndorsementSpecification) Compile() (sql string, params []driver.Valuer, err error) {
20-
v := s.NewPostgresqlVisitor(GlobalScopeContext{})
21-
err = e.Expression().Accept(v)
21+
exp := e.Expression()
22+
tv := is.NewTransformVisitor(GlobalScopeContext{})
23+
err = exp.Accept(tv)
2224
if err != nil {
23-
return "", []driver.Valuer{}, err
24-
}
25-
return v.Result()
26-
}
27-
28-
type EndorserIdContext struct {
29-
}
30-
31-
func (c EndorserIdContext) NameByPath(path ...string) (string, error) {
32-
if len(path) == 0 {
33-
return "", s.NewMissingFieldsError("tenantId", "memberId")
25+
return "", nil, err
3426
}
35-
switch path[0] {
36-
case "tenantId":
37-
return "tenant_id", nil
38-
case "memberId":
39-
return "member_id", nil
40-
default:
41-
return "", fmt.Errorf("can't get field \"%s\"", path[0])
27+
exp, err = tv.Result()
28+
if err != nil {
29+
return "", nil, err
4230
}
43-
}
44-
45-
func (c EndorserIdContext) Extract(val any) (driver.Valuer, error) {
46-
switch valTyped := val.(type) {
47-
case member.InternalMemberId:
48-
var ex exporters.UintExporter
49-
valTyped.Export(&ex)
50-
return nil, nil
51-
case tenant.TenantId:
52-
var ex exporters.UintExporter
53-
valTyped.Export(&ex)
54-
return nil, nil
55-
case member.MemberId:
56-
var ex MemberIdExporter
57-
valTyped.Export(&ex)
58-
return nil, s.NewMissingValuesError(ex.Values()...)
59-
default:
60-
return nil, fmt.Errorf("can't export \"%#v\"", val)
31+
v := is.NewPostgresqlVisitor()
32+
err = exp.Accept(v)
33+
if err != nil {
34+
return "", []driver.Valuer{}, err
6135
}
36+
return v.Result()
6237
}
6338

6439
type EndorserContext struct {
65-
id EndorserIdContext
6640
}
6741

68-
func (c EndorserContext) NameByPath(path ...string) (string, error) {
69-
prefix := "endorser"
70-
var name string
71-
var err error
42+
func (c EndorserContext) AttrNode(parent s.EmptiableObject, path []string) (s.Visitable, error) {
7243
switch path[0] {
7344
case "availableEndorsementCount":
74-
name = "available_endorsement_count"
45+
return s.Field(parent, "available_endorsement_count"), nil
7546
case "pendingEndorsementCount":
76-
name = "pending_endorsement_count"
47+
return s.Field(parent, "pending_endorsement_count"), nil
7748
case "id":
78-
name, err = c.id.NameByPath(path[1:]...)
79-
if err != nil {
80-
return "", err
81-
}
49+
return is.CompositeExpression(
50+
s.Field(parent, "tenant_id"),
51+
s.Field(parent, "member_id"),
52+
), nil
8253
default:
83-
return "", fmt.Errorf("can't get field \"%s\"", path[0])
84-
}
85-
return prefix + "." + name, nil
86-
}
87-
88-
func (c EndorserContext) Extract(val any) (driver.Valuer, error) {
89-
switch valTyped := val.(type) {
90-
case endorserVal.EndorsementCount:
91-
var ex exporters.UintExporter
92-
valTyped.Export(&ex)
93-
return ex, nil
94-
default:
95-
return nil, fmt.Errorf("can't export \"%#v\"", val)
54+
return nil, fmt.Errorf("can't get field \"%s\"", path[0])
9655
}
9756
}
9857

9958
type GlobalScopeContext struct {
10059
endorser EndorserContext
10160
}
10261

103-
func (c GlobalScopeContext) NameByPath(path ...string) (string, error) {
62+
func (c GlobalScopeContext) AttrNode(path []string) (s.Visitable, error) {
10463
switch path[0] {
10564
case "endorser":
106-
return c.endorser.NameByPath(path[1:]...)
65+
return c.endorser.AttrNode(s.Object(s.GlobalScope(), "endorser"), path[1:])
10766
default:
108-
return "", fmt.Errorf("can't get object \"%s\"", path[0])
67+
return nil, fmt.Errorf("can't get object \"%s\"", path[0])
10968
}
11069
}
11170

112-
func (c GlobalScopeContext) Extract(val any) (driver.Valuer, error) {
71+
func (c GlobalScopeContext) ValueNode(val any) (s.Visitable, error) {
11372
switch valTyped := val.(type) {
11473
case endorserVal.EndorsementCount:
11574
var ex exporters.UintExporter
11675
valTyped.Export(&ex)
117-
return ex, nil
76+
return s.Value(ex), nil
11877
case member.InternalMemberId:
11978
var ex exporters.UintExporter
12079
valTyped.Export(&ex)
121-
return nil, nil
80+
return s.Value(ex), nil
12281
case tenant.TenantId:
12382
var ex exporters.UintExporter
12483
valTyped.Export(&ex)
125-
return nil, nil
84+
return s.Value(ex), nil
12685
case member.MemberId:
12786
var ex MemberIdExporter
12887
valTyped.Export(&ex)
129-
return nil, s.NewMissingValuesError(ex.Values()...)
88+
nodes := []s.Visitable{}
89+
for _, v := range ex.Values() {
90+
node, err := c.ValueNode(v)
91+
if err != nil {
92+
return nil, err
93+
}
94+
nodes = append(nodes, node)
95+
}
96+
return is.CompositeExpression(nodes...), nil
13097
default:
13198
return nil, fmt.Errorf("can't export \"%#v\"", val)
13299
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package specification
2+
3+
import (
4+
s "github.com/emacsway/grade/grade/internal/domain/seedwork/specification"
5+
)
6+
7+
type ExpressionComposer interface {
8+
Equal(other CompositeExpressionNode) (s.Visitable, error)
9+
NotEqual(other CompositeExpressionNode) (s.Visitable, error)
10+
s.Visitable
11+
}
12+
13+
func CompositeExpression(nodes ...s.Visitable) CompositeExpressionNode {
14+
return CompositeExpressionNode{
15+
nodes: nodes,
16+
}
17+
}
18+
19+
type CompositeExpressionNode struct {
20+
nodes []s.Visitable
21+
}
22+
23+
func (n CompositeExpressionNode) Equal(other CompositeExpressionNode) (s.Visitable, error) {
24+
var operands []s.Visitable
25+
if len(n.nodes) != len(other.nodes) {
26+
return nil, ErrCompositeExpressionsDifferentLength
27+
}
28+
for i := range n.nodes {
29+
left, right := n.nodes[i], other.nodes[i]
30+
leftComposite, ok := left.(CompositeExpressionNode)
31+
if ok {
32+
rightComposite, ok := right.(CompositeExpressionNode)
33+
if !ok {
34+
return nil, ErrCompositeExpressionsDifferentLength
35+
}
36+
newNode, err := leftComposite.Equal(rightComposite)
37+
if err != nil {
38+
return nil, err
39+
}
40+
operands = append(operands, newNode)
41+
} else {
42+
operands = append(operands, s.Equal(left, right))
43+
}
44+
}
45+
return s.And(operands[0], operands[1:]...), nil
46+
}
47+
48+
func (n CompositeExpressionNode) NotEqual(other CompositeExpressionNode) (s.Visitable, error) {
49+
var operands []s.Visitable
50+
if len(n.nodes) != len(other.nodes) {
51+
return nil, ErrCompositeExpressionsDifferentLength
52+
}
53+
for i := range n.nodes {
54+
left, right := n.nodes[i], other.nodes[i]
55+
leftComposite, ok := left.(CompositeExpressionNode)
56+
if ok {
57+
rightComposite, ok := right.(CompositeExpressionNode)
58+
if !ok {
59+
return nil, ErrCompositeExpressionsDifferentLength
60+
}
61+
newNode, err := leftComposite.NotEqual(rightComposite)
62+
if err != nil {
63+
return nil, err
64+
}
65+
operands = append(operands, newNode)
66+
} else {
67+
operands = append(operands, s.NotEqual(left, right))
68+
}
69+
}
70+
return s.Not(s.And(operands[0], operands[1:]...)), nil
71+
}
72+
73+
func (n CompositeExpressionNode) Accept(v s.Visitor) error {
74+
return nil
75+
}

grade/internal/infrastructure/seedwork/specification/postgresql_visitor.go

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package specification
33
import (
44
"database/sql/driver"
55
"fmt"
6+
"strings"
67

78
s "github.com/emacsway/grade/grade/internal/domain/seedwork/specification"
89
)
@@ -15,10 +16,9 @@ func PlaceholderIndex(index uint8) PostgresqlVisitorOption {
1516
}
1617
}
1718

18-
func NewPostgresqlVisitor(context Context, opts ...PostgresqlVisitorOption) *PostgresqlVisitor {
19+
func NewPostgresqlVisitor(opts ...PostgresqlVisitorOption) *PostgresqlVisitor {
1920
v := &PostgresqlVisitor{
2021
precedenceMapping: make(map[string]int),
21-
Context: context,
2222
}
2323
// https://www.postgresql.org/docs/14/sql-syntax-lexical.html#SQL-PRECEDENCE-TABLE
2424
v.setPrecedence(160, ". LEFT")
@@ -48,19 +48,6 @@ type PostgresqlVisitor struct {
4848
parameters []driver.Valuer
4949
precedence int
5050
precedenceMapping map[string]int
51-
// currentItem Context
52-
stack []Context
53-
Context
54-
}
55-
56-
func (v *PostgresqlVisitor) Push(ctx Context) {
57-
v.stack = append(v.stack, v.Context)
58-
v.Context = ctx
59-
}
60-
61-
func (v *PostgresqlVisitor) Pop() {
62-
v.Context = v.stack[len(v.stack)-1]
63-
v.stack = v.stack[:len(v.stack)-1]
6451
}
6552

6653
func (v PostgresqlVisitor) getNodePrecedenceKey(n s.Operable) string {
@@ -98,7 +85,6 @@ func (v *PostgresqlVisitor) visit(precedenceKey string, callable func() error) e
9885
}
9986

10087
func (v *PostgresqlVisitor) VisitGlobalScope(_ s.GlobalScopeNode) error {
101-
// v.push(v.Context)
10288
return nil
10389
}
10490

@@ -111,25 +97,17 @@ func (v *PostgresqlVisitor) VisitCollection(n s.CollectionNode) error {
11197
}
11298

11399
func (v *PostgresqlVisitor) VisitItem(n s.ItemNode) error {
114-
// v.push(v.currentItem)
115100
return nil
116101
}
117102

118103
func (v *PostgresqlVisitor) VisitField(n s.FieldNode) error {
119-
name, err := v.Context.NameByPath(s.ExtractFieldPath(n)...)
120-
// v.pop()
121-
if err != nil {
122-
return err
123-
}
104+
name := strings.Join(s.ExtractFieldPath(n), ".")
124105
v.sql += name
125106
return nil
126107
}
127108

128109
func (v *PostgresqlVisitor) VisitValue(n s.ValueNode) error {
129-
val, err := v.Extract(n.Value())
130-
if err != nil {
131-
return err
132-
}
110+
val := n.Value().(driver.Valuer)
133111
v.parameters = append(v.parameters, val)
134112
v.sql += fmt.Sprintf("$%d", len(v.parameters))
135113
return nil
@@ -167,8 +145,3 @@ func (v *PostgresqlVisitor) VisitInfix(n s.InfixNode) error {
167145
func (v PostgresqlVisitor) Result() (sql string, params []driver.Valuer, err error) {
168146
return v.sql, v.parameters, nil
169147
}
170-
171-
type Context interface {
172-
NameByPath(...string) (string, error)
173-
Extract(any) (driver.Valuer, error)
174-
}

0 commit comments

Comments
 (0)