Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions errors/rbac_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (
ErrLinkNotFound = errors.New("error: link between name1 and name2 does not exist")
ErrUseDomainParameter = errors.New("error: useDomain should be 1 parameter")
ErrInvalidFieldValuesParameter = errors.New("fieldValues requires at least one parameter")
ErrInvalidTypeDefinition = errors.New("error: invalid type definition")

// GetAllowedObjectConditions errors.
ErrObjCondition = errors.New("need to meet the prefix required by the object condition")
Expand Down
87 changes: 85 additions & 2 deletions management_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,81 @@ package casbin
import (
"errors"
"fmt"
"sort"
"strings"

"github.com/casbin/casbin/v3/constant"
"github.com/casbin/casbin/v3/util"
"github.com/casbin/govaluate"
)

func (e *Enforcer) getTypedPrincipals() ([]string, []string, bool, error) {
userSet := map[string]struct{}{}
roleSet := map[string]struct{}{}

appendByType := func(values []string) error {
for _, value := range values {
entityType, enabled, err := e.model.GetEntityType(value)
if err != nil {
return err
}
if !enabled {
return nil
}
switch entityType {
case "user":
userSet[value] = struct{}{}
case "role":
roleSet[value] = struct{}{}
}
}
return nil
}

values, err := e.model.GetValuesForFieldInPolicyAllTypesByName("p", constant.SubjectIndex)
if err != nil {
return nil, nil, false, err
}
if err := appendByType(values); err != nil {
return nil, nil, false, err
}
_, enabled, err := e.model.GetEntityType("")
if err != nil {
return nil, nil, false, err
}
if !enabled {
return nil, nil, false, nil
}

if _, err := e.model.GetAssertion("g", "g"); err == nil {
groupingPolicy, err := e.GetNamedGroupingPolicy("g")
if err != nil {
return nil, nil, false, err
}
for _, rule := range groupingPolicy {
limit := 2
if len(rule) < limit {
limit = len(rule)
}
if err := appendByType(rule[:limit]); err != nil {
return nil, nil, false, err
}
}
}

users := make([]string, 0, len(userSet))
for user := range userSet {
users = append(users, user)
}
roles := make([]string, 0, len(roleSet))
for role := range roleSet {
roles = append(roles, role)
}
sort.Strings(users)
sort.Strings(roles)
return users, roles, true, nil
}

// GetAllSubjects gets the list of subjects that show up in the current policy.
func (e *Enforcer) GetAllSubjects() ([]string, error) {
return e.model.GetValuesForFieldInPolicyAllTypesByName("p", constant.SubjectIndex)
Expand Down Expand Up @@ -68,6 +136,13 @@ func (e *Enforcer) GetAllNamedActions(ptype string) ([]string, error) {

// GetAllRoles gets the list of roles that show up in the current policy.
func (e *Enforcer) GetAllRoles() ([]string, error) {
_, roles, enabled, err := e.getTypedPrincipals()
if err != nil {
return nil, err
}
if enabled {
return roles, nil
}
return e.model.GetValuesForFieldInPolicyAllTypes("g", 1)
}

Expand All @@ -79,6 +154,14 @@ func (e *Enforcer) GetAllNamedRoles(ptype string) ([]string, error) {
// GetAllUsers gets the list of users that show up in the current policy.
// Users are subjects that are not roles (i.e., subjects that do not appear as the second element in any grouping policy).
func (e *Enforcer) GetAllUsers() ([]string, error) {
users, _, enabled, err := e.getTypedPrincipals()
if err != nil {
return nil, err
}
if enabled {
return users, nil
}

subjects, err := e.GetAllSubjects()
if err != nil {
return nil, err
Expand All @@ -89,8 +172,8 @@ func (e *Enforcer) GetAllUsers() ([]string, error) {
return nil, err
}

users := util.SetSubtract(subjects, roles)
return users, nil
result := util.SetSubtract(subjects, roles)
return result, nil
}

// GetPolicy gets all the authorization rules in the policy.
Expand Down
11 changes: 11 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ var sectionNameMap = map[string]string{
"e": "policy_effect",
"m": "matchers",
"c": "constraint_definition",
"t": "type_definition",
}

// Minimal required sections for a model to be valid.
Expand Down Expand Up @@ -118,6 +119,12 @@ func getKeySuffix(i int) string {
}

func loadSection(model Model, cfg config.ConfigInterface, sec string) {
if sec == "t" {
loadAssertion(model, cfg, sec, userTypeKey)
loadAssertion(model, cfg, sec, roleTypeKey)
return
}

i := 1
for {
if !loadAssertion(model, cfg, sec, sec+getKeySuffix(i)) {
Expand Down Expand Up @@ -203,6 +210,10 @@ func (model Model) loadModelFromConfig(cfg config.ConfigInterface) error {
return err
}

if err := model.ValidateTypeDefinitions(); err != nil {
return err
}

return nil
}

Expand Down
11 changes: 11 additions & 0 deletions model/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ func (model Model) AddPolicy(sec string, ptype string, rule []string) error {
if err != nil {
return err
}
if err := model.ValidatePolicyTypes(sec, ptype, rule); err != nil {
return err
}
assertion.Policy = append(assertion.Policy, rule)
assertion.PolicyMap[strings.Join(rule, DefaultSep)] = len(model[sec][ptype].Policy) - 1

Expand Down Expand Up @@ -282,6 +285,9 @@ func (model Model) UpdatePolicy(sec string, ptype string, oldRule []string, newR
if err != nil {
return false, err
}
if err := model.ValidatePolicyTypes(sec, ptype, newRule); err != nil {
return false, err
}
oldPolicy := strings.Join(oldRule, DefaultSep)
index, ok := model[sec][ptype].PolicyMap[oldPolicy]
if !ok {
Expand All @@ -301,6 +307,11 @@ func (model Model) UpdatePolicies(sec string, ptype string, oldRules, newRules [
if err != nil {
return false, err
}
for _, newRule := range newRules {
if err := model.ValidatePolicyTypes(sec, ptype, newRule); err != nil {
return false, err
}
}
rollbackFlag := false
// index -> []{oldIndex, newIndex}
modifiedRuleIndex := make(map[int][]int)
Expand Down
150 changes: 150 additions & 0 deletions model/type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright 2026 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package model

import (
"fmt"
"strings"

"github.com/casbin/casbin/v3/constant"
Err "github.com/casbin/casbin/v3/errors"
)

const (
typeDefinitionSection = "t"
userTypeKey = "user"
roleTypeKey = "role"
)

type entityType string

const (
entityTypeUnknown entityType = ""
entityTypeUser entityType = "user"
entityTypeRole entityType = "role"
)

type typeDefinition struct {
userPrefix string
rolePrefix string
}

func (model Model) getTypeDefinition() (*typeDefinition, bool, error) {
section := model[typeDefinitionSection]
if len(section) == 0 {
return nil, false, nil
}

userAssertion, hasUser := section[userTypeKey]
roleAssertion, hasRole := section[roleTypeKey]
if !hasUser || !hasRole {
return nil, false, fmt.Errorf("%w: type_definition must define both user and role", Err.ErrInvalidTypeDefinition)
}

userPrefix := strings.TrimSpace(userAssertion.Value)
rolePrefix := strings.TrimSpace(roleAssertion.Value)
if userPrefix == "" || rolePrefix == "" {
return nil, false, fmt.Errorf("%w: user and role prefixes cannot be empty", Err.ErrInvalidTypeDefinition)
}
if userPrefix == rolePrefix {
return nil, false, fmt.Errorf("%w: user and role prefixes must be different", Err.ErrInvalidTypeDefinition)
}
if strings.HasPrefix(userPrefix, rolePrefix) || strings.HasPrefix(rolePrefix, userPrefix) {
return nil, false, fmt.Errorf("%w: user and role prefixes must not overlap", Err.ErrInvalidTypeDefinition)
}

return &typeDefinition{userPrefix: userPrefix, rolePrefix: rolePrefix}, true, nil
}

func (model Model) ValidateTypeDefinitions() error {
_, _, err := model.getTypeDefinition()
return err
}

func (model Model) GetEntityType(name string) (string, bool, error) {
def, enabled, err := model.getTypeDefinition()
if err != nil || !enabled {
return "", enabled, err
}

switch {
case strings.HasPrefix(name, def.userPrefix):
return string(entityTypeUser), true, nil
case strings.HasPrefix(name, def.rolePrefix):
return string(entityTypeRole), true, nil
default:
return "", true, nil
}
}

func (model Model) ValidatePolicyTypes(sec string, ptype string, rule []string) error {
def, enabled, err := model.getTypeDefinition()
if err != nil || !enabled {
return err
}

switch sec {
case "p":
index, err := model.GetFieldIndex(ptype, constant.SubjectIndex)
if err != nil {
return err
}
if index >= len(rule) {
return nil
}
return validateEntityType(rule[index], ptype+".sub", def, entityTypeUser, entityTypeRole)
case "g":
if ptype != "g" || len(rule) < 2 {
return nil
}
if err := validateEntityType(rule[0], ptype+"[0]", def, entityTypeUser, entityTypeRole); err != nil {
return err
}
return validateEntityType(rule[1], ptype+"[1]", def, entityTypeRole)
default:
return nil
}
}

func validateEntityType(name string, field string, def *typeDefinition, allowed ...entityType) error {
actual := getEntityType(name, def)
if actual == entityTypeUnknown {
return fmt.Errorf("type mismatch for %s: %q does not match any configured user/role prefix", field, name)
}

for _, allowedType := range allowed {
if actual == allowedType {
return nil
}
}

expected := make([]string, 0, len(allowed))
for _, allowedType := range allowed {
expected = append(expected, string(allowedType))
}

return fmt.Errorf("type mismatch for %s: %q is %s, expected %s", field, name, actual, strings.Join(expected, " or "))
}

func getEntityType(name string, def *typeDefinition) entityType {
switch {
case strings.HasPrefix(name, def.userPrefix):
return entityTypeUser
case strings.HasPrefix(name, def.rolePrefix):
return entityTypeRole
default:
return entityTypeUnknown
}
}
22 changes: 22 additions & 0 deletions rbac_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,28 @@ func (e *Enforcer) GetNamedImplicitPermissionsForUser(ptype string, gtype string
// GetImplicitUsersForPermission("data1", "read") will get: ["alice", "bob"].
// Note: only users will be returned, roles (2nd arg in "g") will be excluded.
func (e *Enforcer) GetImplicitUsersForPermission(permission ...string) ([]string, error) {
if _, _, enabled, err := e.getTypedPrincipals(); err != nil {
return nil, err
} else if enabled {
subjects, err := e.GetAllUsers()
if err != nil {
return nil, err
}

res := []string{}
for _, user := range subjects {
req := util.JoinSliceAny(user, permission...)
allowed, err := e.Enforce(req...)
if err != nil {
return nil, err
}
if allowed {
res = append(res, user)
}
}
return res, nil
}

pSubjects, err := e.GetAllSubjects()
if err != nil {
return nil, err
Expand Down
Loading