From 500e8d9300c4160b76c99ffc2c43efda992dfbba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Fri, 24 Apr 2026 12:23:29 -0700 Subject: [PATCH] add support for exhaustive checking of switch on enum --- ast/visitor.go | 5 +- bbq/compiler/compiler.go | 5 +- cmd/errors/errors.go | 6 + interpreter/switch_test.go | 53 +++++ sema/check_block.go | 25 ++- sema/check_composite_declaration.go | 4 + sema/check_exhaustive_switch.go | 46 +++++ sema/check_switch.go | 120 +++++++++++- sema/errors.go | 21 ++ sema/switch_test.go | 290 ++++++++++++++++++++++++++++ sema/type.go | 1 + 11 files changed, 570 insertions(+), 6 deletions(-) create mode 100644 sema/check_exhaustive_switch.go diff --git a/ast/visitor.go b/ast/visitor.go index 7f1cfa4536..abb590a491 100644 --- a/ast/visitor.go +++ b/ast/visitor.go @@ -38,13 +38,13 @@ type StatementDeclarationVisitor[T any] interface { VisitEntitlementDeclaration(*EntitlementDeclaration) T VisitEntitlementMappingDeclaration(*EntitlementMappingDeclaration) T VisitTransactionDeclaration(*TransactionDeclaration) T + VisitPragmaDeclaration(*PragmaDeclaration) T } type DeclarationVisitor[T any] interface { StatementDeclarationVisitor[T] VisitFieldDeclaration(*FieldDeclaration) T VisitEnumCaseDeclaration(*EnumCaseDeclaration) T - VisitPragmaDeclaration(*PragmaDeclaration) T VisitImportDeclaration(*ImportDeclaration) T } @@ -177,6 +177,9 @@ func AcceptStatement[T any](statement Statement, visitor StatementVisitor[T]) (_ case ElementTypeRemoveStatement: return visitor.VisitRemoveStatement(statement.(*RemoveStatement)) + + case ElementTypePragmaDeclaration: + return visitor.VisitPragmaDeclaration(statement.(*PragmaDeclaration)) } panic(errors.NewUnreachableError()) diff --git a/bbq/compiler/compiler.go b/bbq/compiler/compiler.go index dac1112336..b2228dfe40 100644 --- a/bbq/compiler/compiler.go +++ b/bbq/compiler/compiler.go @@ -4165,8 +4165,9 @@ func (c *Compiler[_, _]) VisitFieldDeclaration(_ *ast.FieldDeclaration) (_ struc } func (c *Compiler[_, _]) VisitPragmaDeclaration(_ *ast.PragmaDeclaration) (_ struct{}) { - // TODO - panic(errors.NewUnreachableError()) + // Pragmas are directives for the checker (e.g. #exhaustive). + // Nothing to compile. + return } func (c *Compiler[_, _]) VisitImportDeclaration(declaration *ast.ImportDeclaration) (_ struct{}) { diff --git a/cmd/errors/errors.go b/cmd/errors/errors.go index eb3b667f21..1e304ccd3c 100644 --- a/cmd/errors/errors.go +++ b/cmd/errors/errors.go @@ -1490,6 +1490,12 @@ func generateErrors() []namedError { Pos: placeholderPosition, }, }, + {"sema.MissingSwitchCasesError", + &sema.MissingSwitchCasesError{ + MissingCases: placeholderStrings, + Range: placeholderRange, + }, + }, {"sema.MissingTypeArgumentError", &sema.MissingTypeArgumentError{ TypeParameterName: placeholderString, diff --git a/interpreter/switch_test.go b/interpreter/switch_test.go index 5216d74ada..124f2344f5 100644 --- a/interpreter/switch_test.go +++ b/interpreter/switch_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/onflow/cadence/common" "github.com/onflow/cadence/interpreter" . "github.com/onflow/cadence/test_utils/interpreter_utils" ) @@ -233,4 +234,56 @@ func TestInterpretSwitchStatement(t *testing.T) { AssertValuesEqual(t, inter, testCase.expected, actual) } }) + + t.Run("exhaustive enum", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndPrepare(t, ` + enum Color: UInt8 { + case red + case green + case blue + } + + fun test(): [String] { + let results: [String] = [] + let rawValues: [UInt8] = [0, 1, 2] + for rawValue in rawValues { + let c = Color(rawValue: rawValue)! + results.append(name(c)) + } + return results + } + + fun name(_ c: Color): String { + #exhaustive + switch c { + case Color.red: + return "red" + case Color.green: + return "green" + case Color.blue: + return "blue" + } + } + `) + + actual, err := inter.Invoke("test") + require.NoError(t, err) + + AssertValuesEqual(t, inter, + interpreter.NewArrayValue( + inter, + &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + }, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue("red"), + interpreter.NewUnmeteredStringValue("green"), + interpreter.NewUnmeteredStringValue("blue"), + ), + actual, + ) + }) } diff --git a/sema/check_block.go b/sema/check_block.go index 294620edd1..c00e6c6034 100644 --- a/sema/check_block.go +++ b/sema/check_block.go @@ -32,7 +32,7 @@ func (checker *Checker) visitStatements(statements []ast.Statement) { functionActivation := checker.functionActivations.Current() // check all statements - for _, statement := range statements { + for i, statement := range statements { // Is this statement unreachable? Report it once for this statement, // but avoid noise and don't report it for all remaining unreachable statements @@ -60,6 +60,23 @@ func (checker *Checker) visitStatements(statements []ast.Statement) { // check statement + switch stmt := statement.(type) { + case *ast.SwitchStatement: + isExhaustive := isPrecedingStatementExhaustivePragma(statements, i) + checker.checkSwitchStatement(stmt, isExhaustive) + continue + + case *ast.PragmaDeclaration: + if isExhaustivePragma(stmt) && !isStatementFollowedBySwitch(statements, i) { + checker.report( + &InvalidPragmaError{ + Message: "the #exhaustive pragma must be placed directly before a switch statement", + Range: ast.NewRangeFromPositioned(checker.memoryGauge, stmt), + }, + ) + } + } + ast.AcceptStatement[struct{}](statement, checker) } } @@ -75,9 +92,13 @@ func (checker *Checker) checkValidStatement(statement ast.Statement) bool { // Only function and variable declarations are allowed locally - switch declaration.(type) { + switch decl := declaration.(type) { case *ast.FunctionDeclaration, *ast.VariableDeclaration: return true + case *ast.PragmaDeclaration: + if isExhaustivePragma(decl) { + return true + } } identifier := declaration.DeclarationIdentifier() diff --git a/sema/check_composite_declaration.go b/sema/check_composite_declaration.go index 16ba3899f4..a0a750a4e8 100644 --- a/sema/check_composite_declaration.go +++ b/sema/check_composite_declaration.go @@ -1090,6 +1090,8 @@ func (checker *Checker) declareEnumConstructor( memberCaseTypeAnnotation := NewTypeAnnotation(compositeType) + compositeType.EnumCases = make([]string, 0, len(enumCases)) + for _, enumCase := range enumCases { caseName := enumCase.Identifier.Identifier @@ -1097,6 +1099,8 @@ func (checker *Checker) declareEnumConstructor( continue } + compositeType.EnumCases = append(compositeType.EnumCases, caseName) + enumLookupFunctionType.Members.Set( caseName, &Member{ diff --git a/sema/check_exhaustive_switch.go b/sema/check_exhaustive_switch.go new file mode 100644 index 0000000000..396051159b --- /dev/null +++ b/sema/check_exhaustive_switch.go @@ -0,0 +1,46 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Flow Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sema + +import "github.com/onflow/cadence/ast" + +func isExhaustivePragma(pragma *ast.PragmaDeclaration) bool { + ident, ok := pragma.Expression.(*ast.IdentifierExpression) + return ok && ident.Identifier.Identifier == "exhaustive" +} + +// isPrecedingStatementExhaustivePragma checks whether the statement at the given index +// is immediately preceded by an #exhaustive pragma. +func isPrecedingStatementExhaustivePragma(statements []ast.Statement, index int) bool { + if index == 0 { + return false + } + pragma, ok := statements[index-1].(*ast.PragmaDeclaration) + return ok && isExhaustivePragma(pragma) +} + +// isStatementFollowedBySwitch checks whether the statement at the given index +// is immediately followed by a switch statement. +func isStatementFollowedBySwitch(statements []ast.Statement, index int) bool { + if index+1 >= len(statements) { + return false + } + _, ok := statements[index+1].(*ast.SwitchStatement) + return ok +} diff --git a/sema/check_switch.go b/sema/check_switch.go index 5cfd1cedc6..963001eea0 100644 --- a/sema/check_switch.go +++ b/sema/check_switch.go @@ -20,9 +20,15 @@ package sema import ( "github.com/onflow/cadence/ast" + "github.com/onflow/cadence/common" ) func (checker *Checker) VisitSwitchStatement(statement *ast.SwitchStatement) (_ struct{}) { + checker.checkSwitchStatement(statement, false) + return +} + +func (checker *Checker) checkSwitchStatement(statement *ast.SwitchStatement, isExhaustive bool) { testType := checker.VisitExpression(statement.Expression, statement, nil) @@ -39,6 +45,12 @@ func (checker *Checker) VisitSwitchStatement(statement *ast.SwitchStatement) (_ ) } + // If the #exhaustive pragma was set, verify exhaustiveness. + + if isExhaustive { + isExhaustive = checker.checkSwitchExhaustiveOverEnum(statement, testType) + } + // Check all cases checker.functionActivations.Current().WithSwitch(func() { @@ -47,10 +59,101 @@ func (checker *Checker) VisitSwitchStatement(statement *ast.SwitchStatement) (_ statement.Cases, testType, testTypeIsValid, + isExhaustive, ) }) +} - return +// checkSwitchExhaustiveOverEnum checks whether a switch statement on an enum type +// covers all enum cases. Reports errors if the test type is not an enum +// or if not all enum cases are covered. Returns true if the switch is exhaustive. +func (checker *Checker) checkSwitchExhaustiveOverEnum( + statement *ast.SwitchStatement, + testType Type, +) bool { + compositeType, ok := testType.(*CompositeType) + if !ok || compositeType.Kind != common.CompositeKindEnum { + checker.report( + &InvalidPragmaError{ + Message: "the #exhaustive pragma can only be used with enum types", + Range: ast.NewRangeFromPositioned(checker.memoryGauge, statement.Expression), + }, + ) + return false + } + + enumCases := compositeType.EnumCases + if len(enumCases) == 0 { + return true + } + + // Build a set of enum case names for quick lookup + enumCaseSet := make(map[string]struct{}, len(enumCases)) + for _, name := range enumCases { + enumCaseSet[name] = struct{}{} + } + + // Track which enum cases are covered by switch cases + coveredCases := make(map[string]struct{}) + + for _, switchCase := range statement.Cases { + if switchCase.Expression == nil { + // Default case — skip + continue + } + + memberExpr, ok := switchCase.Expression.(*ast.MemberExpression) + if !ok { + continue + } + + identExpr, ok := memberExpr.Expression.(*ast.IdentifierExpression) + if !ok { + continue + } + + // Look up the identifier in scope to verify it refers to the enum type + variable := checker.valueActivations.Find(identExpr.Identifier.Identifier) + if variable == nil { + continue + } + + funcType, ok := variable.Type.(*FunctionType) + if !ok { + continue + } + + if funcType.TypeFunctionType != compositeType { + continue + } + + // The member name is an enum case reference + memberName := memberExpr.Identifier.Identifier + if _, isEnumCase := enumCaseSet[memberName]; isEnumCase { + coveredCases[memberName] = struct{}{} + } + } + + if len(coveredCases) == len(enumCases) { + return true + } + + // Report which enum cases are missing + missingCases := make([]string, 0, len(enumCases)-len(coveredCases)) + for _, name := range enumCases { + if _, covered := coveredCases[name]; !covered { + missingCases = append(missingCases, name) + } + } + + checker.report( + &MissingSwitchCasesError{ + MissingCases: missingCases, + Range: ast.NewRangeFromPositioned(checker.memoryGauge, statement), + }, + ) + + return false } func (checker *Checker) checkSwitchCaseExpression( @@ -94,6 +197,7 @@ func (checker *Checker) checkSwitchCasesStatements( remainingCases []*ast.SwitchCase, testType Type, testTypeIsValid bool, + isExhaustive bool, ) { remainingCaseCount := len(remainingCases) if remainingCaseCount == 0 { @@ -106,6 +210,9 @@ func (checker *Checker) checkSwitchCasesStatements( // However, the default case's block must be checked directly as the "else", // because if a default case exists, the whole switch statement // will definitely have one case which will be taken. + // + // Similarly, if the switch is exhaustive over an enum type + // (via the #exhaustive pragma), the last case is treated like a default case. switchCase := remainingCases[0] @@ -137,6 +244,16 @@ func (checker *Checker) checkSwitchCasesStatements( testTypeIsValid, ) + // If this is the last case and the switch is exhaustive over an enum, + // treat this case like a default: it is guaranteed to be taken + // if none of the previous cases matched. + if remainingCaseCount == 1 && isExhaustive { + currentFunctionActivation.ReturnInfo.WithNewJumpTarget(func() { + checker.checkSwitchCaseStatements(switchCase) + }) + return + } + _, _ = checker.checkConditionalBranches( func() Type { @@ -153,6 +270,7 @@ func (checker *Checker) checkSwitchCasesStatements( remainingCases[1:], testType, testTypeIsValid, + isExhaustive, ) // ignored diff --git a/sema/errors.go b/sema/errors.go index 4df9c5492e..c7388fcf6c 100644 --- a/sema/errors.go +++ b/sema/errors.go @@ -5869,6 +5869,27 @@ func (e *MissingSwitchCaseStatementsError) EndPosition(common.MemoryGauge) ast.P return e.Pos } +// MissingSwitchCasesError + +type MissingSwitchCasesError struct { + MissingCases []string + ast.Range +} + +var _ SemanticError = &MissingSwitchCasesError{} +var _ errors.UserError = &MissingSwitchCasesError{} + +func (*MissingSwitchCasesError) isSemanticError() {} + +func (*MissingSwitchCasesError) IsUserError() {} + +func (e *MissingSwitchCasesError) Error() string { + return fmt.Sprintf( + "switch is not exhaustive: missing case(s): %s", + strings.Join(e.MissingCases, ", "), + ) +} + // MissingEntryPointError type MissingEntryPointError struct { diff --git a/sema/switch_test.go b/sema/switch_test.go index 197e1dd5d5..0b4b36e4cb 100644 --- a/sema/switch_test.go +++ b/sema/switch_test.go @@ -610,3 +610,293 @@ func TestCheckSwitchResourceInvalidation(t *testing.T) { assert.IsType(t, &sema.ResourceUseAfterInvalidationError{}, errs[0]) }) } + +func TestCheckSwitchStatementExhaustiveEnum(t *testing.T) { + + t.Parallel() + + t.Run("exhaustive, no default, definite return", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Color: UInt8 { + case red + case green + case blue + } + + fun test(c: Color): String { + #exhaustive + switch c { + case Color.red: + return "red" + case Color.green: + return "green" + case Color.blue: + return "blue" + } + } + `) + + require.NoError(t, err) + }) + + t.Run("exhaustive, with default, no error", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Color: UInt8 { + case red + case green + case blue + } + + fun test(c: Color): String { + #exhaustive + switch c { + case Color.red: + return "red" + case Color.green: + return "green" + case Color.blue: + return "blue" + default: + return "unknown" + } + } + `) + + require.NoError(t, err) + }) + + t.Run("exhaustive, unreachable code after switch", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Color: UInt8 { + case red + case green + case blue + } + + fun test(c: Color): String { + #exhaustive + switch c { + case Color.red: + return "red" + case Color.green: + return "green" + case Color.blue: + return "blue" + } + return "unreachable" + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.UnreachableStatementError{}, errs[0]) + }) + + t.Run("non-exhaustive, missing cases error", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Color: UInt8 { + case red + case green + case blue + } + + fun test(c: Color): String { + #exhaustive + switch c { + case Color.red: + return "red" + case Color.green: + return "green" + } + } + `) + + errs := RequireCheckerErrors(t, err, 2) + + assert.IsType(t, &sema.MissingSwitchCasesError{}, errs[0]) + assert.IsType(t, &sema.MissingReturnStatementError{}, errs[1]) + }) + + t.Run("single case enum, exhaustive", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Direction: UInt8 { + case up + } + + fun test(d: Direction): String { + #exhaustive + switch d { + case Direction.up: + return "up" + } + } + `) + + require.NoError(t, err) + }) + + t.Run("exhaustive, no return in body", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Color: UInt8 { + case red + case green + } + + fun test(c: Color) { + var x = 0 + #exhaustive + switch c { + case Color.red: + x = 1 + case Color.green: + x = 2 + } + } + `) + + require.NoError(t, err) + }) + + t.Run("non-member-access expression, missing cases", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Color: UInt8 { + case red + case green + } + + fun test(c: Color): String { + let r = Color.red + let g = Color.green + #exhaustive + switch c { + case r: + return "red" + case g: + return "green" + } + } + `) + + errs := RequireCheckerErrors(t, err, 2) + + assert.IsType(t, &sema.MissingSwitchCasesError{}, errs[0]) + assert.IsType(t, &sema.MissingReturnStatementError{}, errs[1]) + }) + + t.Run("without pragma, no exhaustiveness check", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Color: UInt8 { + case red + case green + case blue + } + + fun test(c: Color): String { + switch c { + case Color.red: + return "red" + case Color.green: + return "green" + case Color.blue: + return "blue" + } + } + `) + + // Without #exhaustive, the switch is not treated as exhaustive + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.MissingReturnStatementError{}, errs[0]) + }) + + t.Run("pragma on non-enum type", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test(x: Int): String { + #exhaustive + switch x { + case 1: + return "one" + default: + return "other" + } + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidPragmaError{}, errs[0]) + }) + + t.Run("pragma not followed by switch", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + fun test() { + #exhaustive + let x = 1 + } + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.InvalidPragmaError{}, errs[0]) + }) + + t.Run("duplicate case, not exhaustive", func(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + enum Color: UInt8 { + case red + case green + case blue + } + + fun test(c: Color): String { + #exhaustive + switch c { + case Color.red: + return "red" + case Color.red: + return "red again" + case Color.green: + return "green" + } + } + `) + + errs := RequireCheckerErrors(t, err, 2) + + assert.IsType(t, &sema.MissingSwitchCasesError{}, errs[0]) + assert.IsType(t, &sema.MissingReturnStatementError{}, errs[1]) + }) +} diff --git a/sema/type.go b/sema/type.go index 45bf8a58a0..424182ca09 100644 --- a/sema/type.go +++ b/sema/type.go @@ -5126,6 +5126,7 @@ type CompositeType struct { memberResolvers atomic.Pointer[map[string]MemberResolver] Identifier string Fields []string + EnumCases []string ConstructorParameters []Parameter // an internal set of field `effectiveInterfaceConformances` effectiveInterfaceConformanceSet atomic.Pointer[InterfaceSet]