Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pkg/microservice/user/core/handler/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func ListUsers(c *gin.Context) {
}

if len(args.UIDs) > 0 {
ctx.Resp, ctx.RespErr = permission.SearchUsersByUIDs(args.UIDs, ctx.Logger)
ctx.Resp, ctx.RespErr = permission.SearchUsersByUIDs(args.UIDs, args.MFAEnabled, ctx.Logger)
} else if len(args.Account) > 0 {
if len(args.IdentityType) == 0 {
args.IdentityType = config.SystemIdentityType
Expand Down Expand Up @@ -321,6 +321,7 @@ func OpenAPIListUsersBrief(c *gin.Context) {
Roles: args.Roles,
Project: args.Project,
IdentityType: args.IdentityType,
MFAEnabled: args.MFAEnabled,
}

var resp *types.UsersResp
Expand Down Expand Up @@ -383,7 +384,7 @@ func ListUsersBrief(c *gin.Context) {

var resp *types.UsersResp
if len(args.UIDs) > 0 {
resp, err = permission.SearchUsersByUIDs(args.UIDs, ctx.Logger)
resp, err = permission.SearchUsersByUIDs(args.UIDs, args.MFAEnabled, ctx.Logger)
} else if len(args.Account) > 0 {
if len(args.IdentityType) == 0 {
args.IdentityType = config.SystemIdentityType
Expand Down
1 change: 1 addition & 0 deletions pkg/microservice/user/core/repository/models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type User struct {
Email string `json:"email"`
Phone string `json:"phone"`
Account string `json:"account"`
MFAEnabled bool `gorm:"->;column:mfa_enabled;-:migration" json:"mfa_enabled"`
APIToken string `gorm:"api_token" json:"api_token"`
APITokenEnabled bool `gorm:"column:api_token_enabled;default:0" json:"api_token_enabled"`

Expand Down
44 changes: 44 additions & 0 deletions pkg/microservice/user/core/repository/orm/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,50 @@ func ListRoleByUIDAndNamespace(uid, namespace string, db *gorm.DB) ([]*models.Ne
return resp, nil
}

// ListRoleByUIDsAndNamespace lists roles for the given users in a namespace with a single query.
func ListRoleByUIDsAndNamespace(uids []string, namespace string, db *gorm.DB) (map[string][]*models.NewRole, error) {
if len(uids) == 0 {
return map[string][]*models.NewRole{}, nil
}

type uidRole struct {
UID string `gorm:"column:uid"`
ID uint `gorm:"column:id"`
Name string `gorm:"column:name"`
Description string `gorm:"column:description"`
Type int64 `gorm:"column:type"`
Namespace string `gorm:"column:namespace"`
GlobalReadOnly bool `gorm:"column:global_read_only"`
}

rows := make([]*uidRole, 0)
err := db.Table("role").
Select("role_binding.uid, role.id, role.name, role.description, role.type, role.namespace, role.global_read_only").
Joins("INNER JOIN role_binding ON role.id = role_binding.role_id").
Where("role.namespace = ?", namespace).
Where("role_binding.uid IN ?", uids).
Order("role_binding.uid ASC").
Order("role.id ASC").
Scan(&rows).Error
if err != nil {
return nil, err
}

resp := make(map[string][]*models.NewRole, len(rows))
for _, row := range rows {
resp[row.UID] = append(resp[row.UID], &models.NewRole{
ID: row.ID,
Name: row.Name,
Description: row.Description,
Type: row.Type,
Namespace: row.Namespace,
GlobalReadOnly: row.GlobalReadOnly,
})
}

return resp, nil
}

// ListRoleByUID list a set of roles that is used by specific user in ALL namespace
func ListRoleByUID(uid string, db *gorm.DB) ([]*models.NewRole, error) {
resp := make([]*models.NewRole, 0)
Expand Down
138 changes: 87 additions & 51 deletions pkg/microservice/user/core/repository/orm/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ import (
"github.com/koderover/zadig/v2/pkg/types"
)

const (
userMFAJoinClause = "LEFT JOIN user_mfa ON user_mfa.uid = user.uid"
userMFAEnabledSelectExpr = "IFNULL(user_mfa.enabled, 0) AS mfa_enabled"
)

// CreateUser create a user
func CreateUser(user *models.User, db *gorm.DB) error {
if err := db.Create(&user).Error; err != nil {
Expand All @@ -46,6 +51,22 @@ func GetUser(account string, identityType string, db *gorm.DB) (*models.User, er
return &user, nil
}

func GetUserByAccountAndMFAEnabled(account string, identityType string, mfaEnabled *bool, db *gorm.DB) (*models.User, error) {
var user models.User
query := db.Model(&models.User{}).
Select("user.*, "+userMFAEnabledSelectExpr).
Where("account = ? and identity_type = ?", account, identityType)
err := applyMFAEnabledJoinFilter(query, mfaEnabled).
First(&user).Error
if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
}
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return &user, nil
}

// GetUserByUid Get a user based on uid
func GetUserByUid(uid string, db *gorm.DB) (*models.User, error) {
var user models.User
Expand Down Expand Up @@ -80,13 +101,17 @@ func ListAllUsers(db *gorm.DB) ([]*models.User, error) {
}

// ListUsers gets a list of users based on paging constraints
func ListUsers(page int, perPage int, name string, db *gorm.DB) ([]models.User, error) {
func ListUsers(page int, perPage int, name string, mfaEnabled *bool, db *gorm.DB) ([]models.User, error) {
var (
users []models.User
err error
)

err = db.Where("name LIKE ?", "%"+name+"%").Order("account ASC").Offset((page - 1) * perPage).Limit(perPage).Find(&users).Error
query := db.Model(&models.User{}).
Select("user.*, "+userMFAEnabledSelectExpr).
Where("name LIKE ?", "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.Order("account ASC").Offset((page - 1) * perPage).Limit(perPage).Find(&users).Error

if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
Expand All @@ -95,15 +120,18 @@ func ListUsers(page int, perPage int, name string, db *gorm.DB) ([]models.User,
return users, nil
}

func ListUsersByLoginTime(page int, perPage int, name string, order setting.ListUserOrder, db *gorm.DB) ([]models.UserWithLoginTime, error) {
func ListUsersByLoginTime(page int, perPage int, name string, order setting.ListUserOrder, mfaEnabled *bool, db *gorm.DB) ([]models.UserWithLoginTime, error) {
var (
users []models.UserWithLoginTime
err error
)

err = db.Select("user.uid, user.name, user.account, user.identity_type, user.api_token_enabled, IFNULL(user_login.last_login_time, 0) as last_login_time").
query := db.Model(&models.User{}).
Select("user.uid, user.name, user.account, user.identity_type, user.api_token_enabled, "+userMFAEnabledSelectExpr+", IFNULL(user_login.last_login_time, 0) as last_login_time").
Where("user.name LIKE ?", "%"+name+"%").
Joins("LEFT JOIN user_login on user_login.uid = user.uid").
Joins("LEFT JOIN user_login on user_login.uid = user.uid")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.
Order("IFNULL(user_login.last_login_time, 0) " + string(order)).
Offset((page - 1) * perPage).
Limit(perPage).
Expand All @@ -117,39 +145,25 @@ func ListUsersByLoginTime(page int, perPage int, name string, order setting.List
return users, nil
}

// listUIDsByRoles returns distinct user uids that have any of the given role names within the namespace.
func listUIDsByRoles(roles []string, namespace string, db *gorm.DB) ([]string, error) {
var uids []string
err := db.Table("role_binding").
Distinct("role_binding.uid").
func uidSubQueryByRoles(roles []string, namespace string, db *gorm.DB) *gorm.DB {
return db.Table("role_binding").
Select("DISTINCT role_binding.uid").
Joins("INNER JOIN role ON role.id = role_binding.role_id").
Where("role.name IN ? AND role.namespace = ?", roles, namespace).
Pluck("role_binding.uid", &uids).Error

if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
}
return uids, nil
Where("role.name IN ? AND role.namespace = ?", roles, namespace)
}

// ListUsersByNameAndRoleWithLoginTime gets a list of users filtered by name and roles,
// ordered by last_login_time with pagination. It is implemented in two simple steps:
// 1. Find the uids of users that have any of the given roles (role_binding + role) within the namespace.
// 2. Query user + user_login for those uids, filter by name, order by last_login_time and paginate.
func ListUsersByNameAndRoleWithLoginTime(page int, perPage int, name string, roles []string, namespace string, order setting.ListUserOrder, db *gorm.DB) ([]models.UserWithLoginTime, error) {
uids, err := listUIDsByRoles(roles, namespace, db)
if err != nil {
return nil, err
}
if len(uids) == 0 {
return []models.UserWithLoginTime{}, nil
}

// ordered by last_login_time with pagination.
func ListUsersByNameAndRoleWithLoginTime(page int, perPage int, name string, roles []string, namespace string, order setting.ListUserOrder, mfaEnabled *bool, db *gorm.DB) ([]models.UserWithLoginTime, error) {
var users []models.UserWithLoginTime
err = db.Table("user").
Select("user.uid, user.name, user.account, user.identity_type, user.api_token_enabled, IFNULL(user_login.last_login_time, 0) AS last_login_time").
roleUIDSubQuery := uidSubQueryByRoles(roles, namespace, db)
query := db.Model(&models.User{}).
Select("user.uid, user.name, user.account, user.identity_type, user.api_token_enabled, "+userMFAEnabledSelectExpr+", IFNULL(user_login.last_login_time, 0) AS last_login_time").
Joins("LEFT JOIN user_login ON user_login.uid = user.uid").
Where("user.uid IN ? AND user.name LIKE ?", uids, "%"+name+"%").
Where("user.uid IN (?)", roleUIDSubQuery).
Where("user.name LIKE ?", "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err := query.
Order("last_login_time " + string(order)).
Offset((page - 1) * perPage).
Limit(perPage).
Expand All @@ -162,16 +176,19 @@ func ListUsersByNameAndRoleWithLoginTime(page int, perPage int, name string, rol
}

// ListUsersByNameAndRole gets a list of users based on paging constraints, the name of the user, the roles, and namespace
func ListUsersByNameAndRole(page int, perPage int, name string, roles []string, namespace string, db *gorm.DB) ([]models.User, error) {
func ListUsersByNameAndRole(page int, perPage int, name string, roles []string, namespace string, mfaEnabled *bool, db *gorm.DB) ([]models.User, error) {
var (
users []models.User
err error
)

err = db.Where("user.name LIKE ? AND role.name IN ? AND role.namespace = ?", "%"+name+"%", roles, namespace).
Joins("INNER JOIN role_binding on role_binding.uid = user.uid").
Joins("INNER JOIN role on role_binding.role_id = role.id").Order("account ASC").Offset((page - 1) * perPage).
Group("user.uid").
roleUIDSubQuery := uidSubQueryByRoles(roles, namespace, db)
query := db.Model(&models.User{}).
Select("user.*, "+userMFAEnabledSelectExpr).
Where("user.uid IN (?)", roleUIDSubQuery).
Where("user.name LIKE ?", "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.Order("account ASC").Offset((page - 1) * perPage).
Limit(perPage).
Find(&users).
Error
Expand All @@ -183,6 +200,10 @@ func ListUsersByNameAndRole(page int, perPage int, name string, roles []string,
return users, nil
}

func joinUserMFA(db *gorm.DB) *gorm.DB {
return db.Joins(userMFAJoinClause)
}

func ListUsersByGroup(groupID string, db *gorm.DB) ([]*models.User, error) {
resp := make([]*models.User, 0)

Expand All @@ -199,13 +220,16 @@ func ListUsersByGroup(groupID string, db *gorm.DB) ([]*models.User, error) {
}

// ListUsersByUIDs gets a list of users based on paging constraints
func ListUsersByUIDs(uids []string, db *gorm.DB) ([]models.User, error) {
func ListUsersByUIDs(uids []string, mfaEnabled *bool, db *gorm.DB) ([]models.User, error) {
var (
users []models.User
err error
)

err = db.Find(&users, "uid in ?", uids).Error
query := db.Model(&models.User{}).
Select("user.*, "+userMFAEnabledSelectExpr).
Where("user.uid in ?", uids)
err = applyMFAEnabledJoinFilter(query, mfaEnabled).Find(&users).Error

if err != nil && err != gorm.ErrRecordNotFound {
return nil, err
Expand Down Expand Up @@ -251,14 +275,15 @@ func DeleteUserByUid(uid string, db *gorm.DB) error {
}

// GetUsersCount gets user count
func GetUsersCount(name string) (int64, error) {
func GetUsersCount(name string, mfaEnabled *bool) (int64, error) {
var (
users []models.User
err error
count int64
)

err = repository.DB.Where("name LIKE ?", "%"+name+"%").Find(&users).Count(&count).Error
query := repository.DB.Model(&models.User{}).Where("name LIKE ?", "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.Count(&count).Error

if err != nil {
return 0, err
Expand All @@ -268,20 +293,18 @@ func GetUsersCount(name string) (int64, error) {
}

// GetUsersCountByRoles gets user count filtered by roles and namespace
func GetUsersCountByRoles(name string, roles []string, namespace string) (int64, error) {
func GetUsersCountByRoles(name string, roles []string, namespace string, mfaEnabled *bool) (int64, error) {
var (
users []models.User
err error
count int64
)

err = repository.DB.Where("user.name LIKE ? AND role.name IN ? AND role.namespace = ?", "%"+name+"%", roles, namespace).
Joins("INNER JOIN role_binding on role_binding.uid = user.uid").
Joins("INNER JOIN role on role_binding.role_id = role.id").
Group("user.uid").
Find(&users).
Count(&count).
Error
roleUIDSubQuery := uidSubQueryByRoles(roles, namespace, repository.DB)
query := repository.DB.Model(&models.User{}).
Where("user.uid IN (?)", roleUIDSubQuery).
Where("user.name LIKE ?", "%"+name+"%")
query = applyMFAEnabledJoinFilter(query, mfaEnabled)
err = query.Count(&count).Error

if err != nil {
return 0, err
Expand All @@ -290,6 +313,19 @@ func GetUsersCountByRoles(name string, roles []string, namespace string) (int64,
return count, nil
}

func applyMFAEnabledJoinFilter(db *gorm.DB, mfaEnabled *bool) *gorm.DB {
db = joinUserMFA(db)
if mfaEnabled == nil {
return db
}

if *mfaEnabled {
return db.Where("user_mfa.enabled = ?", true)
}

return db.Where("user_mfa.enabled IS NULL OR user_mfa.enabled = ?", false)
}

// UpdateUser update user info
func UpdateUser(uid string, user *models.User, db *gorm.DB) error {
if err := db.Model(&models.User{}).Where("uid = ?", uid).Updates(user).Error; err != nil {
Expand Down
11 changes: 0 additions & 11 deletions pkg/microservice/user/core/repository/orm/user_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,6 @@ func GetUserMFA(uid string, db *gorm.DB) (*models.UserMFA, error) {
return res, nil
}

func ListUserMFAsByUIDs(uids []string, db *gorm.DB) ([]*models.UserMFA, error) {
if len(uids) == 0 {
return []*models.UserMFA{}, nil
}
res := make([]*models.UserMFA, 0)
if err := db.Where("uid IN ?", uids).Find(&res).Error; err != nil {
return nil, err
}
return res, nil
}

// EnableUserMFA enables MFA for a user without allowing overwrite of an already-enabled MFA config.
func EnableUserMFA(uid, secretCipher, recoveryCodesJSON string, db *gorm.DB) error {
now := time.Now().Unix()
Expand Down
26 changes: 26 additions & 0 deletions pkg/microservice/user/core/service/permission/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,32 @@ func ListRolesByNamespaceAndUserID(projectName, uid string, log *zap.SugaredLogg
return resp, nil
}

func ListRolesByNamespaceAndUserIDs(projectName string, uids []string, log *zap.SugaredLogger) (map[string][]*types.Role, error) {
rolesByUID, err := orm.ListRoleByUIDsAndNamespace(uids, projectName, repository.DB)
if err != nil {
log.Errorf("failed to list roles in project: %s, error: %s", projectName, err)
return nil, fmt.Errorf("failed to list roles in project: %s, error: %s", projectName, err)
}

resp := make(map[string][]*types.Role, len(rolesByUID))
for uid, roles := range rolesByUID {
roleList := make([]*types.Role, 0, len(roles))
for _, role := range roles {
roleList = append(roleList, &types.Role{
ID: role.ID,
Name: role.Name,
Namespace: role.Namespace,
Description: role.Description,
Type: convertDBRoleType(role.Type),
GlobalReadOnly: role.GlobalReadOnly,
})
}
resp[uid] = roleList
}

return resp, nil
}

func GetRole(ns, name string, log *zap.SugaredLogger) (*types.DetailedRole, error) {
role, err := orm.GetRole(name, ns, repository.DB)
if err != nil {
Expand Down
Loading
Loading