Skip to content

Commit 4a8d41d

Browse files
committed
Refactor TransformVisitor
1 parent 1a1a5bf commit 4a8d41d

4 files changed

Lines changed: 61 additions & 82 deletions

File tree

grade/internal/infrastructure/repositories/endorser/endorser_specifications.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,18 @@ type EndorserCanCompleteEndorsementSpecification struct {
1717
}
1818

1919
func (e *EndorserCanCompleteEndorsementSpecification) Compile() (sql string, params []driver.Valuer, err error) {
20-
v := s.NewPostgresqlVisitor(GlobalScopeContext{})
21-
err = e.Expression().Accept(v)
20+
exp := e.Expression()
21+
tv := s.NewTransformVisitor(GlobalScopeContext{})
22+
err = exp.Accept(tv)
23+
if err != nil {
24+
return "", nil, err
25+
}
26+
exp, err = tv.Result()
27+
if err != nil {
28+
return "", nil, err
29+
}
30+
v := s.NewPostgresqlVisitor()
31+
err = exp.Accept(v)
2232
if err != nil {
2333
return "", []driver.Valuer{}, err
2434
}
@@ -47,11 +57,11 @@ func (c EndorserIdContext) Extract(val any) (driver.Valuer, error) {
4757
case member.InternalMemberId:
4858
var ex exporters.UintExporter
4959
valTyped.Export(&ex)
50-
return nil, nil
60+
return ex, nil
5161
case tenant.TenantId:
5262
var ex exporters.UintExporter
5363
valTyped.Export(&ex)
54-
return nil, nil
64+
return ex, nil
5565
case member.MemberId:
5666
var ex MemberIdExporter
5767
valTyped.Export(&ex)
@@ -118,11 +128,11 @@ func (c GlobalScopeContext) Extract(val any) (driver.Valuer, error) {
118128
case member.InternalMemberId:
119129
var ex exporters.UintExporter
120130
valTyped.Export(&ex)
121-
return nil, nil
131+
return ex, nil
122132
case tenant.TenantId:
123133
var ex exporters.UintExporter
124134
valTyped.Export(&ex)
125-
return nil, nil
135+
return ex, nil
126136
case member.MemberId:
127137
var ex MemberIdExporter
128138
valTyped.Export(&ex)

grade/internal/infrastructure/seedwork/specification/postgresql_visitor.go

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package specification
33
import (
44
"database/sql/driver"
55
"fmt"
6+
"strings"
67

78
s "github.com/emacsway/grade/grade/internal/domain/seedwork/specification"
89
)
@@ -15,10 +16,9 @@ func PlaceholderIndex(index uint8) PostgresqlVisitorOption {
1516
}
1617
}
1718

18-
func NewPostgresqlVisitor(context Context, opts ...PostgresqlVisitorOption) *PostgresqlVisitor {
19+
func NewPostgresqlVisitor(opts ...PostgresqlVisitorOption) *PostgresqlVisitor {
1920
v := &PostgresqlVisitor{
2021
precedenceMapping: make(map[string]int),
21-
Context: context,
2222
}
2323
// https://www.postgresql.org/docs/14/sql-syntax-lexical.html#SQL-PRECEDENCE-TABLE
2424
v.setPrecedence(160, ". LEFT")
@@ -48,19 +48,6 @@ type PostgresqlVisitor struct {
4848
parameters []driver.Valuer
4949
precedence int
5050
precedenceMapping map[string]int
51-
// currentItem Context
52-
stack []Context
53-
Context
54-
}
55-
56-
func (v *PostgresqlVisitor) Push(ctx Context) {
57-
v.stack = append(v.stack, v.Context)
58-
v.Context = ctx
59-
}
60-
61-
func (v *PostgresqlVisitor) Pop() {
62-
v.Context = v.stack[len(v.stack)-1]
63-
v.stack = v.stack[:len(v.stack)-1]
6451
}
6552

6653
func (v PostgresqlVisitor) getNodePrecedenceKey(n s.Operable) string {
@@ -98,7 +85,6 @@ func (v *PostgresqlVisitor) visit(precedenceKey string, callable func() error) e
9885
}
9986

10087
func (v *PostgresqlVisitor) VisitGlobalScope(_ s.GlobalScopeNode) error {
101-
// v.push(v.Context)
10288
return nil
10389
}
10490

@@ -111,25 +97,17 @@ func (v *PostgresqlVisitor) VisitCollection(n s.CollectionNode) error {
11197
}
11298

11399
func (v *PostgresqlVisitor) VisitItem(n s.ItemNode) error {
114-
// v.push(v.currentItem)
115100
return nil
116101
}
117102

118103
func (v *PostgresqlVisitor) VisitField(n s.FieldNode) error {
119-
name, err := v.Context.NameByPath(s.ExtractFieldPath(n)...)
120-
// v.pop()
121-
if err != nil {
122-
return err
123-
}
104+
name := strings.Join(s.ExtractFieldPath(n), ".")
124105
v.sql += name
125106
return nil
126107
}
127108

128109
func (v *PostgresqlVisitor) VisitValue(n s.ValueNode) error {
129-
val, err := v.Extract(n.Value())
130-
if err != nil {
131-
return err
132-
}
110+
val := n.Value().(driver.Valuer)
133111
v.parameters = append(v.parameters, val)
134112
v.sql += fmt.Sprintf("$%d", len(v.parameters))
135113
return nil
@@ -167,8 +145,3 @@ func (v *PostgresqlVisitor) VisitInfix(n s.InfixNode) error {
167145
func (v PostgresqlVisitor) Result() (sql string, params []driver.Valuer, err error) {
168146
return v.sql, v.parameters, nil
169147
}
170-
171-
type Context interface {
172-
NameByPath(...string) (string, error)
173-
Extract(any) (driver.Valuer, error)
174-
}

grade/internal/infrastructure/seedwork/specification/transform_visitor.go

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,20 @@ import (
44
"errors"
55
"fmt"
66

7+
"database/sql/driver"
8+
79
s "github.com/emacsway/grade/grade/internal/domain/seedwork/specification"
810
)
911

1012
var (
1113
ErrCompositeExpressionsDifferentLength = errors.New("composite expressions have different length")
1214
)
1315

16+
type Context interface {
17+
NameByPath(...string) (string, error)
18+
Extract(any) (driver.Valuer, error)
19+
}
20+
1421
func NewTransformVisitor(context Context) *TransformVisitor {
1522
return &TransformVisitor{
1623
Context: context,
@@ -20,10 +27,22 @@ func NewTransformVisitor(context Context) *TransformVisitor {
2027
type TransformVisitor struct {
2128
compositeExpressions []CompositeExpression
2229
currentNode s.Visitable
30+
stack []Context
2331
Context
2432
}
2533

34+
func (v *TransformVisitor) Push(ctx Context) {
35+
v.stack = append(v.stack, v.Context)
36+
v.Context = ctx
37+
}
38+
39+
func (v *TransformVisitor) Pop() {
40+
v.Context = v.stack[len(v.stack)-1]
41+
v.stack = v.stack[:len(v.stack)-1]
42+
}
43+
2644
func (v *TransformVisitor) VisitGlobalScope(_ s.GlobalScopeNode) error {
45+
// v.push(v.Context)
2746
return nil
2847
}
2948

@@ -36,32 +55,34 @@ func (v *TransformVisitor) VisitCollection(n s.CollectionNode) error {
3655
}
3756

3857
func (v *TransformVisitor) VisitItem(n s.ItemNode) error {
58+
// v.push(v.currentItem)
3959
return nil
4060
}
4161

4262
func (v *TransformVisitor) VisitField(n s.FieldNode) error {
43-
_, err := v.Context.NameByPath(s.ExtractFieldPath(n)...)
63+
name, err := v.Context.NameByPath(s.ExtractFieldPath(n)...)
64+
// v.pop()
4465
if err != nil {
4566
if errTyped, ok := err.(MissingFieldsError); ok {
4667
names := errTyped.MissingFieldNames()
47-
o := s.Object(n.Object(), n.Name())
4868
compositeExpression := CompositeExpression{}
4969
for i := range names {
50-
// TODO: use n.Object() instead of o?
51-
compositeExpression.Add(s.Field(o, names[i]))
70+
compositeExpression.Add(s.Field(n.Object(), names[i]))
5271
}
5372
v.compositeExpressions = append(v.compositeExpressions, compositeExpression)
73+
v.currentNode = nil
5474
return nil
5575
} else {
5676
return err
5777
}
78+
} else {
79+
v.currentNode = s.Field(n.Object(), name)
5880
}
59-
v.currentNode = n
6081
return nil
6182
}
6283

6384
func (v *TransformVisitor) VisitValue(n s.ValueNode) error {
64-
_, err := v.Extract(n.Value())
85+
val, err := v.Extract(n.Value())
6586
if err != nil {
6687
if errTyped, ok := err.(MissingValuesError); ok {
6788
values := errTyped.MissingValues()
@@ -70,12 +91,14 @@ func (v *TransformVisitor) VisitValue(n s.ValueNode) error {
7091
compositeExpression.Add(s.Value(values[i]))
7192
}
7293
v.compositeExpressions = append(v.compositeExpressions, compositeExpression)
94+
v.currentNode = nil
7395
return nil
7496
} else {
7597
return err
7698
}
99+
} else {
100+
v.currentNode = s.Value(val)
77101
}
78-
v.currentNode = n
79102
return nil
80103
}
81104

@@ -136,36 +159,6 @@ func (v TransformVisitor) Result() (s.Visitable, error) {
136159
return v.currentNode, nil
137160
}
138161

139-
type CompositeExpression struct {
140-
nodes []s.Visitable
141-
}
142-
143-
func (n *CompositeExpression) Add(nodes ...s.Visitable) {
144-
n.nodes = append(n.nodes, nodes...)
145-
}
146-
147-
func (n CompositeExpression) Equal(other CompositeExpression) (s.Visitable, error) {
148-
var operands []s.Visitable
149-
if len(n.nodes) != len(other.nodes) {
150-
return nil, ErrCompositeExpressionsDifferentLength
151-
}
152-
for i := range n.nodes {
153-
operands = append(operands, s.Equal(n.nodes[i], other.nodes[i]))
154-
}
155-
return s.And(operands[0], operands[1:]...), nil
156-
}
157-
158-
func (n CompositeExpression) NotEqual(other CompositeExpression) (s.Visitable, error) {
159-
var operands []s.Visitable
160-
if len(n.nodes) != len(other.nodes) {
161-
return nil, ErrCompositeExpressionsDifferentLength
162-
}
163-
for i := range n.nodes {
164-
operands = append(operands, s.Equal(n.nodes[i], other.nodes[i]))
165-
}
166-
return s.Not(s.And(operands[0], operands[1:]...)), nil
167-
}
168-
169162
func NewMissingFieldsError(names ...string) MissingFieldsError {
170163
return MissingFieldsError{
171164
missingFieldNames: names,

grade/internal/infrastructure/seedwork/specification/transform_visitor_test.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func (ss SomethingSpecification) Evaluate( /* session session.PgxSession */ ) (
5050
sql string, params []driver.Valuer, err error,
5151
) {
5252
exp := ss.Expression()
53-
tv := NewTransformVisitor(TestContext{})
53+
tv := NewTransformVisitor(TestGlobalScopeContext{})
5454
err = exp.Accept(tv)
5555
if err != nil {
5656
return "", nil, err
@@ -59,7 +59,7 @@ func (ss SomethingSpecification) Evaluate( /* session session.PgxSession */ ) (
5959
if err != nil {
6060
return "", nil, err
6161
}
62-
v := NewPostgresqlVisitor(TestContext{})
62+
v := NewPostgresqlVisitor()
6363
err = exp.Accept(v)
6464
if err != nil {
6565
return "", nil, err
@@ -109,10 +109,10 @@ type SomethingId struct {
109109
identity.IntIdentity
110110
}
111111

112-
type TestContext struct {
112+
type TestGlobalScopeContext struct {
113113
}
114114

115-
func (c TestContext) NameByPath(path ...string) (string, error) {
115+
func (c TestGlobalScopeContext) NameByPath(path ...string) (string, error) {
116116
switch path[0] {
117117
case "something":
118118
return c.somethingPath("something", path[1:]...)
@@ -121,7 +121,7 @@ func (c TestContext) NameByPath(path ...string) (string, error) {
121121
}
122122
}
123123

124-
func (c TestContext) somethingPath(prefix string, path ...string) (string, error) {
124+
func (c TestGlobalScopeContext) somethingPath(prefix string, path ...string) (string, error) {
125125
switch path[0] {
126126
case "id":
127127
// FIXME: In case of stack implementation it will not work with member_id because this attrite is present on both cases:
@@ -132,13 +132,16 @@ func (c TestContext) somethingPath(prefix string, path ...string) (string, error
132132
// Кажется, решение в том, чтобы выделить TransformContext с правилами преобразования.
133133
// Нужно подумать что делать с полями сущностей 3-го и более глубокого уровня вложенности.
134134
// В принципе, там должны получаться многоуровневые JOINs.
135+
// Метод Extract() можно устранить, если значение возвращать тоже в контексте.
136+
// Кажется, TransformVisitor можно вообще выбросить, т.к. сам контекст может возвращать CompositeExpression.
137+
// Он все-равно управляет маппингом через err. Он создан для маппинга.
135138
return c.somethingIdPath(prefix, path[1:]...)
136139
default:
137140
return "", fmt.Errorf("can't get field \"%s\"", path[0])
138141
}
139142
}
140143

141-
func (c TestContext) somethingIdPath(prefix string, path ...string) (string, error) {
144+
func (c TestGlobalScopeContext) somethingIdPath(prefix string, path ...string) (string, error) {
142145
if len(path) == 0 {
143146
return "", NewMissingFieldsError("memberId", "somethingId")
144147
}
@@ -152,7 +155,7 @@ func (c TestContext) somethingIdPath(prefix string, path ...string) (string, err
152155
}
153156
}
154157

155-
func (c TestContext) somethingIdMemberIdPath(prefix string, path ...string) (string, error) {
158+
func (c TestGlobalScopeContext) somethingIdMemberIdPath(prefix string, path ...string) (string, error) {
156159
if len(path) == 0 {
157160
return "", NewMissingFieldsError("tenantId", "memberId")
158161
}
@@ -166,7 +169,7 @@ func (c TestContext) somethingIdMemberIdPath(prefix string, path ...string) (str
166169
}
167170
}
168171

169-
func (c TestContext) Extract(val any) (driver.Valuer, error) {
172+
func (c TestGlobalScopeContext) Extract(val any) (driver.Valuer, error) {
170173
switch valTyped := val.(type) {
171174
case InternalMemberId:
172175
var ex exporters.UintExporter

0 commit comments

Comments
 (0)