diff --git a/internal/storage/batch.go b/internal/storage/batch.go index 9e342ab1d6e..2d9f0663f21 100644 --- a/internal/storage/batch.go +++ b/internal/storage/batch.go @@ -16,76 +16,92 @@ import ( type batchBuilder struct { db *sql.DB - args []any - conditions []string + args argsBuilder + where whereBuilder batchSize int limitPerHost int } func (s *Storage) NewBatchBuilder() *batchBuilder { - return &batchBuilder{ + b := batchBuilder{ db: s.db, } + + return &b } func (b *batchBuilder) WithBatchSize(batchSize int) *batchBuilder { + if batchSize <= 0 { + return b + } + b.batchSize = batchSize + return b } func (b *batchBuilder) WithUserID(userID int64) *batchBuilder { - b.conditions = append(b.conditions, "user_id = $"+strconv.Itoa(len(b.args)+1)) - b.args = append(b.args, userID) + nArgs := b.args.append(userID) + b.where.and("user_id = $" + strconv.Itoa(nArgs)) return b } func (b *batchBuilder) WithCategoryID(categoryID int64) *batchBuilder { - b.conditions = append(b.conditions, "category_id = $"+strconv.Itoa(len(b.args)+1)) - b.args = append(b.args, categoryID) + nArgs := b.args.append(categoryID) + b.where.and("category_id = $" + strconv.Itoa(nArgs)) return b } func (b *batchBuilder) WithErrorLimit(limit int) *batchBuilder { - if limit > 0 { - b.conditions = append(b.conditions, "parsing_error_count < $"+strconv.Itoa(len(b.args)+1)) - b.args = append(b.args, limit) + if limit <= 0 { + return b } + + nArgs := b.args.append(limit) + b.where.and("parsing_error_count < $" + strconv.Itoa(nArgs)) + return b } func (b *batchBuilder) WithNextCheckExpired() *batchBuilder { - b.conditions = append(b.conditions, "next_check_at < now()") + b.where.and("next_check_at < now()") return b } func (b *batchBuilder) WithoutDisabledFeeds() *batchBuilder { - b.conditions = append(b.conditions, "disabled IS false") + b.where.and("disabled IS false") return b } func (b *batchBuilder) WithLimitPerHost(limit int) *batchBuilder { - if limit > 0 { - b.limitPerHost = limit + if limit <= 0 { + return b } + + b.limitPerHost = limit + return b } // FetchJobs retrieves a batch of jobs based on the conditions set in the builder. // When limitPerHost is set, it limits the number of jobs per feed hostname to prevent overwhelming a single host. func (b *batchBuilder) FetchJobs() (model.JobList, error) { - query := `SELECT id, user_id, feed_url FROM feeds` + var qb strings.Builder - if len(b.conditions) > 0 { - query += " WHERE " + strings.Join(b.conditions, " AND ") - } + qb.WriteString(` + SELECT id, user_id, feed_url + FROM feeds + `) + + qb.WriteString(" " + b.where.String()) - query += " ORDER BY next_check_at ASC" + qb.WriteString(` ORDER BY next_check_at ASC`) if b.batchSize > 0 { - query += " LIMIT " + strconv.Itoa(b.batchSize) + qb.WriteString(` LIMIT ` + strconv.Itoa(b.batchSize)) } - rows, err := b.db.Query(query, b.args...) + rows, err := b.db.Query(qb.String(), b.args.all()...) if err != nil { return nil, fmt.Errorf(`store: unable to fetch batch of jobs: %v`, err) } diff --git a/internal/storage/entry_pagination_builder.go b/internal/storage/entry_pagination_builder.go index 18fa44a5346..e16b2daaa58 100644 --- a/internal/storage/entry_pagination_builder.go +++ b/internal/storage/entry_pagination_builder.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "strconv" - "strings" "github.com/lib/pq" "miniflux.app/v2/internal/model" @@ -16,58 +15,66 @@ import ( // entryPaginationBuilder is a builder for entry prev/next queries. type entryPaginationBuilder struct { - db *sql.DB - conditions []string - args []any - entryID int64 - order string - direction string + db *sql.DB + args argsBuilder + where whereBuilder + orderBy orderByBuilder + entryID int64 + direction string } // WithSearchQuery adds full-text search query to the condition. func (e *entryPaginationBuilder) WithSearchQuery(query string) *entryPaginationBuilder { - if query != "" { - e.conditions = append(e.conditions, fmt.Sprintf("e.document_vectors @@ plainto_tsquery($%d)", len(e.args)+1)) - e.args = append(e.args, query) + if query == "" { + return e } + nArgs := e.args.append(query) + e.where.andf("e.document_vectors @@ plainto_tsquery($%d)", nArgs) + return e } // WithStarred adds starred to the condition. func (e *entryPaginationBuilder) WithStarred() *entryPaginationBuilder { - e.conditions = append(e.conditions, "e.starred is true") + e.where.and("e.starred is true") return e } // WithFeedID adds feed_id to the condition. func (e *entryPaginationBuilder) WithFeedID(feedID int64) *entryPaginationBuilder { - if feedID != 0 { - e.conditions = append(e.conditions, "e.feed_id = $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, feedID) + if feedID == 0 { + return e } + nArgs := e.args.append(feedID) + e.where.and("e.feed_id = $" + strconv.Itoa(nArgs)) + return e } // WithCategoryID adds category_id to the condition. func (e *entryPaginationBuilder) WithCategoryID(categoryID int64) *entryPaginationBuilder { - if categoryID != 0 { - e.conditions = append(e.conditions, "f.category_id = $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, categoryID) + if categoryID == 0 { + return e } + nArgs := e.args.append(categoryID) + e.where.and("f.category_id = $" + strconv.Itoa(nArgs)) + return e } // WithStatus adds status to the condition. func (e *entryPaginationBuilder) WithStatus(status string) *entryPaginationBuilder { - if status != "" { - e.conditions = append(e.conditions, "e.status = $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, status) + if status == "" { + return e } + nArgs := e.args.append(status) + e.where.and("e.status = $" + strconv.Itoa(nArgs)) + return e } @@ -78,31 +85,31 @@ func (e *entryPaginationBuilder) WithStatusOrEntryID(status string, entryID int6 } if entryID == 0 { - e.WithStatus(status) - return e + return e.WithStatus(status) } - statusArg := len(e.args) + 1 - entryArg := len(e.args) + 2 - e.conditions = append(e.conditions, fmt.Sprintf("(e.status = $%d OR e.id = $%d)", statusArg, entryArg)) - e.args = append(e.args, status, entryID) + statusArg := e.args.append(status) + entryArg := e.args.append(entryID) + e.where.andf("(e.status = $%d OR e.id = $%d)", statusArg, entryArg) return e } -func (e *entryPaginationBuilder) WithTags(tags []string) *entryPaginationBuilder { - if len(tags) > 0 { - e.conditions = append(e.conditions, fmt.Sprintf("LOWER(e.tags::text)::text[] @> LOWER($%d::text)::text[]", len(e.args)+1)) - e.args = append(e.args, pq.Array(tags)) +func (e *entryPaginationBuilder) WithTags(tags ...string) *entryPaginationBuilder { + if len(tags) == 0 { + return e } + nArgs := e.args.append(pq.Array(tags)) + e.where.andf("LOWER(e.tags::text)::text[] @> LOWER($%d::text)::text[]", nArgs) + return e } // WithGloballyVisible adds global visibility to the condition. func (e *entryPaginationBuilder) WithGloballyVisible() *entryPaginationBuilder { - e.conditions = append(e.conditions, "not c.hide_globally") - e.conditions = append(e.conditions, "not f.hide_globally") + e.where.and("c.hide_globally IS FALSE") + e.where.and("f.hide_globally IS FALSE") return e } @@ -143,27 +150,29 @@ func (e *entryPaginationBuilder) Entries() (*model.Entry, *model.Entry, error) { func (e *entryPaginationBuilder) getPrevNextID(tx *sql.Tx) (prevID int64, nextID int64, err error) { cte := ` - WITH entry_pagination AS ( - SELECT - e.id, - lag(e.id) over (order by e.%[1]s asc, e.created_at asc, e.id desc) as prev_id, - lead(e.id) over (order by e.%[1]s asc, e.created_at asc, e.id desc) as next_id - FROM entries AS e - JOIN feeds AS f ON f.id=e.feed_id + SELECT + e.id, + lag(e.id) over (` + e.orderBy.String() + `) as prev_id, + lead(e.id) over (` + e.orderBy.String() + `) as next_id + FROM entries AS e + JOIN feeds AS f ON f.id = e.feed_id JOIN categories c ON c.id = f.category_id - WHERE %[2]s - ORDER BY e.%[1]s asc, e.created_at asc, e.id desc - ) - SELECT prev_id, next_id FROM entry_pagination AS ep WHERE %[3]s; - ` + ` + e.where.String() + " " + e.orderBy.String() + + finalWhere := whereBuilder{} + finalArgs := e.args.clone() - subCondition := strings.Join(e.conditions, " AND ") - finalCondition := "ep.id = $" + strconv.Itoa(len(e.args)+1) - query := fmt.Sprintf(cte, e.order, subCondition, finalCondition) - e.args = append(e.args, e.entryID) + nArgs := finalArgs.append(e.entryID) + finalWhere.and("ep.id = $" + strconv.Itoa(nArgs)) + + query := ` + WITH entry_pagination AS (` + cte + `) + SELECT prev_id, next_id + FROM entry_pagination AS ep + ` + finalWhere.String() var pID, nID sql.NullInt64 - err = tx.QueryRow(query, e.args...).Scan(&pID, &nID) + err = tx.QueryRow(query, finalArgs.all()...).Scan(&pID, &nID) switch { case errors.Is(err, sql.ErrNoRows): return 0, 0, nil @@ -202,12 +211,18 @@ func (e *entryPaginationBuilder) getEntry(tx *sql.Tx, entryID int64) (*model.Ent // NewEntryPaginationBuilder returns a new EntryPaginationBuilder. func (s *Storage) NewEntryPaginationBuilder(userID, entryID int64, order, direction string) *entryPaginationBuilder { - return &entryPaginationBuilder{ - db: s.db, - args: []any{userID}, - conditions: []string{"e.user_id = $1"}, - entryID: entryID, - order: pq.QuoteIdentifier(order), - direction: direction, + e := entryPaginationBuilder{ + db: s.db, + entryID: entryID, + direction: direction, } + + nArgs := e.args.append(userID) + e.where.and("e.user_id = $" + strconv.Itoa(nArgs)) + + e.orderBy.asc("e." + pq.QuoteIdentifier(order)) + e.orderBy.asc("e.created_at") + e.orderBy.desc("e.id") + + return &e } diff --git a/internal/storage/entry_query_builder.go b/internal/storage/entry_query_builder.go index 83be7c0da60..29164fdf19a 100644 --- a/internal/storage/entry_query_builder.go +++ b/internal/storage/entry_query_builder.go @@ -19,9 +19,9 @@ import ( // EntryQueryBuilder builds a SQL query to fetch entries. type EntryQueryBuilder struct { store *Storage - args []any - conditions []string - sortExpressions []string + args argsBuilder + where whereBuilder + orderBy orderByBuilder limit int offset int fetchEnclosures bool @@ -44,146 +44,180 @@ func (e *EntryQueryBuilder) WithoutContent() *EntryQueryBuilder { // WithSearchQuery adds full-text search query to the condition. func (e *EntryQueryBuilder) WithSearchQuery(query string) *EntryQueryBuilder { - if query != "" { - nArgs := len(e.args) + 1 - e.conditions = append(e.conditions, fmt.Sprintf("e.document_vectors @@ plainto_tsquery($%d)", nArgs)) - e.args = append(e.args, query) + if query == "" { + return e + } - // 0.0000001 = 0.1 / (seconds_in_a_day) + nArgs := e.args.append(query) + e.where.andf("e.document_vectors @@ plainto_tsquery($%d)", nArgs) + + // 0.0000001 = 0.1 / (seconds_in_a_day) + e.orderBy.desc( + fmt.Sprintf("ts_rank(document_vectors, plainto_tsquery($%d)) - extract (epoch from now() - published_at)::float * 0.0000001", nArgs), + ) - e.sortExpressions = append(e.sortExpressions, - fmt.Sprintf("ts_rank(document_vectors, plainto_tsquery($%d)) - extract (epoch from now() - published_at)::float * 0.0000001 DESC", nArgs), - ) - } return e } // WithStarred adds starred filter. func (e *EntryQueryBuilder) WithStarred(starred bool) *EntryQueryBuilder { - if starred { - e.conditions = append(e.conditions, "e.starred is true") - } else { - e.conditions = append(e.conditions, "e.starred is false") - } + e.where.and("e.starred is " + strconv.FormatBool(starred)) + return e } // BeforeChangedDate adds a condition < changed_at func (e *EntryQueryBuilder) BeforeChangedDate(date time.Time) *EntryQueryBuilder { - e.conditions = append(e.conditions, "e.changed_at < $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, date) + nArgs := e.args.append(date) + e.where.and("e.changed_at < $" + strconv.Itoa(nArgs)) + return e } // AfterChangedDate adds a condition > changed_at func (e *EntryQueryBuilder) AfterChangedDate(date time.Time) *EntryQueryBuilder { - e.conditions = append(e.conditions, "e.changed_at > $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, date) + nArgs := e.args.append(date) + e.where.and("e.changed_at > $" + strconv.Itoa(nArgs)) + return e } // BeforePublishedDate adds a condition < published_at func (e *EntryQueryBuilder) BeforePublishedDate(date time.Time) *EntryQueryBuilder { - e.conditions = append(e.conditions, "e.published_at < $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, date) + nArgs := e.args.append(date) + e.where.and("e.published_at < $" + strconv.Itoa(nArgs)) + return e } // AfterPublishedDate adds a condition > published_at func (e *EntryQueryBuilder) AfterPublishedDate(date time.Time) *EntryQueryBuilder { - e.conditions = append(e.conditions, "e.published_at > $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, date) + nArgs := e.args.append(date) + e.where.and("e.published_at > $" + strconv.Itoa(nArgs)) + return e } // BeforeEntryID adds a condition < entryID. func (e *EntryQueryBuilder) BeforeEntryID(entryID int64) *EntryQueryBuilder { - if entryID != 0 { - e.conditions = append(e.conditions, "e.id < $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, entryID) + if entryID == 0 { + return e } + + nArgs := e.args.append(entryID) + e.where.and("e.id < $" + strconv.Itoa(nArgs)) + return e } // AfterEntryID adds a condition > entryID. func (e *EntryQueryBuilder) AfterEntryID(entryID int64) *EntryQueryBuilder { - if entryID != 0 { - e.conditions = append(e.conditions, "e.id > $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, entryID) + if entryID == 0 { + return e } + + nArgs := e.args.append(entryID) + e.where.and("e.id > $" + strconv.Itoa(nArgs)) + return e } // WithEntryIDs filter by entry IDs. func (e *EntryQueryBuilder) WithEntryIDs(entryIDs ...int64) *EntryQueryBuilder { + if len(entryIDs) == 0 { + return e + } + if len(entryIDs) == 1 { - e.conditions = append(e.conditions, fmt.Sprintf("e.id = $%d", len(e.args)+1)) - e.args = append(e.args, entryIDs[0]) - } else if len(entryIDs) > 1 { - e.conditions = append(e.conditions, fmt.Sprintf("e.id = ANY($%d)", len(e.args)+1)) - e.args = append(e.args, pq.Int64Array(entryIDs)) + nArgs := e.args.append(entryIDs[0]) + e.where.and("e.id = $" + strconv.Itoa(nArgs)) + + return e } + + nArgs := e.args.append(pq.Int64Array(entryIDs)) + e.where.andf("e.id = ANY($%d)", nArgs) + return e } // WithFeedID filter by feed ID. func (e *EntryQueryBuilder) WithFeedID(feedID int64) *EntryQueryBuilder { - if feedID > 0 { - e.conditions = append(e.conditions, "e.feed_id = $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, feedID) + if feedID == 0 { + return e } + + nArgs := e.args.append(feedID) + e.where.and("e.feed_id = $" + strconv.Itoa(nArgs)) + return e } // WithCategoryID filter by category ID. func (e *EntryQueryBuilder) WithCategoryID(categoryID int64) *EntryQueryBuilder { - if categoryID > 0 { - e.conditions = append(e.conditions, "f.category_id = $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, categoryID) + if categoryID == 0 { + return e } + + nArgs := e.args.append(categoryID) + e.where.and("f.category_id = $" + strconv.Itoa(nArgs)) + return e } // WithStatuses filter by a list of entry statuses. func (e *EntryQueryBuilder) WithStatuses(statuses ...string) *EntryQueryBuilder { + if len(statuses) == 0 { + return e + } + if len(statuses) == 1 { - e.conditions = append(e.conditions, fmt.Sprintf("e.status = $%d", len(e.args)+1)) - e.args = append(e.args, statuses[0]) - } else if len(statuses) > 1 { - e.conditions = append(e.conditions, fmt.Sprintf("e.status = ANY($%d)", len(e.args)+1)) - e.args = append(e.args, pq.StringArray(statuses)) + nArgs := e.args.append(statuses[0]) + e.where.and("e.status = $" + strconv.Itoa(nArgs)) + + return e } + + nArgs := e.args.append(pq.StringArray(statuses)) + e.where.andf("e.status = ANY($%d)", nArgs) + return e } // WithTags filter by a list of entry tags. func (e *EntryQueryBuilder) WithTags(tags ...string) *EntryQueryBuilder { - if len(tags) > 0 { - e.conditions = append(e.conditions, fmt.Sprintf("LOWER(e.tags::text)::text[] @> LOWER($%d::text)::text[]", len(e.args)+1)) - e.args = append(e.args, pq.Array(tags)) + if len(tags) == 0 { + return e } + + nArgs := e.args.append(pq.Array(tags)) + e.where.andf("LOWER(e.tags::text)::text[] @> LOWER($%d::text)::text[]", nArgs) + return e } // WithoutStatus set the entry status that should not be returned. func (e *EntryQueryBuilder) WithoutStatus(status string) *EntryQueryBuilder { - if status != "" { - e.conditions = append(e.conditions, "e.status <> $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, status) + if status == "" { + return e } + + nArgs := e.args.append(status) + e.where.and("e.status <> $" + strconv.Itoa(nArgs)) + return e } // WithShareCode set the entry share code. func (e *EntryQueryBuilder) WithShareCode(shareCode string) *EntryQueryBuilder { - e.conditions = append(e.conditions, "e.share_code = $"+strconv.Itoa(len(e.args)+1)) - e.args = append(e.args, shareCode) + nArgs := e.args.append(shareCode) + e.where.and("e.share_code = $" + strconv.Itoa(nArgs)) return e } // WithShareCodeNotEmpty adds a filter for non-empty share code. func (e *EntryQueryBuilder) WithShareCodeNotEmpty() *EntryQueryBuilder { - e.conditions = append(e.conditions, "e.share_code <> ''") + e.where.and("e.share_code <> ''") return e } @@ -191,9 +225,9 @@ func (e *EntryQueryBuilder) WithShareCodeNotEmpty() *EntryQueryBuilder { func (e *EntryQueryBuilder) WithSorting(column, direction string) *EntryQueryBuilder { switch { case strings.EqualFold(direction, "ASC"): - e.sortExpressions = append(e.sortExpressions, pq.QuoteIdentifier(column)+" ASC") + e.orderBy.asc(pq.QuoteIdentifier(column)) case strings.EqualFold(direction, "DESC"): - e.sortExpressions = append(e.sortExpressions, pq.QuoteIdentifier(column)+" DESC") + e.orderBy.desc(pq.QuoteIdentifier(column)) } return e @@ -201,23 +235,29 @@ func (e *EntryQueryBuilder) WithSorting(column, direction string) *EntryQueryBui // WithLimit set the limit. func (e *EntryQueryBuilder) WithLimit(limit int) *EntryQueryBuilder { - if limit > 0 { - e.limit = min(limit, model.MaxEntryLimit) + if limit <= 0 { + return e } + + e.limit = min(limit, model.MaxEntryLimit) + return e } // WithOffset set the offset. func (e *EntryQueryBuilder) WithOffset(offset int) *EntryQueryBuilder { - if offset > 0 { - e.offset = offset + if offset <= 0 { + return e } + + e.offset = offset + return e } func (e *EntryQueryBuilder) WithGloballyVisible() *EntryQueryBuilder { - e.conditions = append(e.conditions, "c.hide_globally IS FALSE") - e.conditions = append(e.conditions, "f.hide_globally IS FALSE") + e.where.and("c.hide_globally IS FALSE") + e.where.and("f.hide_globally IS FALSE") return e } @@ -228,9 +268,9 @@ func (e *EntryQueryBuilder) CountEntries() (count int, err error) { FROM entries e JOIN feeds f ON f.id = e.feed_id JOIN categories c ON c.id = f.category_id - WHERE ` + e.buildCondition() + ` + e.where.String() - err = e.store.db.QueryRow(query, e.args...).Scan(&count) + err = e.store.db.QueryRow(query, e.args.all()...).Scan(&count) if err != nil { return 0, fmt.Errorf("store: unable to count entries: %v", err) } @@ -275,65 +315,70 @@ func (e *EntryQueryBuilder) GetEntriesWithCount() (model.Entries, int, error) { // When withCount is true, count(*) OVER() is included in the SELECT and the total // count of matching rows is returned; otherwise the returned count is 0. func (e *EntryQueryBuilder) fetchEntries(withCount bool) (model.Entries, int, error) { - countColumn := "" + var qb strings.Builder + + qb.WriteString(`SELECT `) + if withCount { - countColumn = "count(*) OVER()," + qb.WriteString(`count(*) OVER(),`) } - query := ` - SELECT - ` + countColumn + ` - e.id, - e.user_id, - e.feed_id, - e.hash, - e.published_at at time zone u.timezone, - e.title, - e.url, - e.comments_url, - e.author, - e.share_code, - ` + e.contentColumn() + `, - e.status, - e.starred, - e.reading_time, - e.created_at, - e.changed_at, - e.tags, - f.title as feed_title, - f.feed_url, - f.site_url, - f.description, - f.checked_at, - f.category_id, - c.title as category_title, - c.hide_globally as category_hidden, - f.scraper_rules, - f.rewrite_rules, - f.crawler, - f.user_agent, - f.cookie, - f.hide_globally, - f.no_media_player, - f.webhook_url, - fi.icon_id, - i.external_id AS icon_external_id, - u.timezone - FROM - entries e - INNER JOIN - feeds f ON f.id=e.feed_id - INNER JOIN - categories c ON c.id=f.category_id - LEFT JOIN - feed_icons fi ON fi.feed_id=f.id - LEFT JOIN - icons i ON i.id=fi.icon_id - INNER JOIN - users u ON u.id=e.user_id - WHERE ` + e.buildCondition() + " " + e.buildSorting() - - rows, err := e.store.db.Query(query, e.args...) + qb.WriteString(` + e.id, + e.user_id, + e.feed_id, + e.hash, + e.published_at at time zone u.timezone, + e.title, + e.url, + e.comments_url, + e.author, + e.share_code,` + + e.contentColumn() + ` as content,` + + `e.status, + e.starred, + e.reading_time, + e.created_at, + e.changed_at, + e.tags, + f.title as feed_title, + f.feed_url, + f.site_url, + f.description, + f.checked_at, + f.category_id, + c.title as category_title, + c.hide_globally as category_hidden, + f.scraper_rules, + f.rewrite_rules, + f.crawler, + f.user_agent, + f.cookie, + f.hide_globally, + f.no_media_player, + f.webhook_url, + fi.icon_id, + i.external_id as icon_external_id, + u.timezone + FROM + entries e + INNER JOIN + feeds f ON f.id=e.feed_id + INNER JOIN + categories c ON c.id=f.category_id + LEFT JOIN + feed_icons fi ON fi.feed_id=f.id + LEFT JOIN + icons i ON i.id=fi.icon_id + INNER JOIN + users u ON u.id=e.user_id + `) + + qb.WriteString(" " + e.where.String()) + + qb.WriteString(" " + e.buildSorting()) + + rows, err := e.store.db.Query(qb.String(), e.args.all()...) if err != nil { return nil, 0, fmt.Errorf("store: unable to get entries: %v", err) } @@ -451,9 +496,9 @@ func (e *EntryQueryBuilder) GetEntryIDs() ([]int64, error) { feeds f ON f.id=e.feed_id - WHERE ` + e.buildCondition() + " " + e.buildSorting() + ` + e.where.String() + " " + e.buildSorting() - rows, err := e.store.db.Query(query, e.args...) + rows, err := e.store.db.Query(query, e.args.all()...) if err != nil { return nil, fmt.Errorf("store: unable to get entries: %v", err) } @@ -476,40 +521,39 @@ func (e *EntryQueryBuilder) GetEntryIDs() ([]int64, error) { func (e *EntryQueryBuilder) contentColumn() string { if e.excludeContent { - return "'' AS content" + return "''" } return "e.content" } -func (e *EntryQueryBuilder) buildCondition() string { - return strings.Join(e.conditions, " AND ") -} - func (e *EntryQueryBuilder) buildSorting() string { - var parts string + var parts strings.Builder - if len(e.sortExpressions) > 0 { - parts += " ORDER BY " + strings.Join(e.sortExpressions, ", ") - } + parts.WriteString(e.orderBy.String()) if e.limit > 0 { - parts += " LIMIT " + strconv.Itoa(e.limit) + parts.WriteString(" LIMIT ") + parts.WriteString(strconv.Itoa(e.limit)) } if e.offset > 0 { - parts += " OFFSET " + strconv.Itoa(e.offset) + parts.WriteString(" OFFSET ") + parts.WriteString(strconv.Itoa(e.offset)) } - return parts + return parts.String() } // NewEntryQueryBuilder returns a new EntryQueryBuilder. func (s *Storage) NewEntryQueryBuilder(userID int64) *EntryQueryBuilder { - return &EntryQueryBuilder{ - store: s, - args: []any{userID}, - conditions: []string{"e.user_id = $1"}, + qb := EntryQueryBuilder{ + store: s, } + + nArgs := qb.args.append(userID) + qb.where.and("e.user_id = $" + strconv.Itoa(nArgs)) + + return &qb } // NewAnonymousQueryBuilder returns a new EntryQueryBuilder suitable for anonymous users. diff --git a/internal/storage/feed_query_builder.go b/internal/storage/feed_query_builder.go index 1677614a16d..a0342411a78 100644 --- a/internal/storage/feed_query_builder.go +++ b/internal/storage/feed_query_builder.go @@ -16,47 +16,61 @@ import ( // feedQueryBuilder builds a SQL query to fetch feeds. type feedQueryBuilder struct { - db *sql.DB - args []any - conditions []string - sortExpressions []string - limit int - offset int - withCounters bool - counterJoinFeeds bool - counterArgs []any - counterConditions []string + db *sql.DB + args argsBuilder + where whereBuilder + orderBy orderByBuilder + limit int + offset int + withCounters bool + counterJoinFeeds bool + counterArgs argsBuilder + counterWhere whereBuilder } // NewFeedQueryBuilder returns a new FeedQueryBuilder. func (s *Storage) NewFeedQueryBuilder(userID int64) *feedQueryBuilder { - return &feedQueryBuilder{ - db: s.db, - args: []any{userID}, - conditions: []string{"f.user_id = $1"}, - counterArgs: []any{userID, model.EntryStatusRead, model.EntryStatusUnread}, - counterConditions: []string{"e.user_id = $1", "e.status IN ($2, $3)"}, + f := feedQueryBuilder{ + db: s.db, } + + nArgs := f.args.append(userID) + f.where.and("f.user_id = $" + strconv.Itoa(nArgs)) + + cArgs := f.counterArgs.append(userID) + f.counterWhere.and("e.user_id = $" + strconv.Itoa(cArgs)) + + f.counterWhere.and("e.status IN (" + model.EntryStatusRead + ", " + model.EntryStatusUnread + ")") + + return &f } // WithCategoryID filter by category ID. func (f *feedQueryBuilder) WithCategoryID(categoryID int64) *feedQueryBuilder { - if categoryID > 0 { - f.conditions = append(f.conditions, "f.category_id = $"+strconv.Itoa(len(f.args)+1)) - f.args = append(f.args, categoryID) - f.counterConditions = append(f.counterConditions, "f.category_id = $"+strconv.Itoa(len(f.counterArgs)+1)) - f.counterArgs = append(f.counterArgs, categoryID) - f.counterJoinFeeds = true + if categoryID == 0 { + return f } + + nArgs := f.args.append(categoryID) + f.where.and("f.category_id = $" + strconv.Itoa(nArgs)) + + cArgs := f.args.append(categoryID) + f.counterWhere.and("f.category_id = $" + strconv.Itoa(cArgs)) + + f.counterJoinFeeds = true + return f } // WithFeedID filter by feed ID. func (f *feedQueryBuilder) WithFeedID(feedID int64) *feedQueryBuilder { - if feedID > 0 { - f.conditions = append(f.conditions, "f.id = $"+strconv.Itoa(len(f.args)+1)) - f.args = append(f.args, feedID) + if feedID == 0 { + return f } + + nArgs := f.args.append(feedID) + f.where.and("f.id = $" + strconv.Itoa(nArgs)) + return f } @@ -70,9 +84,9 @@ func (f *feedQueryBuilder) WithCounters() *feedQueryBuilder { func (f *feedQueryBuilder) WithSorting(column, direction string) *feedQueryBuilder { switch { case strings.EqualFold(direction, "ASC"): - f.sortExpressions = append(f.sortExpressions, pq.QuoteIdentifier(column)+" ASC") + f.orderBy.asc(pq.QuoteIdentifier(column)) case strings.EqualFold(direction, "DESC"): - f.sortExpressions = append(f.sortExpressions, pq.QuoteIdentifier(column)+" DESC") + f.orderBy.desc(pq.QuoteIdentifier(column)) } return f @@ -80,44 +94,44 @@ func (f *feedQueryBuilder) WithSorting(column, direction string) *feedQueryBuild // WithLimit set the limit. func (f *feedQueryBuilder) WithLimit(limit int) *feedQueryBuilder { + if limit <= 0 { + return f + } + f.limit = limit return f } // WithOffset set the offset. func (f *feedQueryBuilder) WithOffset(offset int) *feedQueryBuilder { + if offset <= 0 { + return f + } + f.offset = offset return f } -func (f *feedQueryBuilder) buildCondition() string { - return strings.Join(f.conditions, " AND ") -} - -func (f *feedQueryBuilder) buildCounterCondition() string { - return strings.Join(f.counterConditions, " AND ") -} - func (f *feedQueryBuilder) buildSorting() string { - var parts string + var parts strings.Builder - if len(f.sortExpressions) > 0 { - parts += " ORDER BY " + strings.Join(f.sortExpressions, ", ") - } + parts.WriteString(f.orderBy.String()) - if len(parts) > 0 { - parts += ", lower(f.title) ASC" + if parts.Len() > 0 { + parts.WriteString(", lower(f.title) ASC") } if f.limit > 0 { - parts += " LIMIT " + strconv.Itoa(f.limit) + parts.WriteString(" LIMIT ") + parts.WriteString(strconv.Itoa(f.limit)) } if f.offset > 0 { - parts += " OFFSET " + strconv.Itoa(f.offset) + parts.WriteString(" OFFSET ") + parts.WriteString(strconv.Itoa(f.offset)) } - return parts + return parts.String() } // GetFeed returns a single feed that match the condition. @@ -195,18 +209,14 @@ func (f *feedQueryBuilder) GetFeeds() (model.Feeds, error) { icons i ON i.id=fi.icon_id LEFT JOIN users u ON u.id=f.user_id - WHERE %s - %s - ` - - query = fmt.Sprintf(query, f.buildCondition(), f.buildSorting()) + ` + f.where.String() + " " + f.buildSorting() readCounters, unreadCounters, err := f.fetchFeedCounter() if err != nil { return nil, err } - rows, err := f.db.Query(query, f.args...) + rows, err := f.db.Query(query, f.args.all()...) if err != nil { return nil, fmt.Errorf(`store: unable to fetch feeds: %w`, err) } @@ -303,26 +313,23 @@ func (f *feedQueryBuilder) fetchFeedCounter() (unreadCounters map[int64]int, rea if !f.withCounters { return nil, nil, nil } - query := ` - SELECT - e.feed_id, - e.status, - count(*) - FROM - entries e - %s - WHERE - %s - GROUP BY - e.feed_id, e.status - ` - join := "" + + var qb strings.Builder + + qb.WriteString(` + SELECT e.feed_id, e.status, count(*) + FROM entries e + `) + if f.counterJoinFeeds { - join = "INNER JOIN feeds f ON f.id=e.feed_id" + qb.WriteString(` INNER JOIN feeds f ON f.id=e.feed_id`) } - query = fmt.Sprintf(query, join, f.buildCounterCondition()) - rows, err := f.db.Query(query, f.counterArgs...) + qb.WriteString(" " + f.counterWhere.String()) + + qb.WriteString(` GROUP BY e.feed_id, e.status`) + + rows, err := f.db.Query(qb.String(), f.counterArgs.all()...) if err != nil { return nil, nil, fmt.Errorf(`store: unable to fetch feed counts: %w`, err) } diff --git a/internal/storage/query_builder.go b/internal/storage/query_builder.go new file mode 100644 index 00000000000..ea7f691908f --- /dev/null +++ b/internal/storage/query_builder.go @@ -0,0 +1,87 @@ +package storage + +import ( + "fmt" + "slices" + "strings" +) + +// whereBuilder constructs WHERE expression string using [strings.Builder]. +type whereBuilder struct { + sb strings.Builder +} + +// String returns WHERE condidion string, including WHERE keyword. +func (b *whereBuilder) String() string { + return b.sb.String() +} + +func (b *whereBuilder) and(s string) { + if b.sb.Len() == 0 { + b.sb.WriteString("WHERE ") + } else { + b.sb.WriteString(" AND ") + } + + b.sb.WriteString(s) +} + +func (b *whereBuilder) andf(format string, args ...any) { + if b.sb.Len() == 0 { + b.sb.WriteString("WHERE ") + } else { + b.sb.WriteString(" AND ") + } + + fmt.Fprintf(&b.sb, format, args...) +} + +// orderByBuilder constructs ORDER BY expression string using [strings.Builder]. +type orderByBuilder struct { + sb strings.Builder +} + +// String returns ORDER BY expression string, including ORDER BY keyword. +func (b *orderByBuilder) String() string { + return b.sb.String() +} + +func (b *orderByBuilder) asc(column string) { + if b.sb.Len() == 0 { + b.sb.WriteString("ORDER BY ") + } else { + b.sb.WriteString(", ") + } + + b.sb.WriteString(column) + b.sb.WriteString(" ASC") +} + +func (b *orderByBuilder) desc(column string) { + if b.sb.Len() == 0 { + b.sb.WriteString("ORDER BY ") + } else { + b.sb.WriteString(", ") + } + + b.sb.WriteString(column) + b.sb.WriteString(" DESC") +} + +// argsBuilder collects all parametrized args. +type argsBuilder struct { + args []any +} + +func (b *argsBuilder) clone() argsBuilder { + return argsBuilder{args: slices.Clone(b.args)} +} + +func (b *argsBuilder) all() []any { + return b.args +} + +func (b *argsBuilder) append(arg any) int { + b.args = append(b.args, arg) + return len(b.args) +} diff --git a/internal/storage/user.go b/internal/storage/user.go index 22d0ca92d06..388adc26662 100644 --- a/internal/storage/user.go +++ b/internal/storage/user.go @@ -343,92 +343,16 @@ func (s *Storage) UserLanguage(userID int64) (language string) { // UserByID finds a user by the ID. func (s *Storage) UserByID(userID int64) (*model.User, error) { - query := ` - SELECT - id, - username, - is_admin, - theme, - language, - timezone, - entry_direction, - entries_per_page, - keyboard_shortcuts, - show_reading_time, - entry_swipe, - gesture_nav, - last_login_at, - stylesheet, - custom_js, - external_font_hosts, - google_id, - openid_connect_id, - display_mode, - entry_order, - default_reading_speed, - cjk_reading_speed, - default_home_page, - categories_sorting_order, - mark_read_on_view, - mark_read_on_media_player_completion, - media_playback_rate, - block_filter_entry_rules, - keep_filter_entry_rules, - always_open_external_links, - open_external_links_in_new_tab - FROM - users - WHERE - id = $1 - ` - return s.fetchUser(query, userID) + return s.UserByField("id", userID) } // UserByUsername finds a user by the username. func (s *Storage) UserByUsername(username string) (*model.User, error) { - query := ` - SELECT - id, - username, - is_admin, - theme, - language, - timezone, - entry_direction, - entries_per_page, - keyboard_shortcuts, - show_reading_time, - entry_swipe, - gesture_nav, - last_login_at, - stylesheet, - custom_js, - external_font_hosts, - google_id, - openid_connect_id, - display_mode, - entry_order, - default_reading_speed, - cjk_reading_speed, - default_home_page, - categories_sorting_order, - mark_read_on_view, - mark_read_on_media_player_completion, - media_playback_rate, - block_filter_entry_rules, - keep_filter_entry_rules, - always_open_external_links, - open_external_links_in_new_tab - FROM - users - WHERE - username=LOWER($1) - ` - return s.fetchUser(query, username) + return s.UserByField("username", strings.ToLower(username)) } // UserByField returns the user matching the given column name and value. -func (s *Storage) UserByField(field, value string) (*model.User, error) { +func (s *Storage) UserByField(field string, value any) (*model.User, error) { query := ` SELECT id, @@ -465,9 +389,9 @@ func (s *Storage) UserByField(field, value string) (*model.User, error) { FROM users WHERE - %s=$1 - ` - return s.fetchUser(fmt.Sprintf(query, pq.QuoteIdentifier(field)), value) + ` + pq.QuoteIdentifier(field) + "=$1" + + return s.fetchUser(query, value) } // AnotherUserWithFieldExists returns true if a user other than userID has the given value in the given column. diff --git a/internal/ui/entry_tag.go b/internal/ui/entry_tag.go index 784227526c2..694efbcf30d 100644 --- a/internal/ui/entry_tag.go +++ b/internal/ui/entry_tag.go @@ -52,7 +52,7 @@ func (h *handler) showTagEntryPage(w http.ResponseWriter, r *http.Request) { } prevEntry, nextEntry, err := h.store.NewEntryPaginationBuilder(user.ID, entry.ID, user.EntryOrder, user.EntryDirection). - WithTags([]string{tagName}). + WithTags(tagName). Entries() if err != nil { response.HTMLServerError(w, r, err)