Skip to content

Commit b007ef2

Browse files
committed
Refactor TransformVisitor
1 parent 1a1a5bf commit b007ef2

6 files changed

Lines changed: 216 additions & 191 deletions

File tree

grade/internal/domain/seedwork/specification/nodes.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,23 +334,23 @@ func (n ItemNode) Accept(v Visitor) error {
334334
return v.VisitItem(n)
335335
}
336336

337-
func Field(object ObjectNode, name string) FieldNode {
337+
func Field(object EmptiableObject, name string) FieldNode {
338338
return FieldNode{
339339
object: object,
340340
name: name,
341341
}
342342
}
343343

344344
type FieldNode struct {
345-
object ObjectNode
345+
object EmptiableObject
346346
name string
347347
}
348348

349349
func (n FieldNode) Name() string {
350350
return n.name
351351
}
352352

353-
func (n FieldNode) Object() ObjectNode {
353+
func (n FieldNode) Object() EmptiableObject {
354354
return n.object
355355
}
356356

grade/internal/infrastructure/repositories/endorser/endorser_specifications.go

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,18 @@ type EndorserCanCompleteEndorsementSpecification struct {
1717
}
1818

1919
func (e *EndorserCanCompleteEndorsementSpecification) Compile() (sql string, params []driver.Valuer, err error) {
20-
v := s.NewPostgresqlVisitor(GlobalScopeContext{})
21-
err = e.Expression().Accept(v)
20+
exp := e.Expression()
21+
tv := s.NewTransformVisitor(GlobalScopeContext{})
22+
err = exp.Accept(tv)
23+
if err != nil {
24+
return "", nil, err
25+
}
26+
exp, err = tv.Result()
27+
if err != nil {
28+
return "", nil, err
29+
}
30+
v := s.NewPostgresqlVisitor()
31+
err = exp.Accept(v)
2232
if err != nil {
2333
return "", []driver.Valuer{}, err
2434
}
@@ -28,30 +38,31 @@ func (e *EndorserCanCompleteEndorsementSpecification) Compile() (sql string, par
2838
type EndorserIdContext struct {
2939
}
3040

31-
func (c EndorserIdContext) NameByPath(path ...string) (string, error) {
41+
func (c EndorserIdContext) NameByPath(path []string) ([]string, error) {
3242
if len(path) == 0 {
33-
return "", s.NewMissingFieldsError("tenantId", "memberId")
43+
return nil, s.NewMissingFieldsError("tenantId", "memberId")
3444
}
3545
switch path[0] {
3646
case "tenantId":
37-
return "tenant_id", nil
47+
return []string{"tenant_id"}, nil
3848
case "memberId":
39-
return "member_id", nil
49+
return []string{"member_id"}, nil
4050
default:
41-
return "", fmt.Errorf("can't get field \"%s\"", path[0])
51+
return nil, fmt.Errorf("can't get field \"%s\"", path[0])
4252
}
4353
}
4454

4555
func (c EndorserIdContext) Extract(val any) (driver.Valuer, error) {
56+
// TODO: это не будет работать, т.к. ValueNode может идти первым операндом. Нужно разделять интерфейсы.
4657
switch valTyped := val.(type) {
4758
case member.InternalMemberId:
4859
var ex exporters.UintExporter
4960
valTyped.Export(&ex)
50-
return nil, nil
61+
return ex, nil
5162
case tenant.TenantId:
5263
var ex exporters.UintExporter
5364
valTyped.Export(&ex)
54-
return nil, nil
65+
return ex, nil
5566
case member.MemberId:
5667
var ex MemberIdExporter
5768
valTyped.Export(&ex)
@@ -65,24 +76,22 @@ type EndorserContext struct {
6576
id EndorserIdContext
6677
}
6778

68-
func (c EndorserContext) NameByPath(path ...string) (string, error) {
69-
prefix := "endorser"
70-
var name string
71-
var err error
79+
func (c EndorserContext) NameByPath(path []string) ([]string, error) {
80+
prefix := []string{"endorser"}
7281
switch path[0] {
7382
case "availableEndorsementCount":
74-
name = "available_endorsement_count"
83+
return append(prefix, "available_endorsement_count"), nil
7584
case "pendingEndorsementCount":
76-
name = "pending_endorsement_count"
85+
return append(prefix, "pending_endorsement_count"), nil
7786
case "id":
78-
name, err = c.id.NameByPath(path[1:]...)
87+
names, err := c.id.NameByPath(path[1:])
7988
if err != nil {
80-
return "", err
89+
return nil, err
8190
}
91+
return append(prefix, names...), nil
8292
default:
83-
return "", fmt.Errorf("can't get field \"%s\"", path[0])
93+
return nil, fmt.Errorf("can't get field \"%s\"", path[0])
8494
}
85-
return prefix + "." + name, nil
8695
}
8796

8897
func (c EndorserContext) Extract(val any) (driver.Valuer, error) {
@@ -100,12 +109,12 @@ type GlobalScopeContext struct {
100109
endorser EndorserContext
101110
}
102111

103-
func (c GlobalScopeContext) NameByPath(path ...string) (string, error) {
112+
func (c GlobalScopeContext) NameByPath(path []string) ([]string, error) {
104113
switch path[0] {
105114
case "endorser":
106-
return c.endorser.NameByPath(path[1:]...)
115+
return c.endorser.NameByPath(path[1:])
107116
default:
108-
return "", fmt.Errorf("can't get object \"%s\"", path[0])
117+
return nil, fmt.Errorf("can't get object \"%s\"", path[0])
109118
}
110119
}
111120

@@ -118,11 +127,11 @@ func (c GlobalScopeContext) Extract(val any) (driver.Valuer, error) {
118127
case member.InternalMemberId:
119128
var ex exporters.UintExporter
120129
valTyped.Export(&ex)
121-
return nil, nil
130+
return ex, nil
122131
case tenant.TenantId:
123132
var ex exporters.UintExporter
124133
valTyped.Export(&ex)
125-
return nil, nil
134+
return ex, nil
126135
case member.MemberId:
127136
var ex MemberIdExporter
128137
valTyped.Export(&ex)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package specification
2+
3+
import (
4+
s "github.com/emacsway/grade/grade/internal/domain/seedwork/specification"
5+
)
6+
7+
type CompositeExpression interface {
8+
Equal(other CompositeExpressionNode) (s.Visitable, error)
9+
NotEqual(other CompositeExpressionNode) (s.Visitable, error)
10+
Add(nodes ...s.Visitable)
11+
s.Visitable
12+
}
13+
14+
type CompositeExpressionNode struct {
15+
nodes []s.Visitable
16+
}
17+
18+
func (n *CompositeExpressionNode) Add(nodes ...s.Visitable) {
19+
n.nodes = append(n.nodes, nodes...)
20+
}
21+
22+
func (n CompositeExpressionNode) Equal(other CompositeExpressionNode) (s.Visitable, error) {
23+
var operands []s.Visitable
24+
if len(n.nodes) != len(other.nodes) {
25+
return nil, ErrCompositeExpressionsDifferentLength
26+
}
27+
for i := range n.nodes {
28+
left, right := n.nodes[i], other.nodes[i]
29+
leftComposite, ok := left.(CompositeExpressionNode)
30+
if ok {
31+
rightComposite, ok := right.(CompositeExpressionNode)
32+
if !ok {
33+
return nil, ErrCompositeExpressionsDifferentLength
34+
}
35+
newNode, err := leftComposite.Equal(rightComposite)
36+
if err != nil {
37+
return nil, err
38+
}
39+
operands = append(operands, newNode)
40+
} else {
41+
operands = append(operands, s.Equal(left, right))
42+
}
43+
}
44+
return s.And(operands[0], operands[1:]...), nil
45+
}
46+
47+
func (n CompositeExpressionNode) NotEqual(other CompositeExpressionNode) (s.Visitable, error) {
48+
var operands []s.Visitable
49+
if len(n.nodes) != len(other.nodes) {
50+
return nil, ErrCompositeExpressionsDifferentLength
51+
}
52+
for i := range n.nodes {
53+
left, right := n.nodes[i], other.nodes[i]
54+
leftComposite, ok := left.(CompositeExpressionNode)
55+
if ok {
56+
rightComposite, ok := right.(CompositeExpressionNode)
57+
if !ok {
58+
return nil, ErrCompositeExpressionsDifferentLength
59+
}
60+
newNode, err := leftComposite.NotEqual(rightComposite)
61+
if err != nil {
62+
return nil, err
63+
}
64+
operands = append(operands, newNode)
65+
} else {
66+
operands = append(operands, s.NotEqual(left, right))
67+
}
68+
}
69+
return s.Not(s.And(operands[0], operands[1:]...)), nil
70+
}
71+
72+
func (n CompositeExpressionNode) Accept(v s.Visitor) error {
73+
return nil
74+
}

grade/internal/infrastructure/seedwork/specification/postgresql_visitor.go

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package specification
33
import (
44
"database/sql/driver"
55
"fmt"
6+
"strings"
67

78
s "github.com/emacsway/grade/grade/internal/domain/seedwork/specification"
89
)
@@ -15,10 +16,9 @@ func PlaceholderIndex(index uint8) PostgresqlVisitorOption {
1516
}
1617
}
1718

18-
func NewPostgresqlVisitor(context Context, opts ...PostgresqlVisitorOption) *PostgresqlVisitor {
19+
func NewPostgresqlVisitor(opts ...PostgresqlVisitorOption) *PostgresqlVisitor {
1920
v := &PostgresqlVisitor{
2021
precedenceMapping: make(map[string]int),
21-
Context: context,
2222
}
2323
// https://www.postgresql.org/docs/14/sql-syntax-lexical.html#SQL-PRECEDENCE-TABLE
2424
v.setPrecedence(160, ". LEFT")
@@ -48,19 +48,6 @@ type PostgresqlVisitor struct {
4848
parameters []driver.Valuer
4949
precedence int
5050
precedenceMapping map[string]int
51-
// currentItem Context
52-
stack []Context
53-
Context
54-
}
55-
56-
func (v *PostgresqlVisitor) Push(ctx Context) {
57-
v.stack = append(v.stack, v.Context)
58-
v.Context = ctx
59-
}
60-
61-
func (v *PostgresqlVisitor) Pop() {
62-
v.Context = v.stack[len(v.stack)-1]
63-
v.stack = v.stack[:len(v.stack)-1]
6451
}
6552

6653
func (v PostgresqlVisitor) getNodePrecedenceKey(n s.Operable) string {
@@ -98,7 +85,6 @@ func (v *PostgresqlVisitor) visit(precedenceKey string, callable func() error) e
9885
}
9986

10087
func (v *PostgresqlVisitor) VisitGlobalScope(_ s.GlobalScopeNode) error {
101-
// v.push(v.Context)
10288
return nil
10389
}
10490

@@ -111,25 +97,17 @@ func (v *PostgresqlVisitor) VisitCollection(n s.CollectionNode) error {
11197
}
11298

11399
func (v *PostgresqlVisitor) VisitItem(n s.ItemNode) error {
114-
// v.push(v.currentItem)
115100
return nil
116101
}
117102

118103
func (v *PostgresqlVisitor) VisitField(n s.FieldNode) error {
119-
name, err := v.Context.NameByPath(s.ExtractFieldPath(n)...)
120-
// v.pop()
121-
if err != nil {
122-
return err
123-
}
104+
name := strings.Join(s.ExtractFieldPath(n), ".")
124105
v.sql += name
125106
return nil
126107
}
127108

128109
func (v *PostgresqlVisitor) VisitValue(n s.ValueNode) error {
129-
val, err := v.Extract(n.Value())
130-
if err != nil {
131-
return err
132-
}
110+
val := n.Value().(driver.Valuer)
133111
v.parameters = append(v.parameters, val)
134112
v.sql += fmt.Sprintf("$%d", len(v.parameters))
135113
return nil
@@ -167,8 +145,3 @@ func (v *PostgresqlVisitor) VisitInfix(n s.InfixNode) error {
167145
func (v PostgresqlVisitor) Result() (sql string, params []driver.Valuer, err error) {
168146
return v.sql, v.parameters, nil
169147
}
170-
171-
type Context interface {
172-
NameByPath(...string) (string, error)
173-
Extract(any) (driver.Valuer, error)
174-
}

0 commit comments

Comments
 (0)