Skip to content
58 changes: 37 additions & 21 deletions internal/storage/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,76 +16,92 @@ import (

type batchBuilder struct {
db *sql.DB
args []any
conditions []string
args argsBuilder
where whereBuilder
batchSize int
limitPerHost int
}

func (s *Storage) NewBatchBuilder() *batchBuilder {
return &batchBuilder{
b := batchBuilder{
db: s.db,
}

return &b
}

func (b *batchBuilder) WithBatchSize(batchSize int) *batchBuilder {
if batchSize <= 0 {
return b
}

b.batchSize = batchSize

return b
}

func (b *batchBuilder) WithUserID(userID int64) *batchBuilder {
b.conditions = append(b.conditions, "user_id = $"+strconv.Itoa(len(b.args)+1))
b.args = append(b.args, userID)
nArgs := b.args.append(userID)
b.where.and("user_id = $" + strconv.Itoa(nArgs))
return b
}

func (b *batchBuilder) WithCategoryID(categoryID int64) *batchBuilder {
b.conditions = append(b.conditions, "category_id = $"+strconv.Itoa(len(b.args)+1))
b.args = append(b.args, categoryID)
nArgs := b.args.append(categoryID)
b.where.and("category_id = $" + strconv.Itoa(nArgs))
return b
}

func (b *batchBuilder) WithErrorLimit(limit int) *batchBuilder {
if limit > 0 {
b.conditions = append(b.conditions, "parsing_error_count < $"+strconv.Itoa(len(b.args)+1))
b.args = append(b.args, limit)
if limit <= 0 {
return b
}

nArgs := b.args.append(limit)
b.where.and("parsing_error_count < $" + strconv.Itoa(nArgs))

return b
}

func (b *batchBuilder) WithNextCheckExpired() *batchBuilder {
b.conditions = append(b.conditions, "next_check_at < now()")
b.where.and("next_check_at < now()")
return b
}

func (b *batchBuilder) WithoutDisabledFeeds() *batchBuilder {
b.conditions = append(b.conditions, "disabled IS false")
b.where.and("disabled IS false")
return b
}

func (b *batchBuilder) WithLimitPerHost(limit int) *batchBuilder {
if limit > 0 {
b.limitPerHost = limit
if limit <= 0 {
return b
}

b.limitPerHost = limit

return b
}

// FetchJobs retrieves a batch of jobs based on the conditions set in the builder.
// When limitPerHost is set, it limits the number of jobs per feed hostname to prevent overwhelming a single host.
func (b *batchBuilder) FetchJobs() (model.JobList, error) {
query := `SELECT id, user_id, feed_url FROM feeds`
var qb strings.Builder

if len(b.conditions) > 0 {
query += " WHERE " + strings.Join(b.conditions, " AND ")
}
qb.WriteString(`
SELECT id, user_id, feed_url
FROM feeds
`)

qb.WriteString(" " + b.where.String())

query += " ORDER BY next_check_at ASC"
qb.WriteString(` ORDER BY next_check_at ASC`)

if b.batchSize > 0 {
query += " LIMIT " + strconv.Itoa(b.batchSize)
qb.WriteString(` LIMIT ` + strconv.Itoa(b.batchSize))
}

rows, err := b.db.Query(query, b.args...)
rows, err := b.db.Query(qb.String(), b.args.all()...)
if err != nil {
return nil, fmt.Errorf(`store: unable to fetch batch of jobs: %v`, err)
}
Expand Down
127 changes: 71 additions & 56 deletions internal/storage/entry_pagination_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,73 @@ import (
"errors"
"fmt"
"strconv"
"strings"

"github.com/lib/pq"
"miniflux.app/v2/internal/model"
)

// entryPaginationBuilder is a builder for entry prev/next queries.
type entryPaginationBuilder struct {
db *sql.DB
conditions []string
args []any
entryID int64
order string
direction string
db *sql.DB
args argsBuilder
where whereBuilder
orderBy orderByBuilder
entryID int64
direction string
}

// WithSearchQuery adds full-text search query to the condition.
func (e *entryPaginationBuilder) WithSearchQuery(query string) *entryPaginationBuilder {
if query != "" {
e.conditions = append(e.conditions, fmt.Sprintf("e.document_vectors @@ plainto_tsquery($%d)", len(e.args)+1))
e.args = append(e.args, query)
if query == "" {
return e
}

nArgs := e.args.append(query)
e.where.andf("e.document_vectors @@ plainto_tsquery($%d)", nArgs)

return e
}

// WithStarred adds starred to the condition.
func (e *entryPaginationBuilder) WithStarred() *entryPaginationBuilder {
e.conditions = append(e.conditions, "e.starred is true")
e.where.and("e.starred is true")

return e
}

// WithFeedID adds feed_id to the condition.
func (e *entryPaginationBuilder) WithFeedID(feedID int64) *entryPaginationBuilder {
if feedID != 0 {
e.conditions = append(e.conditions, "e.feed_id = $"+strconv.Itoa(len(e.args)+1))
e.args = append(e.args, feedID)
if feedID == 0 {
return e
}

nArgs := e.args.append(feedID)
e.where.and("e.feed_id = $" + strconv.Itoa(nArgs))

return e
}

// WithCategoryID adds category_id to the condition.
func (e *entryPaginationBuilder) WithCategoryID(categoryID int64) *entryPaginationBuilder {
if categoryID != 0 {
e.conditions = append(e.conditions, "f.category_id = $"+strconv.Itoa(len(e.args)+1))
e.args = append(e.args, categoryID)
if categoryID == 0 {
return e
}

nArgs := e.args.append(categoryID)
e.where.and("f.category_id = $" + strconv.Itoa(nArgs))

return e
}

// WithStatus adds status to the condition.
func (e *entryPaginationBuilder) WithStatus(status string) *entryPaginationBuilder {
if status != "" {
e.conditions = append(e.conditions, "e.status = $"+strconv.Itoa(len(e.args)+1))
e.args = append(e.args, status)
if status == "" {
return e
}

nArgs := e.args.append(status)
e.where.and("e.status = $" + strconv.Itoa(nArgs))

return e
}

Expand All @@ -78,31 +85,31 @@ func (e *entryPaginationBuilder) WithStatusOrEntryID(status string, entryID int6
}

if entryID == 0 {
e.WithStatus(status)
return e
return e.WithStatus(status)
}

statusArg := len(e.args) + 1
entryArg := len(e.args) + 2
e.conditions = append(e.conditions, fmt.Sprintf("(e.status = $%d OR e.id = $%d)", statusArg, entryArg))
e.args = append(e.args, status, entryID)
statusArg := e.args.append(status)
entryArg := e.args.append(entryID)
e.where.andf("(e.status = $%d OR e.id = $%d)", statusArg, entryArg)

return e
}

func (e *entryPaginationBuilder) WithTags(tags []string) *entryPaginationBuilder {
if len(tags) > 0 {
e.conditions = append(e.conditions, fmt.Sprintf("LOWER(e.tags::text)::text[] @> LOWER($%d::text)::text[]", len(e.args)+1))
e.args = append(e.args, pq.Array(tags))
func (e *entryPaginationBuilder) WithTags(tags ...string) *entryPaginationBuilder {
if len(tags) == 0 {
return e
}

nArgs := e.args.append(pq.Array(tags))
e.where.andf("LOWER(e.tags::text)::text[] @> LOWER($%d::text)::text[]", nArgs)

return e
}

// WithGloballyVisible adds global visibility to the condition.
func (e *entryPaginationBuilder) WithGloballyVisible() *entryPaginationBuilder {
e.conditions = append(e.conditions, "not c.hide_globally")
e.conditions = append(e.conditions, "not f.hide_globally")
e.where.and("c.hide_globally IS FALSE")
e.where.and("f.hide_globally IS FALSE")

return e
}
Expand Down Expand Up @@ -143,27 +150,29 @@ func (e *entryPaginationBuilder) Entries() (*model.Entry, *model.Entry, error) {

func (e *entryPaginationBuilder) getPrevNextID(tx *sql.Tx) (prevID int64, nextID int64, err error) {
cte := `
WITH entry_pagination AS (
SELECT
e.id,
lag(e.id) over (order by e.%[1]s asc, e.created_at asc, e.id desc) as prev_id,
lead(e.id) over (order by e.%[1]s asc, e.created_at asc, e.id desc) as next_id
FROM entries AS e
JOIN feeds AS f ON f.id=e.feed_id
SELECT
e.id,
lag(e.id) over (` + e.orderBy.String() + `) as prev_id,
lead(e.id) over (` + e.orderBy.String() + `) as next_id
FROM entries AS e
JOIN feeds AS f ON f.id = e.feed_id
JOIN categories c ON c.id = f.category_id
WHERE %[2]s
ORDER BY e.%[1]s asc, e.created_at asc, e.id desc
)
SELECT prev_id, next_id FROM entry_pagination AS ep WHERE %[3]s;
`
` + e.where.String() + " " + e.orderBy.String()

finalWhere := whereBuilder{}
finalArgs := e.args.clone()

subCondition := strings.Join(e.conditions, " AND ")
finalCondition := "ep.id = $" + strconv.Itoa(len(e.args)+1)
query := fmt.Sprintf(cte, e.order, subCondition, finalCondition)
e.args = append(e.args, e.entryID)
nArgs := finalArgs.append(e.entryID)
finalWhere.and("ep.id = $" + strconv.Itoa(nArgs))

query := `
WITH entry_pagination AS (` + cte + `)
SELECT prev_id, next_id
FROM entry_pagination AS ep
` + finalWhere.String()

var pID, nID sql.NullInt64
err = tx.QueryRow(query, e.args...).Scan(&pID, &nID)
err = tx.QueryRow(query, finalArgs.all()...).Scan(&pID, &nID)
switch {
case errors.Is(err, sql.ErrNoRows):
return 0, 0, nil
Expand Down Expand Up @@ -202,12 +211,18 @@ func (e *entryPaginationBuilder) getEntry(tx *sql.Tx, entryID int64) (*model.Ent

// NewEntryPaginationBuilder returns a new EntryPaginationBuilder.
func (s *Storage) NewEntryPaginationBuilder(userID, entryID int64, order, direction string) *entryPaginationBuilder {
return &entryPaginationBuilder{
db: s.db,
args: []any{userID},
conditions: []string{"e.user_id = $1"},
entryID: entryID,
order: pq.QuoteIdentifier(order),
direction: direction,
e := entryPaginationBuilder{
db: s.db,
entryID: entryID,
direction: direction,
}

nArgs := e.args.append(userID)
e.where.and("e.user_id = $" + strconv.Itoa(nArgs))

e.orderBy.asc("e." + pq.QuoteIdentifier(order))
e.orderBy.asc("e.created_at")
e.orderBy.desc("e.id")

return &e
}
Loading