update
Some checks failed
Pipeline: Test, Lint, Build / Get version info (push) Has been cancelled
Pipeline: Test, Lint, Build / Lint Go code (push) Has been cancelled
Pipeline: Test, Lint, Build / Test Go code (push) Has been cancelled
Pipeline: Test, Lint, Build / Test JS code (push) Has been cancelled
Pipeline: Test, Lint, Build / Lint i18n files (push) Has been cancelled
Pipeline: Test, Lint, Build / Check Docker configuration (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (darwin/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (darwin/arm64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/386) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v5) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v6) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v7) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (windows/386) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (windows/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Push to GHCR (push) Has been cancelled
Pipeline: Test, Lint, Build / Push to Docker Hub (push) Has been cancelled
Pipeline: Test, Lint, Build / Cleanup digest artifacts (push) Has been cancelled
Pipeline: Test, Lint, Build / Build Windows installers (push) Has been cancelled
Pipeline: Test, Lint, Build / Package/Release (push) Has been cancelled
Pipeline: Test, Lint, Build / Upload Linux PKG (push) Has been cancelled
Close stale issues and PRs / stale (push) Has been cancelled
POEditor import / update-translations (push) Has been cancelled

This commit is contained in:
2025-12-08 16:16:23 +01:00
commit c251f174ed
1349 changed files with 194301 additions and 0 deletions

View File

@@ -0,0 +1,442 @@
package persistence
import (
"context"
"encoding/json"
"fmt"
"maps"
"slices"
"strings"
"sync"
"time"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/google/uuid"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
type albumRepository struct {
sqlRepository
ms *MeilisearchService
}
type dbAlbum struct {
*model.Album `structs:",flatten"`
Discs string `structs:"-" json:"discs"`
Participants string `structs:"-" json:"-"`
Tags string `structs:"-" json:"-"`
FolderIDs string `structs:"-" json:"-"`
}
func (a *dbAlbum) PostScan() error {
var err error
if a.Discs != "" {
if err = json.Unmarshal([]byte(a.Discs), &a.Album.Discs); err != nil {
return fmt.Errorf("parsing album discs from db: %w", err)
}
}
a.Album.Participants, err = unmarshalParticipants(a.Participants)
if err != nil {
return fmt.Errorf("parsing album from db: %w", err)
}
if a.Tags != "" {
a.Album.Tags, err = unmarshalTags(a.Tags)
if err != nil {
return fmt.Errorf("parsing album from db: %w", err)
}
a.Genre, a.Genres = a.Album.Tags.ToGenres()
}
if a.FolderIDs != "" {
var ids []string
if err = json.Unmarshal([]byte(a.FolderIDs), &ids); err != nil {
return fmt.Errorf("parsing album folder_ids from db: %w", err)
}
a.Album.FolderIDs = ids
}
return nil
}
func (a *dbAlbum) PostMapArgs(args map[string]any) error {
fullText := []string{a.Name, a.SortAlbumName, a.AlbumArtist}
fullText = append(fullText, a.Album.Participants.AllNames()...)
fullText = append(fullText, slices.Collect(maps.Values(a.Album.Discs))...)
fullText = append(fullText, a.Album.Tags[model.TagAlbumVersion]...)
fullText = append(fullText, a.Album.Tags[model.TagCatalogNumber]...)
args["full_text"] = formatFullText(fullText...)
args["tags"] = marshalTags(a.Album.Tags)
args["participants"] = marshalParticipants(a.Album.Participants)
folderIDs, err := json.Marshal(a.Album.FolderIDs)
if err != nil {
return fmt.Errorf("marshalling album folder_ids: %w", err)
}
args["folder_ids"] = string(folderIDs)
b, err := json.Marshal(a.Album.Discs)
if err != nil {
return fmt.Errorf("marshalling album discs: %w", err)
}
args["discs"] = string(b)
return nil
}
type dbAlbums []dbAlbum
func (as dbAlbums) toModels() model.Albums {
return slice.Map(as, func(a dbAlbum) model.Album { return *a.Album })
}
func NewAlbumRepository(ctx context.Context, db dbx.Builder, ms *MeilisearchService) model.AlbumRepository {
r := &albumRepository{}
r.ctx = ctx
r.db = db
r.ms = ms
r.tableName = "album"
r.registerModel(&model.Album{}, albumFilters())
r.setSortMappings(map[string]string{
"name": "order_album_name, order_album_artist_name",
"artist": "compilation, order_album_artist_name, order_album_name",
"album_artist": "compilation, order_album_artist_name, order_album_name",
// TODO Rename this to just year (or date)
"max_year": "coalesce(nullif(original_date,''), cast(max_year as text)), release_date, name",
"random": "random",
"recently_added": recentlyAddedSort(),
"starred_at": "starred, starred_at",
"rated_at": "rating, rated_at",
})
return r
}
var albumFilters = sync.OnceValue(func() map[string]filterFunc {
filters := map[string]filterFunc{
"id": idFilter("album"),
"name": fullTextFilter("album", "mbz_album_id", "mbz_release_group_id"),
"compilation": booleanFilter,
"artist_id": artistFilter,
"year": yearFilter,
"recently_played": recentlyPlayedFilter,
"starred": booleanFilter,
"has_rating": hasRatingFilter,
"missing": booleanFilter,
"genre_id": tagIDFilter,
"role_total_id": allRolesFilter,
"library_id": libraryIdFilter,
}
// Add all album tags as filters
for tag := range model.AlbumLevelTags() {
filters[string(tag)] = tagIDFilter
}
for role := range model.AllRoles {
filters["role_"+role+"_id"] = artistRoleFilter
}
return filters
})
func recentlyAddedSort() string {
if conf.Server.RecentlyAddedByModTime {
return "updated_at"
}
return "created_at"
}
func recentlyPlayedFilter(string, interface{}) Sqlizer {
return Gt{"play_count": 0}
}
func hasRatingFilter(string, interface{}) Sqlizer {
return Gt{"rating": 0}
}
func yearFilter(_ string, value interface{}) Sqlizer {
return Or{
And{
Gt{"min_year": 0},
LtOrEq{"min_year": value},
GtOrEq{"max_year": value},
},
Eq{"max_year": value},
}
}
func artistFilter(_ string, value interface{}) Sqlizer {
return Or{
Exists("json_tree(participants, '$.albumartist')", Eq{"value": value}),
Exists("json_tree(participants, '$.artist')", Eq{"value": value}),
}
}
func artistRoleFilter(name string, value interface{}) Sqlizer {
roleName := strings.TrimSuffix(strings.TrimPrefix(name, "role_"), "_id")
// Check if the role name is valid. If not, return an invalid filter
if _, ok := model.AllRoles[roleName]; !ok {
return Gt{"": nil}
}
return Exists(fmt.Sprintf("json_tree(participants, '$.%s')", roleName), Eq{"value": value})
}
func allRolesFilter(_ string, value interface{}) Sqlizer {
return Like{"participants": fmt.Sprintf(`%%"%s"%%`, value)}
}
func (r *albumRepository) CountAll(options ...model.QueryOptions) (int64, error) {
query := r.newSelect()
query = r.withAnnotation(query, "album.id")
query = r.applyLibraryFilter(query)
return r.count(query, options...)
}
func (r *albumRepository) Exists(id string) (bool, error) {
return r.exists(Eq{"album.id": id})
}
func (r *albumRepository) Put(al *model.Album) error {
al.ImportedAt = time.Now()
id, err := r.put(al.ID, &dbAlbum{Album: al})
if err != nil {
return err
}
al.ID = id
if len(al.Participants) > 0 {
err = r.updateParticipants(al.ID, al.Participants)
if err != nil {
return err
}
}
if r.ms != nil {
r.ms.IndexAlbum(al)
}
return err
}
// TODO Move external metadata to a separated table
func (r *albumRepository) UpdateExternalInfo(al *model.Album) error {
_, err := r.put(al.ID, &dbAlbum{Album: al}, "description", "small_image_url", "medium_image_url", "large_image_url", "external_url", "external_info_updated_at")
return err
}
func (r *albumRepository) selectAlbum(options ...model.QueryOptions) SelectBuilder {
sql := r.newSelect(options...).Columns("album.*", "library.path as library_path", "library.name as library_name").
LeftJoin("library on album.library_id = library.id")
sql = r.withAnnotation(sql, "album.id")
return r.applyLibraryFilter(sql)
}
func (r *albumRepository) Get(id string) (*model.Album, error) {
res, err := r.GetAll(model.QueryOptions{Filters: Eq{"album.id": id}})
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, model.ErrNotFound
}
return &res[0], nil
}
func (r *albumRepository) GetAll(options ...model.QueryOptions) (model.Albums, error) {
sq := r.selectAlbum(options...)
var res dbAlbums
err := r.queryAll(sq, &res)
if err != nil {
return nil, err
}
return res.toModels(), err
}
func (r *albumRepository) CopyAttributes(fromID, toID string, columns ...string) error {
var from dbx.NullStringMap
err := r.queryOne(Select(columns...).From(r.tableName).Where(Eq{"id": fromID}), &from)
if err != nil {
return fmt.Errorf("getting album to copy fields from: %w", err)
}
to := make(map[string]interface{})
for _, col := range columns {
to[col] = from[col]
}
_, err = r.executeSQL(Update(r.tableName).SetMap(to).Where(Eq{"id": toID}))
return err
}
// Touch flags an album as being scanned by the scanner, but not necessarily updated.
// This is used for when missing tracks are detected for an album during scan.
func (r *albumRepository) Touch(ids ...string) error {
if len(ids) == 0 {
return nil
}
for ids := range slices.Chunk(ids, 200) {
upd := Update(r.tableName).Set("imported_at", time.Now()).Where(Eq{"id": ids})
c, err := r.executeSQL(upd)
if err != nil {
return fmt.Errorf("error touching albums: %w", err)
}
log.Debug(r.ctx, "Touching albums", "ids", ids, "updated", c)
}
return nil
}
// TouchByMissingFolder touches all albums that have missing folders
func (r *albumRepository) TouchByMissingFolder() (int64, error) {
upd := Update(r.tableName).Set("imported_at", time.Now()).
Where(And{
NotEq{"folder_ids": nil},
ConcatExpr("EXISTS (SELECT 1 FROM json_each(folder_ids) AS je JOIN main.folder AS f ON je.value = f.id WHERE f.missing = true)"),
})
c, err := r.executeSQL(upd)
if err != nil {
return 0, fmt.Errorf("error touching albums by missing folder: %w", err)
}
return c, nil
}
// GetTouchedAlbums returns all albums that were touched by the scanner for a given library, in the
// current library scan run.
// It does not need to load participants, as they are not used by the scanner.
func (r *albumRepository) GetTouchedAlbums(libID int) (model.AlbumCursor, error) {
query := r.selectAlbum().
Where(And{
Eq{"library.id": libID},
ConcatExpr("album.imported_at > library.last_scan_at"),
})
cursor, err := queryWithStableResults[dbAlbum](r.sqlRepository, query)
if err != nil {
return nil, err
}
return func(yield func(model.Album, error) bool) {
for a, err := range cursor {
if a.Album == nil {
yield(model.Album{}, fmt.Errorf("unexpected nil album: %v", a))
return
}
if !yield(*a.Album, err) || err != nil {
return
}
}
}, nil
}
// RefreshPlayCounts updates the play count and last play date annotations for all albums, based
// on the media files associated with them.
func (r *albumRepository) RefreshPlayCounts() (int64, error) {
query := Expr(`
with play_counts as (
select user_id, album_id, sum(play_count) as total_play_count, max(play_date) as last_play_date
from media_file
join annotation on item_id = media_file.id
group by user_id, album_id
)
insert into annotation (user_id, item_id, item_type, play_count, play_date)
select user_id, album_id, 'album', total_play_count, last_play_date
from play_counts
where total_play_count > 0
on conflict (user_id, item_id, item_type) do update
set play_count = excluded.play_count,
play_date = excluded.play_date;
`)
return r.executeSQL(query)
}
func (r *albumRepository) purgeEmpty(libraryIDs ...int) error {
del := Delete(r.tableName).Where("id not in (select distinct(album_id) from media_file)")
// If libraryIDs are specified, only purge albums from those libraries
if len(libraryIDs) > 0 {
del = del.Where(Eq{"library_id": libraryIDs})
}
c, err := r.executeSQL(del)
if err != nil {
return fmt.Errorf("purging empty albums: %w", err)
}
// TODO: Delete from Meilisearch.
// Since purgeEmpty executes a DELETE statement without returning IDs, we can't easily sync Meilisearch here.
// Ideally we should select IDs first, then delete. But this is a cleanup task.
// For now we skip Meilisearch deletion here as it requires more changes.
// The stale entries in Meilisearch will just return empty results when fetched from DB, which Search handles.
if c > 0 {
log.Debug(r.ctx, "Purged empty albums", "totalDeleted", c)
}
return nil
}
func (r *albumRepository) Search(q string, offset int, size int, options ...model.QueryOptions) (model.Albums, error) {
var res dbAlbums
if uuid.Validate(q) == nil {
err := r.searchByMBID(r.selectAlbum(options...), q, []string{"mbz_album_id", "mbz_release_group_id"}, &res)
if err != nil {
return nil, fmt.Errorf("searching album by MBID %q: %w", q, err)
}
} else {
if r.ms != nil {
ids, err := r.ms.Search("albums", q, offset, size)
if err == nil {
if len(ids) == 0 {
return model.Albums{}, nil
}
// Fetch matching albums from the database
albums, err := r.GetAll(model.QueryOptions{Filters: Eq{"album.id": ids}})
if err != nil {
return nil, fmt.Errorf("fetching albums from meilisearch ids: %w", err)
}
// Reorder results to match Meilisearch order
idMap := make(map[string]model.Album, len(albums))
for _, a := range albums {
idMap[a.ID] = a
}
sorted := make(model.Albums, 0, len(albums))
for _, id := range ids {
if a, ok := idMap[id]; ok {
sorted = append(sorted, a)
}
}
return sorted, nil
}
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
}
err := r.doSearch(r.selectAlbum(options...), q, offset, size, &res, "album.rowid", "name")
if err != nil {
return nil, fmt.Errorf("searching album by query %q: %w", q, err)
}
}
return res.toModels(), nil
}
func (r *albumRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *albumRepository) Read(id string) (interface{}, error) {
return r.Get(id)
}
func (r *albumRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
if len(options) > 0 && r.ms != nil {
if name, ok := options[0].Filters["name"].(string); ok && name != "" {
ids, err := r.ms.Search("albums", name, 0, 10000)
if err == nil {
log.Debug(r.ctx, "Meilisearch found matches", "count", len(ids), "query", name)
delete(options[0].Filters, "name")
options[0].Filters["id"] = ids
} else {
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
}
}
}
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *albumRepository) EntityName() string {
return "album"
}
func (r *albumRepository) NewInstance() interface{} {
return &model.Album{}
}
var _ model.AlbumRepository = (*albumRepository)(nil)
var _ model.ResourceRepository = (*albumRepository)(nil)

View File

@@ -0,0 +1,524 @@
package persistence
import (
"fmt"
"time"
"github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/consts"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("AlbumRepository", func() {
var albumRepo *albumRepository
BeforeEach(func() {
ctx := request.WithUser(GinkgoT().Context(), model.User{ID: "userid", UserName: "johndoe"})
albumRepo = NewAlbumRepository(ctx, GetDBXBuilder(), nil).(*albumRepository)
})
Describe("Get", func() {
var Get = func(id string) (*model.Album, error) {
album, err := albumRepo.Get(id)
if album != nil {
album.ImportedAt = time.Time{}
}
return album, err
}
It("returns an existent album", func() {
Expect(Get("103")).To(Equal(&albumRadioactivity))
})
It("returns ErrNotFound when the album does not exist", func() {
_, err := Get("666")
Expect(err).To(MatchError(model.ErrNotFound))
})
})
Describe("GetAll", func() {
var GetAll = func(opts ...model.QueryOptions) (model.Albums, error) {
albums, err := albumRepo.GetAll(opts...)
for i := range albums {
albums[i].ImportedAt = time.Time{}
}
return albums, err
}
It("returns all records", func() {
Expect(GetAll()).To(Equal(testAlbums))
})
It("returns all records sorted", func() {
Expect(GetAll(model.QueryOptions{Sort: "name"})).To(Equal(model.Albums{
albumAbbeyRoad,
albumMultiDisc,
albumRadioactivity,
albumSgtPeppers,
}))
})
It("returns all records sorted desc", func() {
Expect(GetAll(model.QueryOptions{Sort: "name", Order: "desc"})).To(Equal(model.Albums{
albumSgtPeppers,
albumRadioactivity,
albumMultiDisc,
albumAbbeyRoad,
}))
})
It("paginates the result", func() {
Expect(GetAll(model.QueryOptions{Offset: 1, Max: 1})).To(Equal(model.Albums{
albumAbbeyRoad,
}))
})
})
Describe("Album.PlayCount", func() {
// Implementation is in withAnnotation() method
DescribeTable("normalizes play count when AlbumPlayCountMode is absolute",
func(songCount, playCount, expected int) {
conf.Server.AlbumPlayCountMode = consts.AlbumPlayCountModeAbsolute
newID := id.NewRandom()
Expect(albumRepo.Put(&model.Album{LibraryID: 1, ID: newID, Name: "name", SongCount: songCount})).To(Succeed())
for i := 0; i < playCount; i++ {
Expect(albumRepo.IncPlayCount(newID, time.Now())).To(Succeed())
}
album, err := albumRepo.Get(newID)
Expect(err).ToNot(HaveOccurred())
Expect(album.PlayCount).To(Equal(int64(expected)))
},
Entry("1 song, 0 plays", 1, 0, 0),
Entry("1 song, 4 plays", 1, 4, 4),
Entry("3 songs, 6 plays", 3, 6, 6),
Entry("10 songs, 6 plays", 10, 6, 6),
Entry("70 songs, 70 plays", 70, 70, 70),
Entry("10 songs, 50 plays", 10, 50, 50),
Entry("120 songs, 121 plays", 120, 121, 121),
)
DescribeTable("normalizes play count when AlbumPlayCountMode is normalized",
func(songCount, playCount, expected int) {
conf.Server.AlbumPlayCountMode = consts.AlbumPlayCountModeNormalized
newID := id.NewRandom()
Expect(albumRepo.Put(&model.Album{LibraryID: 1, ID: newID, Name: "name", SongCount: songCount})).To(Succeed())
for i := 0; i < playCount; i++ {
Expect(albumRepo.IncPlayCount(newID, time.Now())).To(Succeed())
}
album, err := albumRepo.Get(newID)
Expect(err).ToNot(HaveOccurred())
Expect(album.PlayCount).To(Equal(int64(expected)))
},
Entry("1 song, 0 plays", 1, 0, 0),
Entry("1 song, 4 plays", 1, 4, 4),
Entry("3 songs, 6 plays", 3, 6, 2),
Entry("10 songs, 6 plays", 10, 6, 1),
Entry("70 songs, 70 plays", 70, 70, 1),
Entry("10 songs, 50 plays", 10, 50, 5),
Entry("120 songs, 121 plays", 120, 121, 1),
)
})
Describe("dbAlbum mapping", func() {
var (
a model.Album
dba *dbAlbum
args map[string]any
)
BeforeEach(func() {
a = al(model.Album{ID: "1", Name: "name"})
dba = &dbAlbum{Album: &a, Participants: "{}"}
args = make(map[string]any)
})
Describe("PostScan", func() {
It("parses Discs correctly", func() {
dba.Discs = `{"1":"disc1","2":"disc2"}`
Expect(dba.PostScan()).To(Succeed())
Expect(dba.Album.Discs).To(Equal(model.Discs{1: "disc1", 2: "disc2"}))
})
It("parses Participants correctly", func() {
dba.Participants = `{"composer":[{"id":"1","name":"Composer 1"}],` +
`"artist":[{"id":"2","name":"Artist 2"},{"id":"3","name":"Artist 3","subRole":"subRole"}]}`
Expect(dba.PostScan()).To(Succeed())
Expect(dba.Album.Participants).To(HaveLen(2))
Expect(dba.Album.Participants).To(HaveKeyWithValue(
model.RoleFromString("composer"),
model.ParticipantList{{Artist: model.Artist{ID: "1", Name: "Composer 1"}}},
))
Expect(dba.Album.Participants).To(HaveKeyWithValue(
model.RoleFromString("artist"),
model.ParticipantList{{Artist: model.Artist{ID: "2", Name: "Artist 2"}}, {Artist: model.Artist{ID: "3", Name: "Artist 3"}, SubRole: "subRole"}},
))
})
It("parses Tags correctly", func() {
dba.Tags = `{"genre":[{"id":"1","value":"rock"},{"id":"2","value":"pop"}],"mood":[{"id":"3","value":"happy"}]}`
Expect(dba.PostScan()).To(Succeed())
Expect(dba.Album.Tags).To(HaveKeyWithValue(
model.TagName("mood"), []string{"happy"},
))
Expect(dba.Album.Tags).To(HaveKeyWithValue(
model.TagName("genre"), []string{"rock", "pop"},
))
Expect(dba.Album.Genre).To(Equal("rock"))
Expect(dba.Album.Genres).To(HaveLen(2))
})
It("parses Paths correctly", func() {
dba.FolderIDs = `["folder1","folder2"]`
Expect(dba.PostScan()).To(Succeed())
Expect(dba.Album.FolderIDs).To(Equal([]string{"folder1", "folder2"}))
})
})
Describe("PostMapArgs", func() {
It("maps full_text correctly", func() {
Expect(dba.PostMapArgs(args)).To(Succeed())
Expect(args).To(HaveKeyWithValue("full_text", " name"))
})
It("maps tags correctly", func() {
dba.Album.Tags = model.Tags{"genre": {"rock", "pop"}, "mood": {"happy"}}
Expect(dba.PostMapArgs(args)).To(Succeed())
Expect(args).To(HaveKeyWithValue("tags",
`{"genre":[{"id":"5qDZoz1FBC36K73YeoJ2lF","value":"rock"},{"id":"4H0KjnlS2ob9nKLL0zHOqB",`+
`"value":"pop"}],"mood":[{"id":"1F4tmb516DIlHKFT1KzE1Z","value":"happy"}]}`,
))
})
It("maps participants correctly", func() {
dba.Album.Participants = model.Participants{
model.RoleAlbumArtist: model.ParticipantList{_p("AA1", "AlbumArtist1")},
model.RoleComposer: model.ParticipantList{{Artist: model.Artist{ID: "C1", Name: "Composer1"}, SubRole: "composer"}},
}
Expect(dba.PostMapArgs(args)).To(Succeed())
Expect(args).To(HaveKeyWithValue(
"participants",
`{"albumartist":[{"id":"AA1","name":"AlbumArtist1"}],`+
`"composer":[{"id":"C1","name":"Composer1","subRole":"composer"}]}`,
))
})
It("maps discs correctly", func() {
dba.Album.Discs = model.Discs{1: "disc1", 2: "disc2"}
Expect(dba.PostMapArgs(args)).To(Succeed())
Expect(args).To(HaveKeyWithValue("discs", `{"1":"disc1","2":"disc2"}`))
})
It("maps paths correctly", func() {
dba.Album.FolderIDs = []string{"folder1", "folder2"}
Expect(dba.PostMapArgs(args)).To(Succeed())
Expect(args).To(HaveKeyWithValue("folder_ids", `["folder1","folder2"]`))
})
})
})
Describe("dbAlbums.toModels", func() {
It("converts dbAlbums to model.Albums", func() {
dba := dbAlbums{
{Album: &model.Album{ID: "1", Name: "name", SongCount: 2, Annotations: model.Annotations{PlayCount: 4}}},
{Album: &model.Album{ID: "2", Name: "name2", SongCount: 3, Annotations: model.Annotations{PlayCount: 6}}},
}
albums := dba.toModels()
for i := range dba {
Expect(albums[i].ID).To(Equal(dba[i].Album.ID))
Expect(albums[i].Name).To(Equal(dba[i].Album.Name))
Expect(albums[i].SongCount).To(Equal(dba[i].Album.SongCount))
Expect(albums[i].PlayCount).To(Equal(dba[i].Album.PlayCount))
}
})
})
Describe("artistRoleFilter", func() {
DescribeTable("creates correct SQL expressions for artist roles",
func(filterName, artistID, expectedSQL string) {
sqlizer := artistRoleFilter(filterName, artistID)
sql, args, err := sqlizer.ToSql()
Expect(err).ToNot(HaveOccurred())
Expect(sql).To(Equal(expectedSQL))
Expect(args).To(Equal([]interface{}{artistID}))
},
Entry("artist role", "role_artist_id", "123",
"exists (select 1 from json_tree(participants, '$.artist') where value = ?)"),
Entry("albumartist role", "role_albumartist_id", "456",
"exists (select 1 from json_tree(participants, '$.albumartist') where value = ?)"),
Entry("composer role", "role_composer_id", "789",
"exists (select 1 from json_tree(participants, '$.composer') where value = ?)"),
)
It("works with the actual filter map", func() {
filters := albumFilters()
for roleName := range model.AllRoles {
filterName := "role_" + roleName + "_id"
filterFunc, exists := filters[filterName]
Expect(exists).To(BeTrue(), fmt.Sprintf("Filter %s should exist", filterName))
sqlizer := filterFunc(filterName, "test-id")
sql, args, err := sqlizer.ToSql()
Expect(err).ToNot(HaveOccurred())
Expect(sql).To(Equal(fmt.Sprintf("exists (select 1 from json_tree(participants, '$.%s') where value = ?)", roleName)))
Expect(args).To(Equal([]interface{}{"test-id"}))
}
})
It("rejects invalid roles", func() {
sqlizer := artistRoleFilter("role_invalid_id", "123")
_, _, err := sqlizer.ToSql()
Expect(err).To(HaveOccurred())
})
It("rejects invalid filter names", func() {
sqlizer := artistRoleFilter("invalid_name", "123")
_, _, err := sqlizer.ToSql()
Expect(err).To(HaveOccurred())
})
})
Describe("Participant Foreign Key Handling", func() {
// albumArtistRecord represents a record in the album_artists table
type albumArtistRecord struct {
ArtistID string `db:"artist_id"`
Role string `db:"role"`
SubRole string `db:"sub_role"`
}
var artistRepo *artistRepository
BeforeEach(func() {
ctx := request.WithUser(GinkgoT().Context(), adminUser)
artistRepo = NewArtistRepository(ctx, GetDBXBuilder(), nil).(*artistRepository)
})
// Helper to verify album_artists records
verifyAlbumArtists := func(albumID string, expected []albumArtistRecord) {
GinkgoHelper()
var actual []albumArtistRecord
sq := squirrel.Select("artist_id", "role", "sub_role").
From("album_artists").
Where(squirrel.Eq{"album_id": albumID}).
OrderBy("role", "artist_id", "sub_role")
err := albumRepo.queryAll(sq, &actual)
Expect(err).ToNot(HaveOccurred())
Expect(actual).To(Equal(expected))
}
It("verifies that participant records are actually inserted into database", func() {
// Create a real artist in the database first
artist := &model.Artist{
ID: "real-artist-1",
Name: "Real Artist",
OrderArtistName: "real artist",
SortArtistName: "Artist, Real",
}
err := createArtistWithLibrary(artistRepo, artist, 1)
Expect(err).ToNot(HaveOccurred())
// Create an album with participants that reference the real artist
album := &model.Album{
LibraryID: 1,
ID: "test-album-db-insert",
Name: "Test Album DB Insert",
AlbumArtistID: "real-artist-1",
AlbumArtist: "Real Artist",
Participants: model.Participants{
model.RoleArtist: {
{Artist: model.Artist{ID: "real-artist-1", Name: "Real Artist"}},
},
model.RoleComposer: {
{Artist: model.Artist{ID: "real-artist-1", Name: "Real Artist"}, SubRole: "primary"},
},
},
}
// Insert the album
err = albumRepo.Put(album)
Expect(err).ToNot(HaveOccurred())
// Verify that participant records were actually inserted into album_artists table
expected := []albumArtistRecord{
{ArtistID: "real-artist-1", Role: "artist", SubRole: ""},
{ArtistID: "real-artist-1", Role: "composer", SubRole: "primary"},
}
verifyAlbumArtists(album.ID, expected)
// Clean up the test artist and album created for this test
_, _ = artistRepo.executeSQL(squirrel.Delete("artist").Where(squirrel.Eq{"id": artist.ID}))
_, _ = albumRepo.executeSQL(squirrel.Delete("album").Where(squirrel.Eq{"id": album.ID}))
})
It("filters out invalid artist IDs leaving only valid participants in database", func() {
// Create two real artists in the database
artist1 := &model.Artist{
ID: "real-artist-mix-1",
Name: "Real Artist 1",
OrderArtistName: "real artist 1",
}
artist2 := &model.Artist{
ID: "real-artist-mix-2",
Name: "Real Artist 2",
OrderArtistName: "real artist 2",
}
err := createArtistWithLibrary(artistRepo, artist1, 1)
Expect(err).ToNot(HaveOccurred())
err = createArtistWithLibrary(artistRepo, artist2, 1)
Expect(err).ToNot(HaveOccurred())
// Create an album with mix of valid and invalid artist IDs
album := &model.Album{
LibraryID: 1,
ID: "test-album-mixed-validity",
Name: "Test Album Mixed Validity",
AlbumArtistID: "real-artist-mix-1",
AlbumArtist: "Real Artist 1",
Participants: model.Participants{
model.RoleArtist: {
{Artist: model.Artist{ID: "real-artist-mix-1", Name: "Real Artist 1"}},
{Artist: model.Artist{ID: "non-existent-mix-1", Name: "Non Existent 1"}},
{Artist: model.Artist{ID: "real-artist-mix-2", Name: "Real Artist 2"}},
},
model.RoleComposer: {
{Artist: model.Artist{ID: "non-existent-mix-2", Name: "Non Existent 2"}},
{Artist: model.Artist{ID: "real-artist-mix-1", Name: "Real Artist 1"}},
},
},
}
// This should not fail - only valid artists should be inserted
err = albumRepo.Put(album)
Expect(err).ToNot(HaveOccurred())
// Verify that only valid artist IDs were inserted into album_artists table
// Non-existent artists should be filtered out by the INNER JOIN
expected := []albumArtistRecord{
{ArtistID: "real-artist-mix-1", Role: "artist", SubRole: ""},
{ArtistID: "real-artist-mix-2", Role: "artist", SubRole: ""},
{ArtistID: "real-artist-mix-1", Role: "composer", SubRole: ""},
}
verifyAlbumArtists(album.ID, expected)
// Clean up the test artists and album created for this test
artistIDs := []string{artist1.ID, artist2.ID}
_, _ = artistRepo.executeSQL(squirrel.Delete("artist").Where(squirrel.Eq{"id": artistIDs}))
_, _ = albumRepo.executeSQL(squirrel.Delete("album").Where(squirrel.Eq{"id": album.ID}))
})
It("handles complex nested JSON with multiple roles and sub-roles", func() {
// Create 4 artists for this test
artists := []*model.Artist{
{ID: "complex-artist-1", Name: "Lead Vocalist", OrderArtistName: "lead vocalist"},
{ID: "complex-artist-2", Name: "Guitarist", OrderArtistName: "guitarist"},
{ID: "complex-artist-3", Name: "Producer", OrderArtistName: "producer"},
{ID: "complex-artist-4", Name: "Engineer", OrderArtistName: "engineer"},
}
for _, artist := range artists {
err := createArtistWithLibrary(artistRepo, artist, 1)
Expect(err).ToNot(HaveOccurred())
}
// Create album with complex participant structure
album := &model.Album{
LibraryID: 1,
ID: "test-album-complex-json",
Name: "Test Album Complex JSON",
AlbumArtistID: "complex-artist-1",
AlbumArtist: "Lead Vocalist",
Participants: model.Participants{
model.RoleArtist: {
{Artist: model.Artist{ID: "complex-artist-1", Name: "Lead Vocalist"}},
{Artist: model.Artist{ID: "complex-artist-2", Name: "Guitarist"}, SubRole: "lead guitar"},
{Artist: model.Artist{ID: "complex-artist-2", Name: "Guitarist"}, SubRole: "rhythm guitar"},
},
model.RoleProducer: {
{Artist: model.Artist{ID: "complex-artist-3", Name: "Producer"}, SubRole: "executive"},
},
model.RoleEngineer: {
{Artist: model.Artist{ID: "complex-artist-4", Name: "Engineer"}, SubRole: "mixing"},
{Artist: model.Artist{ID: "complex-artist-4", Name: "Engineer"}, SubRole: "mastering"},
},
},
}
err := albumRepo.Put(album)
Expect(err).ToNot(HaveOccurred())
// Verify complex JSON structure was correctly parsed and inserted
expected := []albumArtistRecord{
{ArtistID: "complex-artist-1", Role: "artist", SubRole: ""},
{ArtistID: "complex-artist-2", Role: "artist", SubRole: "lead guitar"},
{ArtistID: "complex-artist-2", Role: "artist", SubRole: "rhythm guitar"},
{ArtistID: "complex-artist-4", Role: "engineer", SubRole: "mastering"},
{ArtistID: "complex-artist-4", Role: "engineer", SubRole: "mixing"},
{ArtistID: "complex-artist-3", Role: "producer", SubRole: "executive"},
}
verifyAlbumArtists(album.ID, expected)
// Clean up the test artists and album created for this test
artistIDs := make([]string, len(artists))
for i, artist := range artists {
artistIDs[i] = artist.ID
}
_, _ = artistRepo.executeSQL(squirrel.Delete("artist").Where(squirrel.Eq{"id": artistIDs}))
_, _ = albumRepo.executeSQL(squirrel.Delete("album").Where(squirrel.Eq{"id": album.ID}))
})
It("handles albums with non-existent artist IDs without constraint errors", func() {
// Regression test for foreign key constraint error when album participants
// contain artist IDs that don't exist in the artist table
// Create an album with participants that reference non-existent artist IDs
album := &model.Album{
LibraryID: 1,
ID: "test-album-fk-constraints",
Name: "Test Album with Invalid Artist References",
AlbumArtistID: "non-existent-artist-1",
AlbumArtist: "Non Existent Album Artist",
Participants: model.Participants{
model.RoleArtist: {
{Artist: model.Artist{ID: "non-existent-artist-1", Name: "Non Existent Artist 1"}},
{Artist: model.Artist{ID: "non-existent-artist-2", Name: "Non Existent Artist 2"}},
},
model.RoleComposer: {
{Artist: model.Artist{ID: "non-existent-composer-1", Name: "Non Existent Composer 1"}},
{Artist: model.Artist{ID: "non-existent-composer-2", Name: "Non Existent Composer 2"}},
},
model.RoleAlbumArtist: {
{Artist: model.Artist{ID: "non-existent-album-artist-1", Name: "Non Existent Album Artist 1"}},
},
},
}
// This should not fail with foreign key constraint error
// The updateParticipants method should handle non-existent artist IDs gracefully
err := albumRepo.Put(album)
Expect(err).ToNot(HaveOccurred())
// Verify that no participant records were inserted since all artist IDs were invalid
// The INNER JOIN with the artist table should filter out all non-existent artists
verifyAlbumArtists(album.ID, []albumArtistRecord{})
// Clean up the test album created for this test
_, _ = albumRepo.executeSQL(squirrel.Delete("album").Where(squirrel.Eq{"id": album.ID}))
})
})
})
func _p(id, name string, sortName ...string) model.Participant {
p := model.Participant{Artist: model.Artist{ID: id, Name: name}}
if len(sortName) > 0 {
p.Artist.SortArtistName = sortName[0]
}
return p
}

View File

@@ -0,0 +1,608 @@
package persistence
import (
"cmp"
"context"
"encoding/json"
"fmt"
"slices"
"strings"
"time"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/google/uuid"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils"
. "github.com/navidrome/navidrome/utils/gg"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
type artistRepository struct {
sqlRepository
indexGroups utils.IndexGroups
ms *MeilisearchService
}
type dbArtist struct {
*model.Artist `structs:",flatten"`
SimilarArtists string `structs:"-" json:"-"`
LibraryStatsJSON string `structs:"-" json:"-"`
}
type dbSimilarArtist struct {
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
}
func (a *dbArtist) PostScan() error {
a.Artist.Stats = make(map[model.Role]model.ArtistStats)
if a.LibraryStatsJSON != "" {
var rawLibStats map[string]map[string]map[string]int64
if err := json.Unmarshal([]byte(a.LibraryStatsJSON), &rawLibStats); err != nil {
return fmt.Errorf("parsing artist stats from db: %w", err)
}
for _, stats := range rawLibStats {
// Sum all libraries roles stats
for key, stat := range stats {
// Aggregate stats into the main Artist.Stats map
artistStats := model.ArtistStats{
SongCount: int(stat["m"]),
AlbumCount: int(stat["a"]),
Size: stat["s"],
}
// Store total stats into the main attributes
if key == "total" {
a.Artist.Size += artistStats.Size
a.Artist.SongCount += artistStats.SongCount
a.Artist.AlbumCount += artistStats.AlbumCount
}
role := model.RoleFromString(key)
if role == model.RoleInvalid {
continue
}
current := a.Artist.Stats[role]
current.Size += artistStats.Size
current.SongCount += artistStats.SongCount
current.AlbumCount += artistStats.AlbumCount
a.Artist.Stats[role] = current
}
}
}
a.Artist.SimilarArtists = nil
if a.SimilarArtists == "" {
return nil
}
var sa []dbSimilarArtist
if err := json.Unmarshal([]byte(a.SimilarArtists), &sa); err != nil {
return fmt.Errorf("parsing similar artists from db: %w", err)
}
for _, s := range sa {
a.Artist.SimilarArtists = append(a.Artist.SimilarArtists, model.Artist{
ID: s.ID,
Name: s.Name,
})
}
return nil
}
func (a *dbArtist) PostMapArgs(m map[string]any) error {
sa := make([]dbSimilarArtist, 0)
for _, s := range a.Artist.SimilarArtists {
sa = append(sa, dbSimilarArtist{ID: s.ID, Name: s.Name})
}
similarArtists, _ := json.Marshal(sa)
m["similar_artists"] = string(similarArtists)
m["full_text"] = formatFullText(a.Name, a.SortArtistName)
// Do not override the sort_artist_name and mbz_artist_id fields if they are empty
// TODO: Better way to handle this?
if v, ok := m["sort_artist_name"]; !ok || v.(string) == "" {
delete(m, "sort_artist_name")
}
if v, ok := m["mbz_artist_id"]; !ok || v.(string) == "" {
delete(m, "mbz_artist_id")
}
return nil
}
type dbArtists []dbArtist
func (dba dbArtists) toModels() model.Artists {
res := make(model.Artists, len(dba))
for i := range dba {
res[i] = *dba[i].Artist
}
return res
}
func NewArtistRepository(ctx context.Context, db dbx.Builder, ms *MeilisearchService) model.ArtistRepository {
r := &artistRepository{}
r.ctx = ctx
r.db = db
r.ms = ms
r.indexGroups = utils.ParseIndexGroups(conf.Server.IndexGroups)
r.tableName = "artist" // To be used by the idFilter below
r.registerModel(&model.Artist{}, map[string]filterFunc{
"id": idFilter(r.tableName),
"name": fullTextFilter(r.tableName, "mbz_artist_id"),
"starred": booleanFilter,
"role": roleFilter,
"missing": booleanFilter,
"library_id": artistLibraryIdFilter,
})
r.setSortMappings(map[string]string{
"name": "order_artist_name",
"starred_at": "starred, starred_at",
"rated_at": "rating, rated_at",
"song_count": "stats->>'total'->>'m'",
"album_count": "stats->>'total'->>'a'",
"size": "stats->>'total'->>'s'",
// Stats by credits that are currently available
"maincredit_song_count": "sum(stats->>'maincredit'->>'m')",
"maincredit_album_count": "sum(stats->>'maincredit'->>'a')",
"maincredit_size": "sum(stats->>'maincredit'->>'s')",
})
return r
}
func roleFilter(_ string, role any) Sqlizer {
if role, ok := role.(string); ok {
if _, ok := model.AllRoles[role]; ok {
return Expr("JSON_EXTRACT(library_artist.stats, '$." + role + ".m') IS NOT NULL")
}
}
return Eq{"1": 2}
}
// artistLibraryIdFilter filters artists based on library access through the library_artist table
func artistLibraryIdFilter(_ string, value interface{}) Sqlizer {
return Eq{"library_artist.library_id": value}
}
// applyLibraryFilterToArtistQuery applies library filtering to artist queries through the library_artist junction table
func (r *artistRepository) applyLibraryFilterToArtistQuery(query SelectBuilder) SelectBuilder {
user := loggedUser(r.ctx)
// Join with library_artist first to ensure only artists with content in libraries are included
// Exclude artists with empty stats (no actual content in the library)
query = query.Join("library_artist on library_artist.artist_id = artist.id")
//query = query.Join("library_artist on library_artist.artist_id = artist.id AND library_artist.stats != '{}'")
// Admin users see all artists from all libraries, no additional filtering needed
if user.ID != invalidUserId && !user.IsAdmin {
// Apply library filtering only for non-admin users by joining with their accessible libraries
query = query.Join("user_library on user_library.library_id = library_artist.library_id AND user_library.user_id = ?", user.ID)
}
return query
}
func (r *artistRepository) selectArtist(options ...model.QueryOptions) SelectBuilder {
// Stats Format: {"1": {"albumartist": {"m": 10, "a": 5, "s": 1024}, "artist": {...}}, "2": {...}}
query := r.newSelect(options...).Columns("artist.*",
"JSON_GROUP_OBJECT(library_artist.library_id, JSONB(library_artist.stats)) as library_stats_json")
query = r.applyLibraryFilterToArtistQuery(query)
query = query.GroupBy("artist.id")
return r.withAnnotation(query, "artist.id")
}
func (r *artistRepository) CountAll(options ...model.QueryOptions) (int64, error) {
query := r.newSelect()
query = r.applyLibraryFilterToArtistQuery(query)
query = r.withAnnotation(query, "artist.id")
return r.count(query, options...)
}
// Exists checks if an artist with the given ID exists in the database and is accessible by the current user.
func (r *artistRepository) Exists(id string) (bool, error) {
// Create a query using the same library filtering logic as selectArtist()
query := r.newSelect().Columns("count(distinct artist.id) as exist").Where(Eq{"artist.id": id})
query = r.applyLibraryFilterToArtistQuery(query)
var res struct{ Exist int64 }
err := r.queryOne(query, &res)
return res.Exist > 0, err
}
func (r *artistRepository) Put(a *model.Artist, colsToUpdate ...string) error {
dba := &dbArtist{Artist: a}
dba.CreatedAt = P(time.Now())
dba.UpdatedAt = dba.CreatedAt
_, err := r.put(dba.ID, dba, colsToUpdate...)
if err == nil && r.ms != nil {
r.ms.IndexArtist(a)
}
return err
}
func (r *artistRepository) UpdateExternalInfo(a *model.Artist) error {
dba := &dbArtist{Artist: a}
_, err := r.put(a.ID, dba,
"biography", "small_image_url", "medium_image_url", "large_image_url",
"similar_artists", "external_url", "external_info_updated_at")
return err
}
func (r *artistRepository) Get(id string) (*model.Artist, error) {
sel := r.selectArtist().Where(Eq{"artist.id": id})
var dba dbArtists
if err := r.queryAll(sel, &dba); err != nil {
return nil, err
}
if len(dba) == 0 {
return nil, model.ErrNotFound
}
res := dba.toModels()
return &res[0], nil
}
func (r *artistRepository) GetAll(options ...model.QueryOptions) (model.Artists, error) {
sel := r.selectArtist(options...)
var dba dbArtists
err := r.queryAll(sel, &dba)
if err != nil {
return nil, err
}
res := dba.toModels()
return res, err
}
func (r *artistRepository) getIndexKey(a model.Artist) string {
source := a.OrderArtistName
if conf.Server.PreferSortTags {
source = cmp.Or(a.SortArtistName, a.OrderArtistName)
}
name := strings.ToLower(source)
for k, v := range r.indexGroups {
if strings.HasPrefix(name, strings.ToLower(k)) {
return v
}
}
return "#"
}
// GetIndex returns a list of artists grouped by the first letter of their name, or by the index group if configured.
// It can filter by roles and libraries, and optionally include artists that are missing (i.e., have no albums).
// TODO Cache the index (recalculate at scan time)
func (r *artistRepository) GetIndex(includeMissing bool, libraryIds []int, roles ...model.Role) (model.ArtistIndexes, error) {
// Validate library IDs. If no library IDs are provided, return an empty index.
if len(libraryIds) == 0 {
return nil, nil
}
options := model.QueryOptions{Sort: "name"}
if len(roles) > 0 {
roleFilters := slice.Map(roles, func(r model.Role) Sqlizer {
return roleFilter("role", r.String())
})
options.Filters = Or(roleFilters)
}
if !includeMissing {
if options.Filters == nil {
options.Filters = Eq{"artist.missing": false}
} else {
options.Filters = And{options.Filters, Eq{"artist.missing": false}}
}
}
libFilter := artistLibraryIdFilter("library_id", libraryIds)
if options.Filters == nil {
options.Filters = libFilter
} else {
options.Filters = And{options.Filters, libFilter}
}
artists, err := r.GetAll(options)
if err != nil {
return nil, err
}
var result model.ArtistIndexes
for k, v := range slice.Group(artists, r.getIndexKey) {
result = append(result, model.ArtistIndex{ID: k, Artists: v})
}
slices.SortFunc(result, func(a, b model.ArtistIndex) int {
return cmp.Compare(a.ID, b.ID)
})
return result, nil
}
func (r *artistRepository) purgeEmpty() error {
del := Delete(r.tableName).Where("id not in (select artist_id from album_artists)")
c, err := r.executeSQL(del)
if err != nil {
return fmt.Errorf("purging empty artists: %w", err)
}
if c > 0 {
log.Debug(r.ctx, "Purged empty artists", "totalDeleted", c)
}
return nil
}
// markMissing marks artists as missing if all their albums are missing.
func (r *artistRepository) markMissing() error {
q := Expr(`
with artists_with_non_missing_albums as (
select distinct aa.artist_id
from album_artists aa
join album a on aa.album_id = a.id
where a.missing = false
)
update artist
set missing = (artist.id not in (select artist_id from artists_with_non_missing_albums));
`)
_, err := r.executeSQL(q)
if err != nil {
return fmt.Errorf("marking missing artists: %w", err)
}
return nil
}
// RefreshPlayCounts updates the play count and last play date annotations for all artists, based
// on the media files associated with them.
func (r *artistRepository) RefreshPlayCounts() (int64, error) {
query := Expr(`
with play_counts as (
select user_id, atom as artist_id, sum(play_count) as total_play_count, max(play_date) as last_play_date
from media_file
join annotation on item_id = media_file.id
left join json_tree(participants, '$.artist') as jt
where atom is not null and key = 'id'
group by user_id, atom
)
insert into annotation (user_id, item_id, item_type, play_count, play_date)
select user_id, artist_id, 'artist', total_play_count, last_play_date
from play_counts
where total_play_count > 0
on conflict (user_id, item_id, item_type) do update
set play_count = excluded.play_count,
play_date = excluded.play_date;
`)
return r.executeSQL(query)
}
// RefreshStats updates the stats field for artists whose associated media files were updated after the oldest recorded library scan time.
// When allArtists is true, it refreshes stats for all artists. It processes artists in batches to handle potentially large updates.
// This method now calculates per-library statistics and stores them in the library_artist junction table.
func (r *artistRepository) RefreshStats(allArtists bool) (int64, error) {
var allTouchedArtistIDs []string
if allArtists {
// Refresh stats for all artists
allArtistsQuerySQL := `SELECT DISTINCT id FROM artist WHERE id <> ''`
if err := r.db.NewQuery(allArtistsQuerySQL).Column(&allTouchedArtistIDs); err != nil {
return 0, fmt.Errorf("fetching all artist IDs: %w", err)
}
log.Debug(r.ctx, "RefreshStats: Refreshing all artists.", "count", len(allTouchedArtistIDs))
} else {
// Only refresh artists with updated timestamps
touchedArtistsQuerySQL := `
SELECT DISTINCT id
FROM artist
WHERE updated_at > (SELECT last_scan_at FROM library ORDER BY last_scan_at ASC LIMIT 1)
`
if err := r.db.NewQuery(touchedArtistsQuerySQL).Column(&allTouchedArtistIDs); err != nil {
return 0, fmt.Errorf("fetching touched artist IDs: %w", err)
}
log.Debug(r.ctx, "RefreshStats: Refreshing touched artists.", "count", len(allTouchedArtistIDs))
}
if len(allTouchedArtistIDs) == 0 {
log.Debug(r.ctx, "RefreshStats: No artists to update.")
return 0, nil
}
// Template for the batch update with placeholder markers that we'll replace
// This now calculates per-library statistics and stores them in library_artist.stats
batchUpdateStatsSQL := `
WITH artist_role_counters AS (
SELECT mfa.artist_id,
mf.library_id,
mfa.role,
count(DISTINCT mf.album_id) AS album_count,
count(DISTINCT mf.id) AS count,
sum(mf.size) AS size
FROM media_file_artists mfa
JOIN media_file mf ON mfa.media_file_id = mf.id
WHERE mfa.artist_id IN (ROLE_IDS_PLACEHOLDER) -- Will replace with actual placeholders
GROUP BY mfa.artist_id, mf.library_id, mfa.role
),
artist_total_counters AS (
SELECT mfa.artist_id,
mf.library_id,
'total' AS role,
count(DISTINCT mf.album_id) AS album_count,
count(DISTINCT mf.id) AS count,
sum(mf.size) AS size
FROM media_file_artists mfa
JOIN media_file mf ON mfa.media_file_id = mf.id
WHERE mfa.artist_id IN (ROLE_IDS_PLACEHOLDER) -- Will replace with actual placeholders
GROUP BY mfa.artist_id, mf.library_id
),
artist_participant_counter AS (
SELECT mfa.artist_id,
mf.library_id,
'maincredit' AS role,
count(DISTINCT mf.album_id) AS album_count,
count(DISTINCT mf.id) AS count,
sum(mf.size) AS size
FROM media_file_artists mfa
JOIN media_file mf ON mfa.media_file_id = mf.id
WHERE mfa.artist_id IN (ROLE_IDS_PLACEHOLDER) -- Will replace with actual placeholders
AND mfa.role IN ('albumartist', 'artist')
GROUP BY mfa.artist_id, mf.library_id
),
combined_counters AS (
SELECT artist_id, library_id, role, album_count, count, size FROM artist_role_counters
UNION ALL
SELECT artist_id, library_id, role, album_count, count, size FROM artist_total_counters
UNION ALL
SELECT artist_id, library_id, role, album_count, count, size FROM artist_participant_counter
),
library_artist_counters AS (
SELECT artist_id,
library_id,
json_group_object(
role,
json_object('a', album_count, 'm', count, 's', size)
) AS counters
FROM combined_counters
GROUP BY artist_id, library_id
)
UPDATE library_artist
SET stats = coalesce((SELECT counters FROM library_artist_counters lac
WHERE lac.artist_id = library_artist.artist_id
AND lac.library_id = library_artist.library_id), '{}')
WHERE library_artist.artist_id IN (ROLE_IDS_PLACEHOLDER);` // Will replace with actual placeholders
var totalRowsAffected int64 = 0
const batchSize = 1000
batchCounter := 0
for artistIDBatch := range slice.CollectChunks(slices.Values(allTouchedArtistIDs), batchSize) {
batchCounter++
log.Trace(r.ctx, "RefreshStats: Processing batch", "batchNum", batchCounter, "batchSize", len(artistIDBatch))
// Create placeholders for each ID in the IN clauses
placeholders := make([]string, len(artistIDBatch))
for i := range artistIDBatch {
placeholders[i] = "?"
}
// Don't add extra parentheses, the IN clause already expects them in SQL syntax
inClause := strings.Join(placeholders, ",")
// Replace the placeholder markers with actual SQL placeholders
batchSQL := strings.Replace(batchUpdateStatsSQL, "ROLE_IDS_PLACEHOLDER", inClause, 4)
// Create a single parameter array with all IDs (repeated 4 times for each IN clause)
// We need to repeat each ID 4 times (once for each IN clause)
args := make([]any, 4*len(artistIDBatch))
for idx, id := range artistIDBatch {
for i := range 4 {
startIdx := i * len(artistIDBatch)
args[startIdx+idx] = id
}
}
// Now use Expr with the expanded SQL and all parameters
sqlizer := Expr(batchSQL, args...)
rowsAffected, err := r.executeSQL(sqlizer)
if err != nil {
return totalRowsAffected, fmt.Errorf("executing batch update for artist stats (batch %d): %w", batchCounter, err)
}
totalRowsAffected += rowsAffected
}
// // Remove library_artist entries for artists that no longer have any content in any library
cleanupSQL := Delete("library_artist").Where("stats = '{}'")
cleanupRows, err := r.executeSQL(cleanupSQL)
if err != nil {
log.Warn(r.ctx, "Failed to cleanup empty library_artist entries", "error", err)
} else if cleanupRows > 0 {
log.Debug(r.ctx, "Cleaned up empty library_artist entries", "rowsDeleted", cleanupRows)
}
log.Debug(r.ctx, "RefreshStats: Successfully updated stats.", "totalArtistsProcessed", len(allTouchedArtistIDs), "totalDBRowsAffected", totalRowsAffected)
return totalRowsAffected, nil
}
func (r *artistRepository) Search(q string, offset int, size int, options ...model.QueryOptions) (model.Artists, error) {
var res dbArtists
if uuid.Validate(q) == nil {
err := r.searchByMBID(r.selectArtist(options...), q, []string{"mbz_artist_id"}, &res)
if err != nil {
return nil, fmt.Errorf("searching artist by MBID %q: %w", q, err)
}
} else {
if r.ms != nil {
ids, err := r.ms.Search("artists", q, offset, size)
if err == nil {
if len(ids) == 0 {
return model.Artists{}, nil
}
// Fetch matching artists from the database
// We need to fetch all fields to return complete objects
artists, err := r.GetAll(model.QueryOptions{Filters: Eq{"artist.id": ids}})
if err != nil {
return nil, fmt.Errorf("fetching artists from meilisearch ids: %w", err)
}
// Reorder results to match Meilisearch order
idMap := make(map[string]model.Artist, len(artists))
for _, a := range artists {
idMap[a.ID] = a
}
sorted := make(model.Artists, 0, len(artists))
for _, id := range ids {
if a, ok := idMap[id]; ok {
sorted = append(sorted, a)
}
}
return sorted, nil
}
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
}
// Natural order for artists is more performant by ID, due to GROUP BY clause in selectArtist
err := r.doSearch(r.selectArtist(options...), q, offset, size, &res, "artist.id",
"sum(json_extract(stats, '$.total.m')) desc", "name")
if err != nil {
return nil, fmt.Errorf("searching artist by query %q: %w", q, err)
}
}
return res.toModels(), nil
}
func (r *artistRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *artistRepository) Read(id string) (interface{}, error) {
return r.Get(id)
}
func (r *artistRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
role := "total"
if len(options) > 0 {
if v, ok := options[0].Filters["role"].(string); ok {
role = v
}
if r.ms != nil {
if name, ok := options[0].Filters["name"].(string); ok && name != "" {
ids, err := r.ms.Search("artists", name, 0, 10000)
if err == nil {
log.Debug(r.ctx, "Meilisearch found matches", "count", len(ids), "query", name)
delete(options[0].Filters, "name")
options[0].Filters["id"] = ids
} else {
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
}
}
}
}
r.sortMappings["song_count"] = "sum(stats->>'" + role + "'->>'m')"
r.sortMappings["album_count"] = "sum(stats->>'" + role + "'->>'a')"
r.sortMappings["size"] = "sum(stats->>'" + role + "'->>'s')"
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *artistRepository) EntityName() string {
return "artist"
}
func (r *artistRepository) NewInstance() interface{} {
return &model.Artist{}
}
var _ model.ArtistRepository = (*artistRepository)(nil)
var _ model.ResourceRepository = (*artistRepository)(nil)

View File

@@ -0,0 +1,772 @@
package persistence
import (
"context"
"encoding/json"
"github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/conf/configtest"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
"github.com/navidrome/navidrome/utils"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// Test helper functions to reduce duplication
func createTestArtistWithMBID(id, name, mbid string) model.Artist {
return model.Artist{
ID: id,
Name: name,
MbzArtistID: mbid,
}
}
func createUserWithLibraries(userID string, libraryIDs []int) model.User {
user := model.User{
ID: userID,
UserName: userID,
Name: userID,
Email: userID + "@test.com",
IsAdmin: false,
}
if len(libraryIDs) > 0 {
user.Libraries = make(model.Libraries, len(libraryIDs))
for i, libID := range libraryIDs {
user.Libraries[i] = model.Library{ID: libID, Name: "Test Library", Path: "/test"}
}
}
return user
}
var _ = Describe("ArtistRepository", func() {
Context("Core Functionality", func() {
Describe("GetIndexKey", func() {
// Note: OrderArtistName should never be empty, so we don't need to test for that
r := artistRepository{indexGroups: utils.ParseIndexGroups(conf.Server.IndexGroups)}
DescribeTable("returns correct index key based on PreferSortTags setting",
func(preferSortTags bool, sortArtistName, orderArtistName, expectedKey string) {
DeferCleanup(configtest.SetupConfig())
conf.Server.PreferSortTags = preferSortTags
a := model.Artist{SortArtistName: sortArtistName, OrderArtistName: orderArtistName, Name: "Test"}
idx := GetIndexKey(&r, a)
Expect(idx).To(Equal(expectedKey))
},
Entry("PreferSortTags=false, SortArtistName empty -> uses OrderArtistName", false, "", "Bar", "B"),
Entry("PreferSortTags=false, SortArtistName not empty -> still uses OrderArtistName", false, "Foo", "Bar", "B"),
Entry("PreferSortTags=true, SortArtistName not empty -> uses SortArtistName", true, "Foo", "Bar", "F"),
Entry("PreferSortTags=true, SortArtistName empty -> falls back to OrderArtistName", true, "", "Bar", "B"),
)
})
Describe("roleFilter", func() {
DescribeTable("validates roles and returns appropriate SQL expressions",
func(role string, shouldBeValid bool) {
result := roleFilter("", role)
if shouldBeValid {
expectedExpr := squirrel.Expr("JSON_EXTRACT(library_artist.stats, '$." + role + ".m') IS NOT NULL")
Expect(result).To(Equal(expectedExpr))
} else {
expectedInvalid := squirrel.Eq{"1": 2}
Expect(result).To(Equal(expectedInvalid))
}
},
// Valid roles from model.AllRoles
Entry("artist role", "artist", true),
Entry("albumartist role", "albumartist", true),
Entry("composer role", "composer", true),
Entry("conductor role", "conductor", true),
Entry("lyricist role", "lyricist", true),
Entry("arranger role", "arranger", true),
Entry("producer role", "producer", true),
Entry("director role", "director", true),
Entry("engineer role", "engineer", true),
Entry("mixer role", "mixer", true),
Entry("remixer role", "remixer", true),
Entry("djmixer role", "djmixer", true),
Entry("performer role", "performer", true),
Entry("maincredit role", "maincredit", true),
// Invalid roles
Entry("invalid role - wizard", "wizard", false),
Entry("invalid role - songanddanceman", "songanddanceman", false),
Entry("empty string", "", false),
Entry("SQL injection attempt", "artist') SELECT LIKE(CHAR(65,66,67,68,69,70,71),UPPER(HEX(RANDOMBLOB(500000000/2))))--", false),
)
It("handles non-string input types", func() {
expectedInvalid := squirrel.Eq{"1": 2}
Expect(roleFilter("", 123)).To(Equal(expectedInvalid))
Expect(roleFilter("", nil)).To(Equal(expectedInvalid))
Expect(roleFilter("", []string{"artist"})).To(Equal(expectedInvalid))
})
})
Describe("dbArtist mapping", func() {
var (
artist *model.Artist
dba *dbArtist
)
BeforeEach(func() {
artist = &model.Artist{ID: "1", Name: "Eddie Van Halen", SortArtistName: "Van Halen, Eddie"}
dba = &dbArtist{Artist: artist}
})
Describe("PostScan", func() {
It("parses stats and similar artists correctly", func() {
stats := map[string]map[string]map[string]int64{
"1": {
"total": {"s": 1000, "m": 10, "a": 2},
"composer": {"s": 500, "m": 5, "a": 1},
},
}
statsJSON, _ := json.Marshal(stats)
dba.LibraryStatsJSON = string(statsJSON)
dba.SimilarArtists = `[{"id":"2","Name":"AC/DC"},{"name":"Test;With:Sep,Chars"}]`
err := dba.PostScan()
Expect(err).ToNot(HaveOccurred())
Expect(dba.Artist.Size).To(Equal(int64(1000)))
Expect(dba.Artist.SongCount).To(Equal(10))
Expect(dba.Artist.AlbumCount).To(Equal(2))
Expect(dba.Artist.Stats).To(HaveLen(1))
Expect(dba.Artist.Stats[model.RoleFromString("composer")].Size).To(Equal(int64(500)))
Expect(dba.Artist.Stats[model.RoleFromString("composer")].SongCount).To(Equal(5))
Expect(dba.Artist.Stats[model.RoleFromString("composer")].AlbumCount).To(Equal(1))
Expect(dba.Artist.SimilarArtists).To(HaveLen(2))
Expect(dba.Artist.SimilarArtists[0].ID).To(Equal("2"))
Expect(dba.Artist.SimilarArtists[0].Name).To(Equal("AC/DC"))
Expect(dba.Artist.SimilarArtists[1].ID).To(BeEmpty())
Expect(dba.Artist.SimilarArtists[1].Name).To(Equal("Test;With:Sep,Chars"))
})
})
Describe("PostMapArgs", func() {
It("maps empty similar artists correctly", func() {
m := make(map[string]any)
err := dba.PostMapArgs(m)
Expect(err).ToNot(HaveOccurred())
Expect(m).To(HaveKeyWithValue("similar_artists", "[]"))
})
It("maps similar artists and full text correctly", func() {
artist.SimilarArtists = []model.Artist{
{ID: "2", Name: "AC/DC"},
{Name: "Test;With:Sep,Chars"},
}
m := make(map[string]any)
err := dba.PostMapArgs(m)
Expect(err).ToNot(HaveOccurred())
Expect(m).To(HaveKeyWithValue("similar_artists", `[{"id":"2","name":"AC/DC"},{"name":"Test;With:Sep,Chars"}]`))
Expect(m).To(HaveKeyWithValue("full_text", " eddie halen van"))
})
It("does not override empty sort_artist_name and mbz_artist_id", func() {
m := map[string]any{
"sort_artist_name": "",
"mbz_artist_id": "",
}
err := dba.PostMapArgs(m)
Expect(err).ToNot(HaveOccurred())
Expect(m).ToNot(HaveKey("sort_artist_name"))
Expect(m).ToNot(HaveKey("mbz_artist_id"))
})
})
})
})
Context("Admin User Operations", func() {
var repo model.ArtistRepository
BeforeEach(func() {
ctx := GinkgoT().Context()
ctx = request.WithUser(ctx, adminUser)
repo = NewArtistRepository(ctx, GetDBXBuilder(), nil).(*artistRepository)
})
Describe("Basic Operations", func() {
Describe("Count", func() {
It("returns the number of artists in the DB", func() {
Expect(repo.CountAll()).To(Equal(int64(2)))
})
})
Describe("Exists", func() {
It("returns true for an artist that is in the DB", func() {
Expect(repo.Exists("3")).To(BeTrue())
})
It("returns false for an artist that is NOT in the DB", func() {
Expect(repo.Exists("666")).To(BeFalse())
})
})
Describe("Get", func() {
It("retrieves existing artist data", func() {
artist, err := repo.Get("2")
Expect(err).ToNot(HaveOccurred())
Expect(artist.Name).To(Equal(artistKraftwerk.Name))
})
})
})
Describe("GetIndex", func() {
When("PreferSortTags is true", func() {
BeforeEach(func() {
conf.Server.PreferSortTags = true
})
It("returns the index when PreferSortTags is true and SortArtistName is not empty", func() {
// Set SortArtistName to "Foo" for Beatles
artistBeatles.SortArtistName = "Foo"
er := repo.Put(&artistBeatles)
Expect(er).To(BeNil())
idx, err := repo.GetIndex(false, []int{1})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(2))
Expect(idx[0].ID).To(Equal("F"))
Expect(idx[0].Artists).To(HaveLen(1))
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
Expect(idx[1].ID).To(Equal("K"))
Expect(idx[1].Artists).To(HaveLen(1))
Expect(idx[1].Artists[0].Name).To(Equal(artistKraftwerk.Name))
// Restore the original value
artistBeatles.SortArtistName = ""
er = repo.Put(&artistBeatles)
Expect(er).To(BeNil())
})
// BFR Empty SortArtistName is not saved in the DB anymore
XIt("returns the index when PreferSortTags is true and SortArtistName is empty", func() {
idx, err := repo.GetIndex(false, []int{1})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(2))
Expect(idx[0].ID).To(Equal("B"))
Expect(idx[0].Artists).To(HaveLen(1))
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
Expect(idx[1].ID).To(Equal("K"))
Expect(idx[1].Artists).To(HaveLen(1))
Expect(idx[1].Artists[0].Name).To(Equal(artistKraftwerk.Name))
})
})
When("PreferSortTags is false", func() {
BeforeEach(func() {
conf.Server.PreferSortTags = false
})
It("returns the index when SortArtistName is NOT empty", func() {
// Set SortArtistName to "Foo" for Beatles
artistBeatles.SortArtistName = "Foo"
er := repo.Put(&artistBeatles)
Expect(er).To(BeNil())
idx, err := repo.GetIndex(false, []int{1})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(2))
Expect(idx[0].ID).To(Equal("B"))
Expect(idx[0].Artists).To(HaveLen(1))
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
Expect(idx[1].ID).To(Equal("K"))
Expect(idx[1].Artists).To(HaveLen(1))
Expect(idx[1].Artists[0].Name).To(Equal(artistKraftwerk.Name))
// Restore the original value
artistBeatles.SortArtistName = ""
er = repo.Put(&artistBeatles)
Expect(er).To(BeNil())
})
It("returns the index when SortArtistName is empty", func() {
idx, err := repo.GetIndex(false, []int{1})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(2))
Expect(idx[0].ID).To(Equal("B"))
Expect(idx[0].Artists).To(HaveLen(1))
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
Expect(idx[1].ID).To(Equal("K"))
Expect(idx[1].Artists).To(HaveLen(1))
Expect(idx[1].Artists[0].Name).To(Equal(artistKraftwerk.Name))
})
})
When("filtering by role", func() {
var raw *artistRepository
BeforeEach(func() {
raw = repo.(*artistRepository)
// Add stats to library_artist table since stats are now stored per-library
composerStats := `{"composer": {"s": 1000, "m": 5, "a": 2}}`
producerStats := `{"producer": {"s": 500, "m": 3, "a": 1}}`
// Set Beatles as composer in library 1
_, err := raw.executeSQL(squirrel.Insert("library_artist").
Columns("library_id", "artist_id", "stats").
Values(1, artistBeatles.ID, composerStats).
Suffix("ON CONFLICT(library_id, artist_id) DO UPDATE SET stats = excluded.stats"))
Expect(err).ToNot(HaveOccurred())
// Set Kraftwerk as producer in library 1
_, err = raw.executeSQL(squirrel.Insert("library_artist").
Columns("library_id", "artist_id", "stats").
Values(1, artistKraftwerk.ID, producerStats).
Suffix("ON CONFLICT(library_id, artist_id) DO UPDATE SET stats = excluded.stats"))
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
// Clean up stats from library_artist table
_, _ = raw.executeSQL(squirrel.Update("library_artist").
Set("stats", "{}").
Where(squirrel.Eq{"artist_id": artistBeatles.ID, "library_id": 1}))
_, _ = raw.executeSQL(squirrel.Update("library_artist").
Set("stats", "{}").
Where(squirrel.Eq{"artist_id": artistKraftwerk.ID, "library_id": 1}))
})
It("returns only artists with the specified role", func() {
idx, err := repo.GetIndex(false, []int{1}, model.RoleComposer)
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(1))
Expect(idx[0].ID).To(Equal("B"))
Expect(idx[0].Artists).To(HaveLen(1))
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
})
It("returns artists with any of the specified roles", func() {
idx, err := repo.GetIndex(false, []int{1}, model.RoleComposer, model.RoleProducer)
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(2))
// Find Beatles and Kraftwerk in the results
var beatlesFound, kraftwerkFound bool
for _, index := range idx {
for _, artist := range index.Artists {
if artist.Name == artistBeatles.Name {
beatlesFound = true
}
if artist.Name == artistKraftwerk.Name {
kraftwerkFound = true
}
}
}
Expect(beatlesFound).To(BeTrue())
Expect(kraftwerkFound).To(BeTrue())
})
It("returns empty index when no artists have the specified role", func() {
idx, err := repo.GetIndex(false, []int{1}, model.RoleDirector)
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(0))
})
})
When("validating library IDs", func() {
It("returns nil when no library IDs are provided", func() {
idx, err := repo.GetIndex(false, []int{})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(BeNil())
})
It("returns artists when library IDs are provided (admin user sees all content)", func() {
// Admin users can see all content when valid library IDs are provided
idx, err := repo.GetIndex(false, []int{1})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(2))
// With non-existent library ID, admin users see no content because no artists are associated with that library
idx, err = repo.GetIndex(false, []int{999})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(0)) // Even admin users need valid library associations
})
})
})
Describe("MBID and Text Search", func() {
var lib2 model.Library
var lr model.LibraryRepository
var restrictedUser model.User
var restrictedRepo model.ArtistRepository
var headlessRepo model.ArtistRepository
BeforeEach(func() {
// Set up headless repo (no user context)
headlessRepo = NewArtistRepository(context.Background(), GetDBXBuilder(), nil)
// Create library for testing access restrictions
lib2 = model.Library{ID: 0, Name: "Artist Test Library", Path: "/artist/test/lib"}
lr = NewLibraryRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
err := lr.Put(&lib2)
Expect(err).ToNot(HaveOccurred())
// Create a user with access to only library 1
restrictedUser = createUserWithLibraries("search_user", []int{1})
// Create repository context for the restricted user
ctx := request.WithUser(GinkgoT().Context(), restrictedUser)
restrictedRepo = NewArtistRepository(ctx, GetDBXBuilder(), nil)
// Ensure both test artists are associated with library 1
err = lr.AddArtist(1, artistBeatles.ID)
Expect(err).ToNot(HaveOccurred())
err = lr.AddArtist(1, artistKraftwerk.ID)
Expect(err).ToNot(HaveOccurred())
// Create the restricted user in the database
ur := NewUserRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
err = ur.Put(&restrictedUser)
Expect(err).ToNot(HaveOccurred())
err = ur.SetUserLibraries(restrictedUser.ID, []int{1})
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
// Clean up library 2
lr := NewLibraryRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
_ = lr.(*libraryRepository).delete(squirrel.Eq{"id": lib2.ID})
})
DescribeTable("MBID search behavior across different user types",
func(testRepo *model.ArtistRepository, shouldFind bool, testDesc string) {
// Create test artist with MBID
artistWithMBID := createTestArtistWithMBID("test-mbid-artist", "Test MBID Artist", "550e8400-e29b-41d4-a716-446655440010")
err := createArtistWithLibrary(*testRepo, &artistWithMBID, 1)
Expect(err).ToNot(HaveOccurred())
// Test the search
results, err := (*testRepo).Search("550e8400-e29b-41d4-a716-446655440010", 0, 10)
Expect(err).ToNot(HaveOccurred())
if shouldFind {
Expect(results).To(HaveLen(1), testDesc)
Expect(results[0].ID).To(Equal("test-mbid-artist"))
} else {
Expect(results).To(BeEmpty(), testDesc)
}
// Clean up
if raw, ok := (*testRepo).(*artistRepository); ok {
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": artistWithMBID.ID}))
}
},
Entry("Admin user can find artist by MBID", &repo, true, "Admin should find MBID artist"),
Entry("Restricted user can find artist by MBID in accessible library", &restrictedRepo, true, "Restricted user should find MBID artist in accessible library"),
Entry("Headless process can find artist by MBID", &headlessRepo, true, "Headless process should find MBID artist"),
)
It("prevents restricted user from finding artist by MBID when not in accessible library", func() {
// Create an artist in library 2 (not accessible to restricted user)
inaccessibleArtist := createTestArtistWithMBID("inaccessible-mbid-artist", "Inaccessible MBID Artist", "a74b1b7f-71a5-4011-9441-d0b5e4122711")
err := repo.Put(&inaccessibleArtist)
Expect(err).ToNot(HaveOccurred())
// Add to library 2 (not accessible to restricted user)
err = lr.AddArtist(lib2.ID, inaccessibleArtist.ID)
Expect(err).ToNot(HaveOccurred())
// Restricted user should not find this artist
results, err := restrictedRepo.Search("a74b1b7f-71a5-4011-9441-d0b5e4122711", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
// But admin should find it
results, err = repo.Search("a74b1b7f-71a5-4011-9441-d0b5e4122711", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(1))
// Clean up
if raw, ok := repo.(*artistRepository); ok {
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": inaccessibleArtist.ID}))
}
})
Context("Text Search", func() {
It("allows admin to find artists by name regardless of library", func() {
results, err := repo.Search("Beatles", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(1))
Expect(results[0].Name).To(Equal("The Beatles"))
})
It("correctly prevents restricted user from finding artists by name when not in accessible library", func() {
// Create an artist in library 2 (not accessible to restricted user)
inaccessibleArtist := model.Artist{
ID: "inaccessible-text-artist",
Name: "Unique Search Name Artist",
}
err := repo.Put(&inaccessibleArtist)
Expect(err).ToNot(HaveOccurred())
// Add to library 2 (not accessible to restricted user)
err = lr.AddArtist(lib2.ID, inaccessibleArtist.ID)
Expect(err).ToNot(HaveOccurred())
// Restricted user should not find this artist
results, err := restrictedRepo.Search("Unique Search Name", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty(), "Text search should respect library filtering")
// Clean up
if raw, ok := repo.(*artistRepository); ok {
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": inaccessibleArtist.ID}))
}
})
})
Context("Headless Processes (No User Context)", func() {
It("should see all artists from all libraries when no user is in context", func() {
// Add artists to different libraries
err := lr.AddArtist(lib2.ID, artistBeatles.ID)
Expect(err).ToNot(HaveOccurred())
// Headless processes should see all artists regardless of library
artists, err := headlessRepo.GetAll()
Expect(err).ToNot(HaveOccurred())
// Should see all artists from all libraries
found := false
for _, artist := range artists {
if artist.ID == artistBeatles.ID {
found = true
break
}
}
Expect(found).To(BeTrue(), "Headless process should see artists from all libraries")
})
It("should allow headless processes to apply explicit library_id filters", func() {
// Add artists to different libraries
err := lr.AddArtist(lib2.ID, artistBeatles.ID)
Expect(err).ToNot(HaveOccurred())
// Filter by specific library
artists, err := headlessRepo.GetAll(model.QueryOptions{
Filters: squirrel.Eq{"library_id": lib2.ID},
})
Expect(err).ToNot(HaveOccurred())
// Should see only artists from the specified library
for _, artist := range artists {
if artist.ID == artistBeatles.ID {
return // Found the expected artist
}
}
Expect(false).To(BeTrue(), "Should find artist from specified library")
})
It("should get individual artists when no user is in context", func() {
// Add artist to a library
err := lr.AddArtist(lib2.ID, artistBeatles.ID)
Expect(err).ToNot(HaveOccurred())
// Headless process should be able to get the artist
artist, err := headlessRepo.Get(artistBeatles.ID)
Expect(err).ToNot(HaveOccurred())
Expect(artist.ID).To(Equal(artistBeatles.ID))
})
})
})
Describe("Admin User Library Access", func() {
It("sees all artists regardless of library permissions", func() {
count, err := repo.CountAll()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(2)))
artists, err := repo.GetAll()
Expect(err).ToNot(HaveOccurred())
Expect(artists).To(HaveLen(2))
exists, err := repo.Exists(artistBeatles.ID)
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
})
})
Describe("Missing Artist Handling", func() {
var missingArtist model.Artist
var raw *artistRepository
BeforeEach(func() {
raw = repo.(*artistRepository)
missingArtist = model.Artist{ID: "missing_test", Name: "Missing Artist", OrderArtistName: "missing artist"}
// Create and mark as missing
err := createArtistWithLibrary(repo, &missingArtist, 1)
Expect(err).ToNot(HaveOccurred())
_, err = raw.executeSQL(squirrel.Update(raw.tableName).Set("missing", true).Where(squirrel.Eq{"id": missingArtist.ID}))
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": missingArtist.ID}))
})
It("missing artists are never returned by search", func() {
// Should see missing artist in GetAll by default for admin users
artists, err := repo.GetAll()
Expect(err).ToNot(HaveOccurred())
Expect(artists).To(HaveLen(3)) // Including the missing artist
// Search never returns missing artists (hardcoded behavior)
results, err := repo.Search("Missing Artist", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
})
})
})
Context("Regular User Operations", func() {
var restrictedRepo model.ArtistRepository
var unauthorizedUser model.User
BeforeEach(func() {
// Create a user without access to any libraries
unauthorizedUser = model.User{ID: "restricted_user", UserName: "restricted", Name: "Restricted User", Email: "restricted@test.com", IsAdmin: false}
// Create repository context for the unauthorized user
ctx := GinkgoT().Context()
ctx = request.WithUser(ctx, unauthorizedUser)
restrictedRepo = NewArtistRepository(ctx, GetDBXBuilder(), nil)
})
Describe("Library Access Restrictions", func() {
It("CountAll returns 0 for users without library access", func() {
count, err := restrictedRepo.CountAll()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(0)))
})
It("GetAll returns empty list for users without library access", func() {
artists, err := restrictedRepo.GetAll()
Expect(err).ToNot(HaveOccurred())
Expect(artists).To(BeEmpty())
})
It("Exists returns false for existing artists when user has no library access", func() {
// These artists exist in the DB but the user has no access to them
exists, err := restrictedRepo.Exists(artistBeatles.ID)
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeFalse())
exists, err = restrictedRepo.Exists(artistKraftwerk.ID)
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeFalse())
})
It("Get returns ErrNotFound for existing artists when user has no library access", func() {
_, err := restrictedRepo.Get(artistBeatles.ID)
Expect(err).To(Equal(model.ErrNotFound))
_, err = restrictedRepo.Get(artistKraftwerk.ID)
Expect(err).To(Equal(model.ErrNotFound))
})
It("Search returns empty results for users without library access", func() {
results, err := restrictedRepo.Search("Beatles", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
results, err = restrictedRepo.Search("Kraftwerk", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
})
It("GetIndex returns empty index for users without library access", func() {
idx, err := restrictedRepo.GetIndex(false, []int{1})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(0))
})
})
Context("when user gains library access", func() {
BeforeEach(func() {
ctx := GinkgoT().Context()
// Give the user access to library 1
ur := NewUserRepository(request.WithUser(ctx, adminUser), GetDBXBuilder())
// First create the user if not exists
err := ur.Put(&unauthorizedUser)
Expect(err).ToNot(HaveOccurred())
// Then add library access
err = ur.SetUserLibraries(unauthorizedUser.ID, []int{1})
Expect(err).ToNot(HaveOccurred())
// Update the user object with the libraries to simulate middleware behavior
libraries, err := ur.GetUserLibraries(unauthorizedUser.ID)
Expect(err).ToNot(HaveOccurred())
unauthorizedUser.Libraries = libraries
// Recreate repository context with updated user
ctx = request.WithUser(ctx, unauthorizedUser)
restrictedRepo = NewArtistRepository(ctx, GetDBXBuilder(), nil)
})
AfterEach(func() {
// Clean up: remove the user's library access
ur := NewUserRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
_ = ur.SetUserLibraries(unauthorizedUser.ID, []int{})
})
It("CountAll returns correct count after gaining access", func() {
count, err := restrictedRepo.CountAll()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(2))) // Beatles and Kraftwerk
})
It("GetAll returns artists after gaining access", func() {
artists, err := restrictedRepo.GetAll()
Expect(err).ToNot(HaveOccurred())
Expect(artists).To(HaveLen(2))
var names []string
for _, artist := range artists {
names = append(names, artist.Name)
}
Expect(names).To(ContainElements("The Beatles", "Kraftwerk"))
})
It("Exists returns true for accessible artists", func() {
exists, err := restrictedRepo.Exists(artistBeatles.ID)
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
exists, err = restrictedRepo.Exists(artistKraftwerk.ID)
Expect(err).ToNot(HaveOccurred())
Expect(exists).To(BeTrue())
})
It("GetIndex returns artists with proper library filtering", func() {
// With valid library access, should see artists
idx, err := restrictedRepo.GetIndex(false, []int{1})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(2))
// With non-existent library ID, should see nothing (non-admin user)
idx, err = restrictedRepo.GetIndex(false, []int{999})
Expect(err).ToNot(HaveOccurred())
Expect(idx).To(HaveLen(0))
})
})
})
})
// Helper function to create an artist with proper library association.
// This ensures test artists always have library_artist associations to avoid orphaned artists in tests.
func createArtistWithLibrary(repo model.ArtistRepository, artist *model.Artist, libraryID int) error {
err := repo.Put(artist)
if err != nil {
return err
}
// Add the artist to the specified library
lr := NewLibraryRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
return lr.AddArtist(libraryID, artist.ID)
}

View File

@@ -0,0 +1,122 @@
package persistence
import (
"database/sql"
"errors"
"fmt"
"regexp"
"github.com/navidrome/navidrome/db"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// When creating migrations that change existing columns, it is easy to miss the original collation of a column.
// These tests enforce that the required collation of the columns and indexes in the database are kept in place.
// This is important to ensure that the database can perform fast case-insensitive searches and sorts.
var _ = Describe("Collation", func() {
conn := db.Db()
DescribeTable("Column collation",
func(table, column string) {
Expect(checkCollation(conn, table, column)).To(Succeed())
},
Entry("artist.order_artist_name", "artist", "order_artist_name"),
Entry("artist.sort_artist_name", "artist", "sort_artist_name"),
Entry("album.order_album_name", "album", "order_album_name"),
Entry("album.order_album_artist_name", "album", "order_album_artist_name"),
Entry("album.sort_album_name", "album", "sort_album_name"),
Entry("album.sort_album_artist_name", "album", "sort_album_artist_name"),
Entry("media_file.order_title", "media_file", "order_title"),
Entry("media_file.order_album_name", "media_file", "order_album_name"),
Entry("media_file.order_artist_name", "media_file", "order_artist_name"),
Entry("media_file.sort_title", "media_file", "sort_title"),
Entry("media_file.sort_album_name", "media_file", "sort_album_name"),
Entry("media_file.sort_artist_name", "media_file", "sort_artist_name"),
Entry("radio.name", "radio", "name"),
Entry("user.name", "user", "name"),
)
DescribeTable("Index collation",
func(table, column string) {
Expect(checkIndexUsage(conn, table, column)).To(Succeed())
},
Entry("artist.order_artist_name", "artist", "order_artist_name collate nocase"),
Entry("artist.sort_artist_name", "artist", "coalesce(nullif(sort_artist_name,''),order_artist_name) collate nocase"),
Entry("album.order_album_name", "album", "order_album_name collate nocase"),
Entry("album.order_album_artist_name", "album", "order_album_artist_name collate nocase"),
Entry("album.sort_album_name", "album", "coalesce(nullif(sort_album_name,''),order_album_name) collate nocase"),
Entry("album.sort_album_artist_name", "album", "coalesce(nullif(sort_album_artist_name,''),order_album_artist_name) collate nocase"),
Entry("media_file.order_title", "media_file", "order_title collate nocase"),
Entry("media_file.order_album_name", "media_file", "order_album_name collate nocase"),
Entry("media_file.order_artist_name", "media_file", "order_artist_name collate nocase"),
Entry("media_file.sort_title", "media_file", "coalesce(nullif(sort_title,''),order_title) collate nocase"),
Entry("media_file.sort_album_name", "media_file", "coalesce(nullif(sort_album_name,''),order_album_name) collate nocase"),
Entry("media_file.sort_artist_name", "media_file", "coalesce(nullif(sort_artist_name,''),order_artist_name) collate nocase"),
Entry("media_file.path", "media_file", "path collate nocase"),
Entry("radio.name", "radio", "name collate nocase"),
Entry("user.user_name", "user", "user_name collate nocase"),
)
})
func checkIndexUsage(conn *sql.DB, table string, column string) error {
rows, err := conn.Query(fmt.Sprintf(`
explain query plan select * from %[1]s
where %[2]s = 'test'
order by %[2]s`, table, column))
if err != nil {
return err
}
defer rows.Close()
err = rows.Err()
if err != nil {
return err
}
if rows.Next() {
var dummy int
var detail string
err = rows.Scan(&dummy, &dummy, &dummy, &detail)
if err != nil {
return nil
}
if ok, _ := regexp.MatchString("SEARCH.*USING INDEX", detail); ok {
return nil
} else {
return fmt.Errorf("INDEX for '%s' not used: %s", column, detail)
}
}
return errors.New("no rows returned")
}
func checkCollation(conn *sql.DB, table string, column string) error {
rows, err := conn.Query(fmt.Sprintf("SELECT sql FROM sqlite_master WHERE type='table' AND tbl_name='%s'", table))
if err != nil {
return err
}
defer rows.Close()
err = rows.Err()
if err != nil {
return err
}
if rows.Next() {
var res string
err = rows.Scan(&res)
if err != nil {
return err
}
re := regexp.MustCompile(fmt.Sprintf(`(?i)\b%s\b.*varchar`, column))
if !re.MatchString(res) {
return fmt.Errorf("column '%s' not found in table '%s'", column, table)
}
re = regexp.MustCompile(fmt.Sprintf(`(?i)\b%s\b.*collate\s+NOCASE`, column))
if re.MatchString(res) {
return nil
}
} else {
return fmt.Errorf("table '%s' not found", table)
}
return fmt.Errorf("column '%s' in table '%s' does not have NOCASE collation", column, table)
}

View File

@@ -0,0 +1,4 @@
package persistence
// Definitions for testing private methods
var GetIndexKey = (*artistRepository).getIndexKey

View File

@@ -0,0 +1,216 @@
package persistence
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"time"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
type folderRepository struct {
sqlRepository
}
type dbFolder struct {
*model.Folder `structs:",flatten"`
ImageFiles string `structs:"-" json:"-"`
}
func (f *dbFolder) PostScan() error {
var err error
if f.ImageFiles != "" {
if err = json.Unmarshal([]byte(f.ImageFiles), &f.Folder.ImageFiles); err != nil {
return fmt.Errorf("parsing folder image files from db: %w", err)
}
}
return nil
}
func (f *dbFolder) PostMapArgs(args map[string]any) error {
if f.Folder.ImageFiles == nil {
args["image_files"] = "[]"
} else {
imgFiles, err := json.Marshal(f.Folder.ImageFiles)
if err != nil {
return fmt.Errorf("marshalling image files: %w", err)
}
args["image_files"] = string(imgFiles)
}
return nil
}
type dbFolders []dbFolder
func (fs dbFolders) toModels() []model.Folder {
return slice.Map(fs, func(f dbFolder) model.Folder { return *f.Folder })
}
func newFolderRepository(ctx context.Context, db dbx.Builder) model.FolderRepository {
r := &folderRepository{}
r.ctx = ctx
r.db = db
r.tableName = "folder"
return r
}
func (r folderRepository) selectFolder(options ...model.QueryOptions) SelectBuilder {
sql := r.newSelect(options...).Columns("folder.*", "library.path as library_path").
Join("library on library.id = folder.library_id")
return r.applyLibraryFilter(sql)
}
func (r folderRepository) Get(id string) (*model.Folder, error) {
sq := r.selectFolder().Where(Eq{"folder.id": id})
var res dbFolder
err := r.queryOne(sq, &res)
return res.Folder, err
}
func (r folderRepository) GetByPath(lib model.Library, path string) (*model.Folder, error) {
id := model.NewFolder(lib, path).ID
return r.Get(id)
}
func (r folderRepository) GetAll(opt ...model.QueryOptions) ([]model.Folder, error) {
sq := r.selectFolder(opt...)
var res dbFolders
err := r.queryAll(sq, &res)
return res.toModels(), err
}
func (r folderRepository) CountAll(opt ...model.QueryOptions) (int64, error) {
query := r.newSelect(opt...).Columns("count(*)")
query = r.applyLibraryFilter(query)
return r.count(query)
}
func (r folderRepository) GetFolderUpdateInfo(lib model.Library, targetPaths ...string) (map[string]model.FolderUpdateInfo, error) {
where := And{
Eq{"library_id": lib.ID},
Eq{"missing": false},
}
// If specific paths are requested, include those folders and all their descendants
if len(targetPaths) > 0 {
// Collect folder IDs for exact target folders and path conditions for descendants
folderIDs := make([]string, 0, len(targetPaths))
pathConditions := make(Or, 0, len(targetPaths)*2)
for _, targetPath := range targetPaths {
if targetPath == "" || targetPath == "." {
// Root path - include everything in this library
pathConditions = Or{}
folderIDs = nil
break
}
// Clean the path to normalize it. Paths stored in the folder table do not have leading/trailing slashes.
cleanPath := strings.TrimPrefix(targetPath, string(os.PathSeparator))
cleanPath = filepath.Clean(cleanPath)
// Include the target folder itself by ID
folderIDs = append(folderIDs, model.FolderID(lib, cleanPath))
// Include all descendants: folders whose path field equals or starts with the target path
// Note: Folder.Path is the directory path, so children have path = targetPath
pathConditions = append(pathConditions, Eq{"path": cleanPath})
pathConditions = append(pathConditions, Like{"path": cleanPath + "/%"})
}
// Combine conditions: exact folder IDs OR descendant path patterns
if len(folderIDs) > 0 {
where = append(where, Or{Eq{"id": folderIDs}, pathConditions})
} else if len(pathConditions) > 0 {
where = append(where, pathConditions)
}
}
sq := r.newSelect().Columns("id", "updated_at", "hash").Where(where)
var res []struct {
ID string
UpdatedAt time.Time
Hash string
}
err := r.queryAll(sq, &res)
if err != nil {
return nil, err
}
m := make(map[string]model.FolderUpdateInfo, len(res))
for _, f := range res {
m[f.ID] = model.FolderUpdateInfo{UpdatedAt: f.UpdatedAt, Hash: f.Hash}
}
return m, nil
}
func (r folderRepository) Put(f *model.Folder) error {
dbf := dbFolder{Folder: f}
_, err := r.put(dbf.ID, &dbf)
return err
}
func (r folderRepository) MarkMissing(missing bool, ids ...string) error {
log.Debug(r.ctx, "Marking folders as missing", "ids", ids, "missing", missing)
for chunk := range slices.Chunk(ids, 200) {
sq := Update(r.tableName).
Set("missing", missing).
Set("updated_at", time.Now()).
Where(Eq{"id": chunk})
_, err := r.executeSQL(sq)
if err != nil {
return err
}
}
return nil
}
func (r folderRepository) GetTouchedWithPlaylists() (model.FolderCursor, error) {
query := r.selectFolder().Where(And{
Eq{"missing": false},
Gt{"num_playlists": 0},
ConcatExpr("folder.updated_at > library.last_scan_at"),
})
cursor, err := queryWithStableResults[dbFolder](r.sqlRepository, query)
if err != nil {
return nil, err
}
return func(yield func(model.Folder, error) bool) {
for f, err := range cursor {
if !yield(*f.Folder, err) || err != nil {
return
}
}
}, nil
}
func (r folderRepository) purgeEmpty(libraryIDs ...int) error {
sq := Delete(r.tableName).Where(And{
Eq{"num_audio_files": 0},
Eq{"num_playlists": 0},
Eq{"image_files": "[]"},
ConcatExpr("id not in (select parent_id from folder)"),
ConcatExpr("id not in (select folder_id from media_file)"),
})
// If libraryIDs are specified, only purge folders from those libraries
if len(libraryIDs) > 0 {
sq = sq.Where(Eq{"library_id": libraryIDs})
}
c, err := r.executeSQL(sq)
if err != nil {
return fmt.Errorf("purging empty folders: %w", err)
}
if c > 0 {
log.Debug(r.ctx, "Purging empty folders", "totalDeleted", c)
}
return nil
}
var _ model.FolderRepository = (*folderRepository)(nil)

View File

@@ -0,0 +1,213 @@
package persistence
import (
"context"
"fmt"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
var _ = Describe("FolderRepository", func() {
var repo model.FolderRepository
var ctx context.Context
var conn *dbx.DB
var testLib, otherLib model.Library
BeforeEach(func() {
ctx = request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid"})
conn = GetDBXBuilder()
repo = newFolderRepository(ctx, conn)
// Use existing library ID 1 from test fixtures
libRepo := NewLibraryRepository(ctx, conn)
lib, err := libRepo.Get(1)
Expect(err).ToNot(HaveOccurred())
testLib = *lib
// Create a second library with its own folder to verify isolation
otherLib = model.Library{Name: "Other Library", Path: "/other/path"}
Expect(libRepo.Put(&otherLib)).To(Succeed())
})
AfterEach(func() {
// Clean up only test folders created by our tests (paths starting with "Test")
// This prevents interference with fixture data needed by other tests
_, _ = conn.NewQuery("DELETE FROM folder WHERE library_id = 1 AND path LIKE 'Test%'").Execute()
_, _ = conn.NewQuery(fmt.Sprintf("DELETE FROM library WHERE id = %d", otherLib.ID)).Execute()
})
Describe("GetFolderUpdateInfo", func() {
Context("with no target paths", func() {
It("returns all folders in the library", func() {
// Create test folders with unique names to avoid conflicts
folder1 := model.NewFolder(testLib, "TestGetLastUpdates/Folder1")
folder2 := model.NewFolder(testLib, "TestGetLastUpdates/Folder2")
err := repo.Put(folder1)
Expect(err).ToNot(HaveOccurred())
err = repo.Put(folder2)
Expect(err).ToNot(HaveOccurred())
otherFolder := model.NewFolder(otherLib, "TestOtherLib/Folder")
err = repo.Put(otherFolder)
Expect(err).ToNot(HaveOccurred())
// Query all folders (no target paths) - should only return folders from testLib
results, err := repo.GetFolderUpdateInfo(testLib)
Expect(err).ToNot(HaveOccurred())
// Should include folders from testLib
Expect(results).To(HaveKey(folder1.ID))
Expect(results).To(HaveKey(folder2.ID))
// Should NOT include folders from other library
Expect(results).ToNot(HaveKey(otherFolder.ID))
})
})
Context("with specific target paths", func() {
It("returns folder info for existing folders", func() {
// Create test folders with unique names
folder1 := model.NewFolder(testLib, "TestSpecific/Rock")
folder2 := model.NewFolder(testLib, "TestSpecific/Jazz")
folder3 := model.NewFolder(testLib, "TestSpecific/Classical")
err := repo.Put(folder1)
Expect(err).ToNot(HaveOccurred())
err = repo.Put(folder2)
Expect(err).ToNot(HaveOccurred())
err = repo.Put(folder3)
Expect(err).ToNot(HaveOccurred())
// Query specific paths
results, err := repo.GetFolderUpdateInfo(testLib, "TestSpecific/Rock", "TestSpecific/Classical")
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(2))
// Verify folder IDs are in results
Expect(results).To(HaveKey(folder1.ID))
Expect(results).To(HaveKey(folder3.ID))
Expect(results).ToNot(HaveKey(folder2.ID))
// Verify update info is populated
Expect(results[folder1.ID].UpdatedAt).ToNot(BeZero())
Expect(results[folder1.ID].Hash).To(Equal(folder1.Hash))
})
It("includes all child folders when querying parent", func() {
// Create a parent folder with multiple children
parent := model.NewFolder(testLib, "TestParent/Music")
child1 := model.NewFolder(testLib, "TestParent/Music/Rock/Queen")
child2 := model.NewFolder(testLib, "TestParent/Music/Jazz")
otherParent := model.NewFolder(testLib, "TestParent2/Music/Jazz")
Expect(repo.Put(parent)).To(Succeed())
Expect(repo.Put(child1)).To(Succeed())
Expect(repo.Put(child2)).To(Succeed())
// Query the parent folder - should return parent and all children
results, err := repo.GetFolderUpdateInfo(testLib, "TestParent/Music")
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(3))
Expect(results).To(HaveKey(parent.ID))
Expect(results).To(HaveKey(child1.ID))
Expect(results).To(HaveKey(child2.ID))
Expect(results).ToNot(HaveKey(otherParent.ID))
})
It("excludes children from other libraries", func() {
// Create parent in testLib
parent := model.NewFolder(testLib, "TestIsolation/Parent")
child := model.NewFolder(testLib, "TestIsolation/Parent/Child")
Expect(repo.Put(parent)).To(Succeed())
Expect(repo.Put(child)).To(Succeed())
// Create similar path in other library
otherParent := model.NewFolder(otherLib, "TestIsolation/Parent")
otherChild := model.NewFolder(otherLib, "TestIsolation/Parent/Child")
Expect(repo.Put(otherParent)).To(Succeed())
Expect(repo.Put(otherChild)).To(Succeed())
// Query should only return folders from testLib
results, err := repo.GetFolderUpdateInfo(testLib, "TestIsolation/Parent")
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(2))
Expect(results).To(HaveKey(parent.ID))
Expect(results).To(HaveKey(child.ID))
Expect(results).ToNot(HaveKey(otherParent.ID))
Expect(results).ToNot(HaveKey(otherChild.ID))
})
It("excludes missing children when querying parent", func() {
// Create parent and children, mark one as missing
parent := model.NewFolder(testLib, "TestMissingChild/Parent")
child1 := model.NewFolder(testLib, "TestMissingChild/Parent/Child1")
child2 := model.NewFolder(testLib, "TestMissingChild/Parent/Child2")
child2.Missing = true
Expect(repo.Put(parent)).To(Succeed())
Expect(repo.Put(child1)).To(Succeed())
Expect(repo.Put(child2)).To(Succeed())
// Query parent - should only return parent and non-missing child
results, err := repo.GetFolderUpdateInfo(testLib, "TestMissingChild/Parent")
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(2))
Expect(results).To(HaveKey(parent.ID))
Expect(results).To(HaveKey(child1.ID))
Expect(results).ToNot(HaveKey(child2.ID))
})
It("handles mix of existing and non-existing target paths", func() {
// Create folders for one path but not the other
existingParent := model.NewFolder(testLib, "TestMixed/Exists")
existingChild := model.NewFolder(testLib, "TestMixed/Exists/Child")
Expect(repo.Put(existingParent)).To(Succeed())
Expect(repo.Put(existingChild)).To(Succeed())
// Query both existing and non-existing paths
results, err := repo.GetFolderUpdateInfo(testLib, "TestMixed/Exists", "TestMixed/DoesNotExist")
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(2))
Expect(results).To(HaveKey(existingParent.ID))
Expect(results).To(HaveKey(existingChild.ID))
})
It("handles empty folder path as root", func() {
// Test querying for root folder without creating it (fixtures should have one)
rootFolderID := model.FolderID(testLib, ".")
results, err := repo.GetFolderUpdateInfo(testLib, "")
Expect(err).ToNot(HaveOccurred())
// Should return the root folder if it exists
if len(results) > 0 {
Expect(results).To(HaveKey(rootFolderID))
}
})
It("returns empty map for non-existent folders", func() {
results, err := repo.GetFolderUpdateInfo(testLib, "NonExistent/Path")
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
})
It("skips missing folders", func() {
// Create a folder and mark it as missing
folder := model.NewFolder(testLib, "TestMissing/Folder")
folder.Missing = true
err := repo.Put(folder)
Expect(err).ToNot(HaveOccurred())
results, err := repo.GetFolderUpdateInfo(testLib, "TestMissing/Folder")
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
})
})
})
})

View File

@@ -0,0 +1,52 @@
package persistence
import (
"context"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type genreRepository struct {
*baseTagRepository
}
func NewGenreRepository(ctx context.Context, db dbx.Builder) model.GenreRepository {
genreFilter := model.TagGenre
return &genreRepository{
baseTagRepository: newBaseTagRepository(ctx, db, &genreFilter),
}
}
func (r *genreRepository) selectGenre(opt ...model.QueryOptions) SelectBuilder {
return r.newSelect(opt...).Columns("tag.tag_value as name")
}
func (r *genreRepository) GetAll(opt ...model.QueryOptions) (model.Genres, error) {
sq := r.selectGenre(opt...)
res := model.Genres{}
err := r.queryAll(sq, &res)
return res, err
}
// Override ResourceRepository methods to return Genre objects instead of Tag objects
func (r *genreRepository) Read(id string) (interface{}, error) {
sel := r.selectGenre().Where(Eq{"tag.id": id})
var res model.Genre
err := r.queryOne(sel, &res)
return &res, err
}
func (r *genreRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *genreRepository) NewInstance() interface{} {
return &model.Genre{}
}
var _ model.GenreRepository = (*genreRepository)(nil)
var _ model.ResourceRepository = (*genreRepository)(nil)

View File

@@ -0,0 +1,329 @@
package persistence
import (
"context"
"slices"
"strings"
"github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("GenreRepository", func() {
var repo model.GenreRepository
var restRepo model.ResourceRepository
var tagRepo model.TagRepository
var ctx context.Context
BeforeEach(func() {
ctx = request.WithUser(GinkgoT().Context(), model.User{ID: "userid", UserName: "johndoe", IsAdmin: true})
genreRepo := NewGenreRepository(ctx, GetDBXBuilder())
repo = genreRepo
restRepo = genreRepo.(model.ResourceRepository)
tagRepo = NewTagRepository(ctx, GetDBXBuilder())
// Clear any existing tags to ensure test isolation
db := GetDBXBuilder()
_, err := db.NewQuery("DELETE FROM tag").Execute()
Expect(err).ToNot(HaveOccurred())
// Ensure library 1 exists and user has access to it
_, err = db.NewQuery("INSERT OR IGNORE INTO library (id, name, path, default_new_users) VALUES (1, 'Test Library', '/test', true)").Execute()
Expect(err).ToNot(HaveOccurred())
_, err = db.NewQuery("INSERT OR IGNORE INTO user_library (user_id, library_id) VALUES ('userid', 1)").Execute()
Expect(err).ToNot(HaveOccurred())
// Add comprehensive test data that covers all test scenarios
newTag := func(name, value string) model.Tag {
return model.Tag{ID: id.NewTagID(name, value), TagName: model.TagName(name), TagValue: value}
}
err = tagRepo.Add(1,
newTag("genre", "rock"),
newTag("genre", "pop"),
newTag("genre", "jazz"),
newTag("genre", "electronic"),
newTag("genre", "classical"),
newTag("genre", "ambient"),
newTag("genre", "techno"),
newTag("genre", "house"),
newTag("genre", "trance"),
newTag("genre", "Alternative Rock"),
newTag("genre", "Blues"),
newTag("genre", "Country"),
// These should not be counted as genres
newTag("mood", "happy"),
newTag("mood", "ambient"),
)
Expect(err).ToNot(HaveOccurred())
})
Describe("GetAll", func() {
It("should return all genres", func() {
genres, err := repo.GetAll()
Expect(err).ToNot(HaveOccurred())
Expect(genres).To(HaveLen(12))
// Verify that all returned items are genres (TagName = "genre")
genreNames := make([]string, len(genres))
for i, genre := range genres {
genreNames[i] = genre.Name
}
Expect(genreNames).To(ContainElement("rock"))
Expect(genreNames).To(ContainElement("pop"))
Expect(genreNames).To(ContainElement("jazz"))
// Should not contain mood tags
Expect(genreNames).ToNot(ContainElement("happy"))
})
It("should support query options", func() {
// Test with limiting results
genres, err := repo.GetAll(model.QueryOptions{Max: 1})
Expect(err).ToNot(HaveOccurred())
Expect(genres).To(HaveLen(1))
})
It("should handle empty results gracefully", func() {
// Clear all genre tags
_, err := GetDBXBuilder().NewQuery("DELETE FROM tag WHERE tag_name = 'genre'").Execute()
Expect(err).ToNot(HaveOccurred())
genres, err := repo.GetAll()
Expect(err).ToNot(HaveOccurred())
Expect(genres).To(BeEmpty())
})
Describe("filtering and sorting", func() {
It("should filter by name using like match", func() {
// Test filtering by partial name match using the "name" filter which maps to containsFilter("tag_value")
options := model.QueryOptions{
Filters: squirrel.Like{"tag_value": "%rock%"}, // Direct field access
}
genres, err := repo.GetAll(options)
Expect(err).ToNot(HaveOccurred())
Expect(genres).To(HaveLen(2)) // Should match "rock" and "Alternative Rock"
// Verify all returned genres contain "rock" in their name
for _, genre := range genres {
Expect(strings.ToLower(genre.Name)).To(ContainSubstring("rock"))
}
})
It("should sort by name in ascending order", func() {
// Test sorting by name with the fixed mapping
options := model.QueryOptions{
Filters: squirrel.Like{"tag_value": "%e%"}, // Should match genres containing "e"
Sort: "name",
}
genres, err := repo.GetAll(options)
Expect(err).ToNot(HaveOccurred())
Expect(genres).To(HaveLen(7))
Expect(slices.IsSortedFunc(genres, func(a, b model.Genre) int {
return strings.Compare(b.Name, a.Name) // Inverted to check descending order
}))
})
It("should sort by name in descending order", func() {
// Test sorting by name in descending order
options := model.QueryOptions{
Filters: squirrel.Like{"tag_value": "%e%"}, // Should match genres containing "e"
Sort: "name",
Order: "desc",
}
genres, err := repo.GetAll(options)
Expect(err).ToNot(HaveOccurred())
Expect(genres).To(HaveLen(7))
Expect(slices.IsSortedFunc(genres, func(a, b model.Genre) int {
return strings.Compare(a.Name, b.Name)
}))
})
})
})
Describe("Count", func() {
It("should return correct count of genres", func() {
count, err := restRepo.Count()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(12))) // We have 12 genre tags
})
It("should handle zero count", func() {
// Clear all genre tags
_, err := GetDBXBuilder().NewQuery("DELETE FROM tag WHERE tag_name = 'genre'").Execute()
Expect(err).ToNot(HaveOccurred())
count, err := restRepo.Count()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(BeZero())
})
It("should only count genre tags", func() {
// Add a non-genre tag
nonGenreTag := model.Tag{
ID: id.NewTagID("mood", "energetic"),
TagName: "mood",
TagValue: "energetic",
}
err := tagRepo.Add(1, nonGenreTag)
Expect(err).ToNot(HaveOccurred())
count, err := restRepo.Count()
Expect(err).ToNot(HaveOccurred())
// Count should not include the mood tag
Expect(count).To(Equal(int64(12))) // Should still be 12 genre tags
})
It("should filter by name using like match", func() {
// Test filtering by partial name match using the "name" filter which maps to containsFilter("tag_value")
options := rest.QueryOptions{
Filters: map[string]interface{}{"name": "%rock%"},
}
count, err := restRepo.Count(options)
Expect(err).ToNot(HaveOccurred())
Expect(count).To(BeNumerically("==", 2))
})
})
Describe("Read", func() {
It("should return existing genre", func() {
// Use one of the existing genres from our consolidated dataset
genreID := id.NewTagID("genre", "rock")
result, err := restRepo.Read(genreID)
Expect(err).ToNot(HaveOccurred())
genre := result.(*model.Genre)
Expect(genre.ID).To(Equal(genreID))
Expect(genre.Name).To(Equal("rock"))
})
It("should return error for non-existent genre", func() {
_, err := restRepo.Read("non-existent-id")
Expect(err).To(HaveOccurred())
})
It("should not return non-genre tags", func() {
moodID := id.NewTagID("mood", "happy") // This exists as a mood tag, not genre
_, err := restRepo.Read(moodID)
Expect(err).To(HaveOccurred()) // Should not find it as a genre
})
})
Describe("ReadAll", func() {
It("should return all genres through ReadAll", func() {
result, err := restRepo.ReadAll()
Expect(err).ToNot(HaveOccurred())
genres := result.(model.Genres)
Expect(genres).To(HaveLen(12)) // We have 12 genre tags
genreNames := make([]string, len(genres))
for i, genre := range genres {
genreNames[i] = genre.Name
}
// Check for some of our consolidated dataset genres
Expect(genreNames).To(ContainElement("rock"))
Expect(genreNames).To(ContainElement("pop"))
Expect(genreNames).To(ContainElement("jazz"))
})
It("should support rest query options", func() {
result, err := restRepo.ReadAll()
Expect(err).ToNot(HaveOccurred())
Expect(result).ToNot(BeNil())
})
})
Describe("Library Filtering", func() {
Context("Headless Processes (No User Context)", func() {
var headlessRepo model.GenreRepository
var headlessRestRepo model.ResourceRepository
BeforeEach(func() {
// Create a repository with no user context (headless)
headlessGenreRepo := NewGenreRepository(context.Background(), GetDBXBuilder())
headlessRepo = headlessGenreRepo
headlessRestRepo = headlessGenreRepo.(model.ResourceRepository)
// Add genres to different libraries
db := GetDBXBuilder()
_, err := db.NewQuery("INSERT OR IGNORE INTO library (id, name, path) VALUES (2, 'Test Library 2', '/test2')").Execute()
Expect(err).ToNot(HaveOccurred())
// Add tags to different libraries
newTag := func(name, value string) model.Tag {
return model.Tag{ID: id.NewTagID(name, value), TagName: model.TagName(name), TagValue: value}
}
err = tagRepo.Add(2, newTag("genre", "jazz"))
Expect(err).ToNot(HaveOccurred())
})
It("should see all genres from all libraries when no user is in context", func() {
// Headless processes should see all genres regardless of library
genres, err := headlessRepo.GetAll()
Expect(err).ToNot(HaveOccurred())
// Should see genres from all libraries
var genreNames []string
for _, genre := range genres {
genreNames = append(genreNames, genre.Name)
}
// Should include both rock (library 1) and jazz (library 2)
Expect(genreNames).To(ContainElement("rock"))
Expect(genreNames).To(ContainElement("jazz"))
})
It("should count all genres from all libraries when no user is in context", func() {
count, err := headlessRestRepo.Count()
Expect(err).ToNot(HaveOccurred())
// Should count all genres from all libraries
Expect(count).To(BeNumerically(">=", 2))
})
It("should allow headless processes to apply explicit library_id filters", func() {
// Filter by specific library
genres, err := headlessRestRepo.ReadAll(rest.QueryOptions{
Filters: map[string]interface{}{"library_id": 2},
})
Expect(err).ToNot(HaveOccurred())
genreList := genres.(model.Genres)
// Should see only genres from library 2
Expect(genreList).To(HaveLen(1))
Expect(genreList[0].Name).To(Equal("jazz"))
})
It("should get individual genres when no user is in context", func() {
// Get all genres first to find an ID
genres, err := headlessRepo.GetAll()
Expect(err).ToNot(HaveOccurred())
Expect(genres).ToNot(BeEmpty())
// Headless process should be able to get the genre
genre, err := headlessRestRepo.Read(genres[0].ID)
Expect(err).ToNot(HaveOccurred())
Expect(genre).ToNot(BeNil())
})
})
})
Describe("EntityName", func() {
It("should return correct entity name", func() {
name := restRepo.EntityName()
Expect(name).To(Equal("tag")) // Genre repository uses tag table
})
})
Describe("NewInstance", func() {
It("should return new genre instance", func() {
instance := restRepo.NewInstance()
Expect(instance).To(BeAssignableToTypeOf(&model.Genre{}))
})
})
})

92
persistence/helpers.go Normal file
View File

@@ -0,0 +1,92 @@
package persistence
import (
"database/sql/driver"
"fmt"
"regexp"
"strings"
"time"
"github.com/Masterminds/squirrel"
"github.com/fatih/structs"
)
type PostMapper interface {
PostMapArgs(map[string]any) error
}
func toSQLArgs(rec interface{}) (map[string]interface{}, error) {
m := structs.Map(rec)
for k, v := range m {
switch t := v.(type) {
case *time.Time:
if t != nil {
m[k] = *t
}
case driver.Valuer:
var err error
m[k], err = t.Value()
if err != nil {
return nil, err
}
}
}
if r, ok := rec.(PostMapper); ok {
err := r.PostMapArgs(m)
if err != nil {
return nil, err
}
}
return m, nil
}
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
func toSnakeCase(str string) string {
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
return strings.ToLower(snake)
}
var matchUnderscore = regexp.MustCompile("_([A-Za-z])")
func toCamelCase(str string) string {
return matchUnderscore.ReplaceAllStringFunc(str, func(s string) string {
return strings.ToUpper(strings.Replace(s, "_", "", -1))
})
}
func Exists(subTable string, cond squirrel.Sqlizer) existsCond {
return existsCond{subTable: subTable, cond: cond, not: false}
}
func NotExists(subTable string, cond squirrel.Sqlizer) existsCond {
return existsCond{subTable: subTable, cond: cond, not: true}
}
type existsCond struct {
subTable string
cond squirrel.Sqlizer
not bool
}
func (e existsCond) ToSql() (string, []interface{}, error) {
sql, args, err := e.cond.ToSql()
sql = fmt.Sprintf("exists (select 1 from %s where %s)", e.subTable, sql)
if e.not {
sql = "not " + sql
}
return sql, args, err
}
var sortOrderRegex = regexp.MustCompile(`order_([a-z_]+)`)
// Convert the order_* columns to an expression using sort_* columns. Example:
// sort_album_name -> (coalesce(nullif(sort_album_name,”),order_album_name) collate nocase)
// It finds order column names anywhere in the substring
func mapSortOrder(tableName, order string) string {
order = strings.ToLower(order)
repl := fmt.Sprintf("(coalesce(nullif(%[1]s.sort_$1,''),%[1]s.order_$1) collate nocase)", tableName)
return sortOrderRegex.ReplaceAllString(order, repl)
}

106
persistence/helpers_test.go Normal file
View File

@@ -0,0 +1,106 @@
package persistence
import (
"time"
"github.com/Masterminds/squirrel"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Helpers", func() {
Describe("toSnakeCase", func() {
It("converts camelCase", func() {
Expect(toSnakeCase("camelCase")).To(Equal("camel_case"))
})
It("converts PascalCase", func() {
Expect(toSnakeCase("PascalCase")).To(Equal("pascal_case"))
})
It("converts ALLCAPS", func() {
Expect(toSnakeCase("ALLCAPS")).To(Equal("allcaps"))
})
It("does not converts snake_case", func() {
Expect(toSnakeCase("snake_case")).To(Equal("snake_case"))
})
})
Describe("toCamelCase", func() {
It("converts snake_case", func() {
Expect(toCamelCase("snake_case")).To(Equal("snakeCase"))
})
It("converts PascalCase", func() {
Expect(toCamelCase("PascalCase")).To(Equal("PascalCase"))
})
It("converts camelCase", func() {
Expect(toCamelCase("camelCase")).To(Equal("camelCase"))
})
It("converts ALLCAPS", func() {
Expect(toCamelCase("ALLCAPS")).To(Equal("ALLCAPS"))
})
})
Describe("toSQLArgs", func() {
type Embed struct{}
type Model struct {
Embed `structs:"-"`
ID string `structs:"id" json:"id"`
AlbumId string `structs:"album_id" json:"albumId"`
PlayCount int `structs:"play_count" json:"playCount"`
UpdatedAt *time.Time `structs:"updated_at"`
CreatedAt time.Time `structs:"created_at"`
}
It("returns a map with snake_case keys", func() {
now := time.Now()
m := &Model{ID: "123", AlbumId: "456", CreatedAt: now, UpdatedAt: &now, PlayCount: 2}
args, err := toSQLArgs(m)
Expect(err).To(BeNil())
Expect(args).To(SatisfyAll(
HaveKeyWithValue("id", "123"),
HaveKeyWithValue("album_id", "456"),
HaveKeyWithValue("play_count", 2),
HaveKeyWithValue("updated_at", BeTemporally("~", now)),
HaveKeyWithValue("created_at", BeTemporally("~", now)),
Not(HaveKey("Embed")),
))
})
})
Describe("Exists", func() {
It("constructs the correct EXISTS query", func() {
e := Exists("album", squirrel.Eq{"id": 1})
sql, args, err := e.ToSql()
Expect(sql).To(Equal("exists (select 1 from album where id = ?)"))
Expect(args).To(ConsistOf(1))
Expect(err).To(BeNil())
})
})
Describe("NotExists", func() {
It("constructs the correct NOT EXISTS query", func() {
e := NotExists("artist", squirrel.ConcatExpr("id = artist_id"))
sql, args, err := e.ToSql()
Expect(sql).To(Equal("not exists (select 1 from artist where id = artist_id)"))
Expect(args).To(BeEmpty())
Expect(err).To(BeNil())
})
})
Describe("mapSortOrder", func() {
It("does not change the sort string if there are no order columns", func() {
sort := "album_name asc"
mapped := mapSortOrder("album", sort)
Expect(mapped).To(Equal(sort))
})
It("changes order columns to sort expression", func() {
sort := "ORDER_ALBUM_NAME asc"
mapped := mapSortOrder("album", sort)
Expect(mapped).To(Equal(`(coalesce(nullif(album.sort_album_name,''),album.order_album_name)` +
` collate nocase) asc`))
})
It("changes multiple order columns to sort expressions", func() {
sort := "compilation, order_title asc, order_album_artist_name desc, year desc"
mapped := mapSortOrder("album", sort)
Expect(mapped).To(Equal(`compilation, (coalesce(nullif(album.sort_title,''),album.order_title) collate nocase) asc,` +
` (coalesce(nullif(album.sort_album_artist_name,''),album.order_album_artist_name) collate nocase) desc, year desc`))
})
})
})

View File

@@ -0,0 +1,347 @@
package persistence
import (
"context"
"fmt"
"strconv"
"sync"
"time"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/run"
"github.com/pocketbase/dbx"
)
type libraryRepository struct {
sqlRepository
}
var (
libCache = map[int]string{}
libLock sync.RWMutex
)
func NewLibraryRepository(ctx context.Context, db dbx.Builder) model.LibraryRepository {
r := &libraryRepository{}
r.ctx = ctx
r.db = db
r.registerModel(&model.Library{}, nil)
return r
}
func (r *libraryRepository) Get(id int) (*model.Library, error) {
sq := r.newSelect().Columns("*").Where(Eq{"id": id})
var res model.Library
err := r.queryOne(sq, &res)
return &res, err
}
func (r *libraryRepository) GetPath(id int) (string, error) {
l := func() string {
libLock.RLock()
defer libLock.RUnlock()
if l, ok := libCache[id]; ok {
return l
}
return ""
}()
if l != "" {
return l, nil
}
libLock.Lock()
defer libLock.Unlock()
libs, err := r.GetAll()
if err != nil {
log.Error(r.ctx, "Error loading libraries from DB", err)
return "", err
}
for _, l := range libs {
libCache[l.ID] = l.Path
}
if l, ok := libCache[id]; ok {
return l, nil
} else {
return "", model.ErrNotFound
}
}
func (r *libraryRepository) Put(l *model.Library) error {
if l.ID == model.DefaultLibraryID {
currentLib, err := r.Get(1)
// if we are creating it, it's ok.
if err == nil { // it exists, so we are updating it
if currentLib.Path != l.Path {
return fmt.Errorf("%w: path for library with ID 1 cannot be changed", model.ErrValidation)
}
}
}
var err error
l.UpdatedAt = time.Now()
if l.ID == 0 {
// Insert with autoassigned ID
l.CreatedAt = time.Now()
err = r.db.Model(l).Insert()
} else {
// Try to update first
cols := map[string]any{
"name": l.Name,
"path": l.Path,
"remote_path": l.RemotePath,
"default_new_users": l.DefaultNewUsers,
"updated_at": l.UpdatedAt,
}
sq := Update(r.tableName).SetMap(cols).Where(Eq{"id": l.ID})
rowsAffected, updateErr := r.executeSQL(sq)
if updateErr != nil {
return updateErr
}
// If no rows were affected, the record doesn't exist, so insert it
if rowsAffected == 0 {
l.CreatedAt = time.Now()
l.UpdatedAt = time.Now()
err = r.db.Model(l).Insert()
}
}
if err != nil {
return err
}
// Auto-assign all libraries to all admin users
sql := Expr(`
INSERT INTO user_library (user_id, library_id)
SELECT u.id, l.id
FROM user u
CROSS JOIN library l
WHERE u.is_admin = true
ON CONFLICT (user_id, library_id) DO NOTHING;`,
)
if _, err = r.executeSQL(sql); err != nil {
return fmt.Errorf("failed to assign library to admin users: %w", err)
}
libLock.Lock()
defer libLock.Unlock()
libCache[l.ID] = l.Path
return nil
}
// TODO Remove this method when we have a proper UI to add libraries
// This is a temporary method to store the music folder path from the config in the DB
func (r *libraryRepository) StoreMusicFolder() error {
sq := Update(r.tableName).Set("path", conf.Server.MusicFolder).
Set("updated_at", time.Now()).
Where(Eq{"id": model.DefaultLibraryID})
_, err := r.executeSQL(sq)
if err != nil {
libLock.Lock()
defer libLock.Unlock()
libCache[model.DefaultLibraryID] = conf.Server.MusicFolder
}
return err
}
func (r *libraryRepository) AddArtist(id int, artistID string) error {
sq := Insert("library_artist").Columns("library_id", "artist_id").Values(id, artistID).
Suffix(`on conflict(library_id, artist_id) do nothing`)
_, err := r.executeSQL(sq)
if err != nil {
return err
}
return nil
}
func (r *libraryRepository) ScanBegin(id int, fullScan bool) error {
sq := Update(r.tableName).
Set("last_scan_started_at", time.Now()).
Set("full_scan_in_progress", fullScan).
Where(Eq{"id": id})
_, err := r.executeSQL(sq)
return err
}
func (r *libraryRepository) ScanEnd(id int) error {
sq := Update(r.tableName).
Set("last_scan_at", time.Now()).
Set("full_scan_in_progress", false).
Set("last_scan_started_at", time.Time{}).
Where(Eq{"id": id})
_, err := r.executeSQL(sq)
if err != nil {
return err
}
// https://www.sqlite.org/pragma.html#pragma_optimize
// Use mask 0x10000 to check table sizes without running ANALYZE
// Running ANALYZE can cause query planner issues with expression-based collation indexes
if conf.Server.DevOptimizeDB {
_, err = r.executeSQL(Expr("PRAGMA optimize=0x10000;"))
}
return err
}
func (r *libraryRepository) ScanInProgress() (bool, error) {
query := r.newSelect().Where(NotEq{"last_scan_started_at": time.Time{}})
count, err := r.count(query)
return count > 0, err
}
func (r *libraryRepository) RefreshStats(id int) error {
var songsRes, albumsRes, artistsRes, foldersRes, filesRes, missingRes struct{ Count int64 }
var sizeRes struct{ Sum int64 }
var durationRes struct{ Sum float64 }
err := run.Parallel(
func() error {
return r.queryOne(Select("count(*) as count").From("media_file").Where(Eq{"library_id": id, "missing": false}), &songsRes)
},
func() error {
return r.queryOne(Select("count(*) as count").From("album").Where(Eq{"library_id": id, "missing": false}), &albumsRes)
},
func() error {
return r.queryOne(Select("count(*) as count").From("library_artist la").
Join("artist a on la.artist_id = a.id").
Where(Eq{"la.library_id": id, "a.missing": false}), &artistsRes)
},
func() error {
return r.queryOne(Select("count(*) as count").From("folder").
Where(And{
Eq{"library_id": id, "missing": false},
Gt{"num_audio_files": 0},
}), &foldersRes)
},
func() error {
return r.queryOne(Select("ifnull(sum(num_audio_files + num_playlists + json_array_length(image_files)),0) as count").
From("folder").Where(Eq{"library_id": id, "missing": false}), &filesRes)
},
func() error {
return r.queryOne(Select("count(*) as count").From("media_file").Where(Eq{"library_id": id, "missing": true}), &missingRes)
},
func() error {
return r.queryOne(Select("ifnull(sum(size),0) as sum").From("album").Where(Eq{"library_id": id, "missing": false}), &sizeRes)
},
func() error {
return r.queryOne(Select("ifnull(sum(duration),0) as sum").From("album").Where(Eq{"library_id": id, "missing": false}), &durationRes)
},
)()
if err != nil {
return err
}
sq := Update(r.tableName).
Set("total_songs", songsRes.Count).
Set("total_albums", albumsRes.Count).
Set("total_artists", artistsRes.Count).
Set("total_folders", foldersRes.Count).
Set("total_files", filesRes.Count).
Set("total_missing_files", missingRes.Count).
Set("total_size", sizeRes.Sum).
Set("total_duration", durationRes.Sum).
Set("updated_at", time.Now()).
Where(Eq{"id": id})
_, err = r.executeSQL(sq)
return err
}
func (r *libraryRepository) Delete(id int) error {
if !loggedUser(r.ctx).IsAdmin {
return model.ErrNotAuthorized
}
if id == 1 {
return fmt.Errorf("%w: library with ID 1 cannot be deleted", model.ErrValidation)
}
err := r.delete(Eq{"id": id})
if err != nil {
return err
}
// Clear cache entry for this library only if DB operation was successful
libLock.Lock()
defer libLock.Unlock()
delete(libCache, id)
return nil
}
func (r *libraryRepository) GetAll(ops ...model.QueryOptions) (model.Libraries, error) {
sq := r.newSelect(ops...).Columns("*")
res := model.Libraries{}
err := r.queryAll(sq, &res)
return res, err
}
func (r *libraryRepository) CountAll(ops ...model.QueryOptions) (int64, error) {
sq := r.newSelect(ops...)
return r.count(sq)
}
// User-library association methods
func (r *libraryRepository) GetUsersWithLibraryAccess(libraryID int) (model.Users, error) {
sel := Select("u.*").
From("user u").
Join("user_library ul ON u.id = ul.user_id").
Where(Eq{"ul.library_id": libraryID}).
OrderBy("u.name")
var res model.Users
err := r.queryAll(sel, &res)
return res, err
}
// REST interface methods
func (r *libraryRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *libraryRepository) Read(id string) (interface{}, error) {
idInt, err := strconv.Atoi(id)
if err != nil {
log.Trace(r.ctx, "invalid library id: %s", id, err)
return nil, rest.ErrNotFound
}
return r.Get(idInt)
}
func (r *libraryRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *libraryRepository) EntityName() string {
return "library"
}
func (r *libraryRepository) NewInstance() interface{} {
return &model.Library{}
}
func (r *libraryRepository) Save(entity interface{}) (string, error) {
lib := entity.(*model.Library)
lib.ID = 0 // Reset ID to ensure we create a new library
err := r.Put(lib)
if err != nil {
return "", err
}
return strconv.Itoa(lib.ID), nil
}
func (r *libraryRepository) Update(id string, entity interface{}, cols ...string) error {
lib := entity.(*model.Library)
idInt, err := strconv.Atoi(id)
if err != nil {
return fmt.Errorf("invalid library ID: %s", id)
}
lib.ID = idInt
return r.Put(lib)
}
var _ model.LibraryRepository = (*libraryRepository)(nil)
var _ rest.Repository = (*libraryRepository)(nil)

View File

@@ -0,0 +1,203 @@
package persistence
import (
"context"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
var _ = Describe("LibraryRepository", func() {
var repo model.LibraryRepository
var ctx context.Context
var conn *dbx.DB
BeforeEach(func() {
ctx = request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid"})
conn = GetDBXBuilder()
repo = NewLibraryRepository(ctx, conn)
})
AfterEach(func() {
// Clean up test libraries (keep ID 1 which is the default library)
_, _ = conn.NewQuery("DELETE FROM library WHERE id > 1").Execute()
})
Describe("Put", func() {
Context("when ID is 0", func() {
It("inserts a new library with autoassigned ID", func() {
lib := &model.Library{
ID: 0,
Name: "Test Library",
Path: "/music/test",
}
err := repo.Put(lib)
Expect(err).ToNot(HaveOccurred())
Expect(lib.ID).To(BeNumerically(">", 0))
Expect(lib.CreatedAt).ToNot(BeZero())
Expect(lib.UpdatedAt).ToNot(BeZero())
// Verify it was inserted
savedLib, err := repo.Get(lib.ID)
Expect(err).ToNot(HaveOccurred())
Expect(savedLib.Name).To(Equal("Test Library"))
Expect(savedLib.Path).To(Equal("/music/test"))
})
})
Context("when ID is non-zero and record exists", func() {
It("updates the existing record", func() {
// First create a library
lib := &model.Library{
ID: 0,
Name: "Original Library",
Path: "/music/original",
}
err := repo.Put(lib)
Expect(err).ToNot(HaveOccurred())
originalID := lib.ID
originalCreatedAt := lib.CreatedAt
// Now update it
lib.Name = "Updated Library"
lib.Path = "/music/updated"
err = repo.Put(lib)
Expect(err).ToNot(HaveOccurred())
// Verify it was updated, not inserted
Expect(lib.ID).To(Equal(originalID))
Expect(lib.CreatedAt).To(Equal(originalCreatedAt))
Expect(lib.UpdatedAt).To(BeTemporally(">", originalCreatedAt))
// Verify the changes were saved
savedLib, err := repo.Get(lib.ID)
Expect(err).ToNot(HaveOccurred())
Expect(savedLib.Name).To(Equal("Updated Library"))
Expect(savedLib.Path).To(Equal("/music/updated"))
})
})
Context("when ID is non-zero but record doesn't exist", func() {
It("inserts a new record with the specified ID", func() {
lib := &model.Library{
ID: 999,
Name: "New Library with ID",
Path: "/music/new",
}
// Ensure the record doesn't exist
_, err := repo.Get(999)
Expect(err).To(HaveOccurred())
// Put should insert it
err = repo.Put(lib)
Expect(err).ToNot(HaveOccurred())
Expect(lib.ID).To(Equal(999))
Expect(lib.CreatedAt).ToNot(BeZero())
Expect(lib.UpdatedAt).ToNot(BeZero())
// Verify it was inserted with the correct ID
savedLib, err := repo.Get(999)
Expect(err).ToNot(HaveOccurred())
Expect(savedLib.ID).To(Equal(999))
Expect(savedLib.Name).To(Equal("New Library with ID"))
Expect(savedLib.Path).To(Equal("/music/new"))
})
})
})
It("refreshes stats", func() {
libBefore, err := repo.Get(1)
Expect(err).ToNot(HaveOccurred())
Expect(repo.RefreshStats(1)).To(Succeed())
libAfter, err := repo.Get(1)
Expect(err).ToNot(HaveOccurred())
Expect(libAfter.UpdatedAt).To(BeTemporally(">", libBefore.UpdatedAt))
var songsRes, albumsRes, artistsRes, foldersRes, filesRes, missingRes struct{ Count int64 }
var sizeRes struct{ Sum int64 }
var durationRes struct{ Sum float64 }
Expect(conn.NewQuery("select count(*) as count from media_file where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&songsRes)).To(Succeed())
Expect(conn.NewQuery("select count(*) as count from album where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&albumsRes)).To(Succeed())
Expect(conn.NewQuery("select count(*) as count from library_artist la join artist a on la.artist_id = a.id where la.library_id = {:id} and a.missing = 0").Bind(dbx.Params{"id": 1}).One(&artistsRes)).To(Succeed())
Expect(conn.NewQuery("select count(*) as count from folder where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&foldersRes)).To(Succeed())
Expect(conn.NewQuery("select ifnull(sum(num_audio_files + num_playlists + json_array_length(image_files)),0) as count from folder where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&filesRes)).To(Succeed())
Expect(conn.NewQuery("select count(*) as count from media_file where library_id = {:id} and missing = 1").Bind(dbx.Params{"id": 1}).One(&missingRes)).To(Succeed())
Expect(conn.NewQuery("select ifnull(sum(size),0) as sum from album where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&sizeRes)).To(Succeed())
Expect(conn.NewQuery("select ifnull(sum(duration),0) as sum from album where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&durationRes)).To(Succeed())
Expect(libAfter.TotalSongs).To(Equal(int(songsRes.Count)))
Expect(libAfter.TotalAlbums).To(Equal(int(albumsRes.Count)))
Expect(libAfter.TotalArtists).To(Equal(int(artistsRes.Count)))
Expect(libAfter.TotalFolders).To(Equal(int(foldersRes.Count)))
Expect(libAfter.TotalFiles).To(Equal(int(filesRes.Count)))
Expect(libAfter.TotalMissingFiles).To(Equal(int(missingRes.Count)))
Expect(libAfter.TotalSize).To(Equal(sizeRes.Sum))
Expect(libAfter.TotalDuration).To(Equal(durationRes.Sum))
})
Describe("ScanBegin and ScanEnd", func() {
var lib *model.Library
BeforeEach(func() {
lib = &model.Library{
ID: 0,
Name: "Test Scan Library",
Path: "/music/test-scan",
}
err := repo.Put(lib)
Expect(err).ToNot(HaveOccurred())
})
DescribeTable("ScanBegin",
func(fullScan bool, expectedFullScanInProgress bool) {
err := repo.ScanBegin(lib.ID, fullScan)
Expect(err).ToNot(HaveOccurred())
updatedLib, err := repo.Get(lib.ID)
Expect(err).ToNot(HaveOccurred())
Expect(updatedLib.LastScanStartedAt).ToNot(BeZero())
Expect(updatedLib.FullScanInProgress).To(Equal(expectedFullScanInProgress))
},
Entry("sets FullScanInProgress to true for full scan", true, true),
Entry("sets FullScanInProgress to false for quick scan", false, false),
)
Context("ScanEnd", func() {
BeforeEach(func() {
err := repo.ScanBegin(lib.ID, true)
Expect(err).ToNot(HaveOccurred())
})
It("sets LastScanAt and clears FullScanInProgress and LastScanStartedAt", func() {
err := repo.ScanEnd(lib.ID)
Expect(err).ToNot(HaveOccurred())
updatedLib, err := repo.Get(lib.ID)
Expect(err).ToNot(HaveOccurred())
Expect(updatedLib.LastScanAt).ToNot(BeZero())
Expect(updatedLib.FullScanInProgress).To(BeFalse())
Expect(updatedLib.LastScanStartedAt).To(BeZero())
})
It("sets LastScanAt to be after LastScanStartedAt", func() {
libBefore, err := repo.Get(lib.ID)
Expect(err).ToNot(HaveOccurred())
err = repo.ScanEnd(lib.ID)
Expect(err).ToNot(HaveOccurred())
libAfter, err := repo.Get(lib.ID)
Expect(err).ToNot(HaveOccurred())
Expect(libAfter.LastScanAt).To(BeTemporally(">=", libBefore.LastScanStartedAt))
})
})
})
})

View File

@@ -0,0 +1,477 @@
package persistence
import (
"context"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/google/uuid"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
type mediaFileRepository struct {
sqlRepository
ms *MeilisearchService
}
type dbMediaFile struct {
*model.MediaFile `structs:",flatten"`
Participants string `structs:"-" json:"-"`
Tags string `structs:"-" json:"-"`
// These are necessary to map the correct names (rg_*) to the correct fields (RG*)
// without using `db` struct tags in the model.MediaFile struct
RgAlbumGain *float64 `structs:"-" json:"-"`
RgAlbumPeak *float64 `structs:"-" json:"-"`
RgTrackGain *float64 `structs:"-" json:"-"`
RgTrackPeak *float64 `structs:"-" json:"-"`
}
func (m *dbMediaFile) PostScan() error {
m.RGTrackGain = m.RgTrackGain
m.RGTrackPeak = m.RgTrackPeak
m.RGAlbumGain = m.RgAlbumGain
m.RGAlbumPeak = m.RgAlbumPeak
var err error
m.MediaFile.Participants, err = unmarshalParticipants(m.Participants)
if err != nil {
return fmt.Errorf("parsing media_file from db: %w", err)
}
if m.Tags != "" {
m.MediaFile.Tags, err = unmarshalTags(m.Tags)
if err != nil {
return fmt.Errorf("parsing media_file from db: %w", err)
}
m.Genre, m.Genres = m.MediaFile.Tags.ToGenres()
}
return nil
}
func (m *dbMediaFile) PostMapArgs(args map[string]any) error {
fullText := []string{m.FullTitle(), m.Album, m.Artist, m.AlbumArtist,
m.SortTitle, m.SortAlbumName, m.SortArtistName, m.SortAlbumArtistName, m.DiscSubtitle}
fullText = append(fullText, m.MediaFile.Participants.AllNames()...)
args["full_text"] = formatFullText(fullText...)
args["tags"] = marshalTags(m.MediaFile.Tags)
args["participants"] = marshalParticipants(m.MediaFile.Participants)
return nil
}
type dbMediaFiles []dbMediaFile
func (m dbMediaFiles) toModels() model.MediaFiles {
return slice.Map(m, func(mf dbMediaFile) model.MediaFile { return *mf.MediaFile })
}
func NewMediaFileRepository(ctx context.Context, db dbx.Builder, ms *MeilisearchService) model.MediaFileRepository {
r := &mediaFileRepository{}
r.ctx = ctx
r.db = db
r.ms = ms
r.tableName = "media_file"
r.registerModel(&model.MediaFile{}, mediaFileFilter())
r.setSortMappings(map[string]string{
"title": "order_title",
"artist": "order_artist_name, order_album_name, release_date, disc_number, track_number",
"album_artist": "order_album_artist_name, order_album_name, release_date, disc_number, track_number",
"album": "order_album_name, album_id, disc_number, track_number, order_artist_name, title",
"random": "random",
"created_at": "media_file.created_at",
"recently_added": mediaFileRecentlyAddedSort(),
"starred_at": "starred, starred_at",
"rated_at": "rating, rated_at",
})
return r
}
var mediaFileFilter = sync.OnceValue(func() map[string]filterFunc {
filters := map[string]filterFunc{
"id": idFilter("media_file"),
"title": fullTextFilter("media_file", "mbz_recording_id", "mbz_release_track_id"),
"starred": booleanFilter,
"genre_id": tagIDFilter,
"missing": booleanFilter,
"artists_id": artistFilter,
"library_id": libraryIdFilter,
}
// Add all album tags as filters
for tag := range model.TagMappings() {
if _, exists := filters[string(tag)]; !exists {
filters[string(tag)] = tagIDFilter
}
}
return filters
})
func mediaFileRecentlyAddedSort() string {
if conf.Server.RecentlyAddedByModTime {
return "media_file.updated_at"
}
return "media_file.created_at"
}
func (r *mediaFileRepository) CountAll(options ...model.QueryOptions) (int64, error) {
query := r.newSelect()
query = r.withAnnotation(query, "media_file.id")
query = r.applyLibraryFilter(query)
return r.count(query, options...)
}
func (r *mediaFileRepository) Exists(id string) (bool, error) {
return r.exists(Eq{"media_file.id": id})
}
func (r *mediaFileRepository) Put(m *model.MediaFile) error {
m.CreatedAt = time.Now()
id, err := r.putByMatch(Eq{"path": m.Path, "library_id": m.LibraryID}, m.ID, &dbMediaFile{MediaFile: m})
if err != nil {
return err
}
m.ID = id
err = r.updateParticipants(m.ID, m.Participants)
if err != nil {
return err
}
if r.ms != nil {
r.ms.IndexMediaFile(m)
}
return nil
}
func (r *mediaFileRepository) selectMediaFile(options ...model.QueryOptions) SelectBuilder {
sql := r.newSelect(options...).Columns("media_file.*", "library.path as library_path", "library.name as library_name").
LeftJoin("library on media_file.library_id = library.id")
sql = r.withAnnotation(sql, "media_file.id")
sql = r.withBookmark(sql, "media_file.id")
return r.applyLibraryFilter(sql)
}
func (r *mediaFileRepository) Get(id string) (*model.MediaFile, error) {
res, err := r.GetAll(model.QueryOptions{Filters: Eq{"media_file.id": id}})
if err != nil {
return nil, err
}
if len(res) == 0 {
return nil, model.ErrNotFound
}
return &res[0], nil
}
func (r *mediaFileRepository) GetWithParticipants(id string) (*model.MediaFile, error) {
m, err := r.Get(id)
if err != nil {
return nil, err
}
m.Participants, err = r.getParticipants(m)
return m, err
}
func (r *mediaFileRepository) GetAll(options ...model.QueryOptions) (model.MediaFiles, error) {
sq := r.selectMediaFile(options...)
var res dbMediaFiles
err := r.queryAll(sq, &res, options...)
if err != nil {
return nil, err
}
return res.toModels(), nil
}
func (r *mediaFileRepository) GetCursor(options ...model.QueryOptions) (model.MediaFileCursor, error) {
sq := r.selectMediaFile(options...)
cursor, err := queryWithStableResults[dbMediaFile](r.sqlRepository, sq)
if err != nil {
return nil, err
}
return func(yield func(model.MediaFile, error) bool) {
for m, err := range cursor {
if m.MediaFile == nil {
yield(model.MediaFile{}, fmt.Errorf("unexpected nil mediafile: %v", m))
return
}
if !yield(*m.MediaFile, err) || err != nil {
return
}
}
}, nil
}
// FindByPaths finds media files by their paths.
// The paths can be library-qualified (format: "libraryID:path") or unqualified ("path").
// Library-qualified paths search within the specified library, while unqualified paths
// search across all libraries for backward compatibility.
func (r *mediaFileRepository) FindByPaths(paths []string) (model.MediaFiles, error) {
query := Or{}
for _, path := range paths {
parts := strings.SplitN(path, ":", 2)
if len(parts) == 2 {
// Library-qualified path: "libraryID:path"
libraryID, err := strconv.Atoi(parts[0])
if err != nil {
// Invalid format, skip
continue
}
relativePath := parts[1]
query = append(query, And{
Eq{"path collate nocase": relativePath},
Eq{"library_id": libraryID},
})
} else {
// Unqualified path: search across all libraries
query = append(query, Eq{"path collate nocase": path})
}
}
if len(query) == 0 {
return model.MediaFiles{}, nil
}
sel := r.newSelect().Columns("*").Where(query)
var res dbMediaFiles
if err := r.queryAll(sel, &res); err != nil {
return nil, err
}
return res.toModels(), nil
}
func (r *mediaFileRepository) Delete(id string) error {
err := r.delete(Eq{"id": id})
if err == nil && r.ms != nil {
r.ms.DeleteMediaFile(id)
}
return err
}
func (r *mediaFileRepository) DeleteAllMissing() (int64, error) {
user := loggedUser(r.ctx)
if !user.IsAdmin {
return 0, rest.ErrPermissionDenied
}
del := Delete(r.tableName).Where(Eq{"missing": true})
var ids []string
if r.ms != nil {
_ = r.db.Select("id").From(r.tableName).Where(dbx.HashExp{"missing": true}).Column(&ids)
}
c, err := r.executeSQL(del)
if err == nil && r.ms != nil && len(ids) > 0 {
r.ms.DeleteMediaFiles(ids)
}
return c, err
}
func (r *mediaFileRepository) DeleteMissing(ids []string) error {
user := loggedUser(r.ctx)
if !user.IsAdmin {
return rest.ErrPermissionDenied
}
err := r.delete(
And{
Eq{"missing": true},
Eq{"id": ids},
},
)
if err == nil && r.ms != nil {
r.ms.DeleteMediaFiles(ids)
}
return err
}
func (r *mediaFileRepository) MarkMissing(missing bool, mfs ...*model.MediaFile) error {
ids := slice.SeqFunc(mfs, func(m *model.MediaFile) string { return m.ID })
for chunk := range slice.CollectChunks(ids, 200) {
upd := Update(r.tableName).
Set("missing", missing).
Set("updated_at", time.Now()).
Where(Eq{"id": chunk})
c, err := r.executeSQL(upd)
if err != nil || c == 0 {
log.Error(r.ctx, "Error setting mediafile missing flag", "ids", chunk, err)
return err
}
log.Debug(r.ctx, "Marked missing mediafiles", "total", c, "ids", chunk)
}
return nil
}
func (r *mediaFileRepository) MarkMissingByFolder(missing bool, folderIDs ...string) error {
for chunk := range slices.Chunk(folderIDs, 200) {
upd := Update(r.tableName).
Set("missing", missing).
Set("updated_at", time.Now()).
Where(And{
Eq{"folder_id": chunk},
Eq{"missing": !missing},
})
c, err := r.executeSQL(upd)
if err != nil {
log.Error(r.ctx, "Error setting mediafile missing flag", "folderIDs", chunk, err)
return err
}
log.Debug(r.ctx, "Marked missing mediafiles from missing folders", "total", c, "folders", chunk)
}
return nil
}
// GetMissingAndMatching returns all mediafiles that are missing and their potential matches (comparing PIDs)
// that were added/updated after the last scan started. The result is ordered by PID.
// It does not need to load bookmarks, annotations and participants, as they are not used by the scanner.
func (r *mediaFileRepository) GetMissingAndMatching(libId int) (model.MediaFileCursor, error) {
subQ := r.newSelect().Columns("pid").
Where(And{
Eq{"media_file.missing": true},
Eq{"library_id": libId},
})
subQText, subQArgs, err := subQ.PlaceholderFormat(Question).ToSql()
if err != nil {
return nil, err
}
sel := r.newSelect().Columns("media_file.*", "library.path as library_path", "library.name as library_name").
LeftJoin("library on media_file.library_id = library.id").
Where("pid in ("+subQText+")", subQArgs...).
Where(Or{
Eq{"missing": true},
ConcatExpr("media_file.created_at > library.last_scan_started_at"),
}).
OrderBy("pid")
cursor, err := queryWithStableResults[dbMediaFile](r.sqlRepository, sel)
if err != nil {
return nil, err
}
return func(yield func(model.MediaFile, error) bool) {
for m, err := range cursor {
if !yield(*m.MediaFile, err) || err != nil {
return
}
}
}, nil
}
// FindRecentFilesByMBZTrackID finds recently added files by MusicBrainz Track ID in other libraries
func (r *mediaFileRepository) FindRecentFilesByMBZTrackID(missing model.MediaFile, since time.Time) (model.MediaFiles, error) {
sel := r.selectMediaFile().Where(And{
NotEq{"media_file.library_id": missing.LibraryID},
Eq{"media_file.mbz_release_track_id": missing.MbzReleaseTrackID},
NotEq{"media_file.mbz_release_track_id": ""}, // Exclude empty MBZ Track IDs
Eq{"media_file.suffix": missing.Suffix},
Gt{"media_file.created_at": since},
Eq{"media_file.missing": false},
}).OrderBy("media_file.created_at DESC")
var res dbMediaFiles
err := r.queryAll(sel, &res)
if err != nil {
return nil, err
}
return res.toModels(), nil
}
// FindRecentFilesByProperties finds recently added files by intrinsic properties in other libraries
func (r *mediaFileRepository) FindRecentFilesByProperties(missing model.MediaFile, since time.Time) (model.MediaFiles, error) {
sel := r.selectMediaFile().Where(And{
NotEq{"media_file.library_id": missing.LibraryID},
Eq{"media_file.title": missing.Title},
Eq{"media_file.size": missing.Size},
Eq{"media_file.suffix": missing.Suffix},
Eq{"media_file.disc_number": missing.DiscNumber},
Eq{"media_file.track_number": missing.TrackNumber},
Eq{"media_file.album": missing.Album},
Eq{"media_file.mbz_release_track_id": ""}, // Exclude files with MBZ Track ID
Gt{"media_file.created_at": since},
Eq{"media_file.missing": false},
}).OrderBy("media_file.created_at DESC")
var res dbMediaFiles
err := r.queryAll(sel, &res)
if err != nil {
return nil, err
}
return res.toModels(), nil
}
func (r *mediaFileRepository) Search(q string, offset int, size int, options ...model.QueryOptions) (model.MediaFiles, error) {
var res dbMediaFiles
if uuid.Validate(q) == nil {
err := r.searchByMBID(r.selectMediaFile(options...), q, []string{"mbz_recording_id", "mbz_release_track_id"}, &res)
if err != nil {
return nil, fmt.Errorf("searching media_file by MBID %q: %w", q, err)
}
} else {
if r.ms != nil {
ids, err := r.ms.Search("mediafiles", q, offset, size)
if err == nil {
if len(ids) == 0 {
return model.MediaFiles{}, nil
}
// Fetch matching media files from the database
// We need to fetch all fields to return complete objects
mfs, err := r.GetAll(model.QueryOptions{Filters: Eq{"media_file.id": ids}})
if err != nil {
return nil, fmt.Errorf("fetching media_files from meilisearch ids: %w", err)
}
// Reorder results to match Meilisearch order
idMap := make(map[string]model.MediaFile, len(mfs))
for _, mf := range mfs {
idMap[mf.ID] = mf
}
sorted := make(model.MediaFiles, 0, len(mfs))
for _, id := range ids {
if mf, ok := idMap[id]; ok {
sorted = append(sorted, mf)
}
}
return sorted, nil
}
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
}
err := r.doSearch(r.selectMediaFile(options...), q, offset, size, &res, "media_file.rowid", "title")
if err != nil {
return nil, fmt.Errorf("searching media_file by query %q: %w", q, err)
}
}
return res.toModels(), nil
}
func (r *mediaFileRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *mediaFileRepository) Read(id string) (interface{}, error) {
return r.Get(id)
}
func (r *mediaFileRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
if len(options) > 0 && r.ms != nil {
if title, ok := options[0].Filters["title"].(string); ok && title != "" {
ids, err := r.ms.Search("mediafiles", title, 0, 10000)
if err == nil {
log.Debug(r.ctx, "Meilisearch found matches", "count", len(ids), "query", title)
delete(options[0].Filters, "title")
options[0].Filters["id"] = ids
} else {
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
}
}
}
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *mediaFileRepository) EntityName() string {
return "mediafile"
}
func (r *mediaFileRepository) NewInstance() interface{} {
return &model.MediaFile{}
}
var _ model.MediaFileRepository = (*mediaFileRepository)(nil)
var _ model.ResourceRepository = (*mediaFileRepository)(nil)

View File

@@ -0,0 +1,413 @@
package persistence
import (
"context"
"time"
"github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/conf/configtest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
var _ = Describe("MediaRepository", func() {
var mr model.MediaFileRepository
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid"})
mr = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
})
It("gets mediafile from the DB", func() {
actual, err := mr.Get("1004")
Expect(err).ToNot(HaveOccurred())
actual.CreatedAt = time.Time{}
Expect(actual).To(Equal(&songAntenna))
})
It("returns ErrNotFound", func() {
_, err := mr.Get("56")
Expect(err).To(MatchError(model.ErrNotFound))
})
It("counts the number of mediafiles in the DB", func() {
Expect(mr.CountAll()).To(Equal(int64(10)))
})
It("returns songs ordered by lyrics with a specific title/artist", func() {
// attempt to mimic filters.SongsByArtistTitleWithLyricsFirst, except we want all items
results, err := mr.GetAll(model.QueryOptions{
Sort: "lyrics, updated_at",
Order: "desc",
Filters: squirrel.And{
squirrel.Eq{"title": "Antenna"},
squirrel.Or{
Exists("json_tree(participants, '$.albumartist')", squirrel.Eq{"value": "Kraftwerk"}),
Exists("json_tree(participants, '$.artist')", squirrel.Eq{"value": "Kraftwerk"}),
},
},
})
Expect(err).To(BeNil())
Expect(results).To(HaveLen(3))
Expect(results[0].Lyrics).To(Equal(`[{"lang":"xxx","line":[{"value":"This is a set of lyrics"}],"synced":false}]`))
for _, item := range results[1:] {
Expect(item.Lyrics).To(Equal("[]"))
Expect(item.Title).To(Equal("Antenna"))
Expect(item.Participants[model.RoleArtist][0].Name).To(Equal("Kraftwerk"))
}
})
It("checks existence of mediafiles in the DB", func() {
Expect(mr.Exists(songAntenna.ID)).To(BeTrue())
Expect(mr.Exists("666")).To(BeFalse())
})
It("delete tracks by id", func() {
newID := id.NewRandom()
Expect(mr.Put(&model.MediaFile{LibraryID: 1, ID: newID})).To(Succeed())
Expect(mr.Delete(newID)).To(Succeed())
_, err := mr.Get(newID)
Expect(err).To(MatchError(model.ErrNotFound))
})
It("deletes all missing files", func() {
new1 := model.MediaFile{ID: id.NewRandom(), LibraryID: 1}
new2 := model.MediaFile{ID: id.NewRandom(), LibraryID: 1}
Expect(mr.Put(&new1)).To(Succeed())
Expect(mr.Put(&new2)).To(Succeed())
Expect(mr.MarkMissing(true, &new1, &new2)).To(Succeed())
adminCtx := request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", IsAdmin: true})
adminRepo := NewMediaFileRepository(adminCtx, GetDBXBuilder(), nil)
// Ensure the files are marked as missing and we have 2 of them
count, err := adminRepo.CountAll(model.QueryOptions{Filters: squirrel.Eq{"missing": true}})
Expect(count).To(BeNumerically("==", 2))
Expect(err).ToNot(HaveOccurred())
count, err = adminRepo.DeleteAllMissing()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(BeNumerically("==", 2))
_, err = mr.Get(new1.ID)
Expect(err).To(MatchError(model.ErrNotFound))
_, err = mr.Get(new2.ID)
Expect(err).To(MatchError(model.ErrNotFound))
})
Context("Annotations", func() {
It("increments play count when the tracks does not have annotations", func() {
id := "incplay.firsttime"
Expect(mr.Put(&model.MediaFile{LibraryID: 1, ID: id})).To(BeNil())
playDate := time.Now()
Expect(mr.IncPlayCount(id, playDate)).To(BeNil())
mf, err := mr.Get(id)
Expect(err).To(BeNil())
Expect(mf.PlayDate.Unix()).To(Equal(playDate.Unix()))
Expect(mf.PlayCount).To(Equal(int64(1)))
})
It("preserves play date if and only if provided date is older", func() {
id := "incplay.playdate"
Expect(mr.Put(&model.MediaFile{LibraryID: 1, ID: id})).To(BeNil())
playDate := time.Now()
Expect(mr.IncPlayCount(id, playDate)).To(BeNil())
mf, err := mr.Get(id)
Expect(err).To(BeNil())
Expect(mf.PlayDate.Unix()).To(Equal(playDate.Unix()))
Expect(mf.PlayCount).To(Equal(int64(1)))
playDateLate := playDate.AddDate(0, 0, 1)
Expect(mr.IncPlayCount(id, playDateLate)).To(BeNil())
mf, err = mr.Get(id)
Expect(err).To(BeNil())
Expect(mf.PlayDate.Unix()).To(Equal(playDateLate.Unix()))
Expect(mf.PlayCount).To(Equal(int64(2)))
playDateEarly := playDate.AddDate(0, 0, -1)
Expect(mr.IncPlayCount(id, playDateEarly)).To(BeNil())
mf, err = mr.Get(id)
Expect(err).To(BeNil())
Expect(mf.PlayDate.Unix()).To(Equal(playDateLate.Unix()))
Expect(mf.PlayCount).To(Equal(int64(3)))
})
It("increments play count on newly starred items", func() {
id := "star.incplay"
Expect(mr.Put(&model.MediaFile{LibraryID: 1, ID: id})).To(BeNil())
Expect(mr.SetStar(true, id)).To(BeNil())
playDate := time.Now()
Expect(mr.IncPlayCount(id, playDate)).To(BeNil())
mf, err := mr.Get(id)
Expect(err).To(BeNil())
Expect(mf.PlayDate.Unix()).To(Equal(playDate.Unix()))
Expect(mf.PlayCount).To(Equal(int64(1)))
})
})
Context("Sort options", func() {
Context("recently_added sort", func() {
var testMediaFiles []model.MediaFile
BeforeEach(func() {
DeferCleanup(configtest.SetupConfig())
// Create test media files with specific timestamps
testMediaFiles = []model.MediaFile{
{
ID: id.NewRandom(),
LibraryID: 1,
Title: "Old Song",
Path: "/test/old.mp3",
},
{
ID: id.NewRandom(),
LibraryID: 1,
Title: "Middle Song",
Path: "/test/middle.mp3",
},
{
ID: id.NewRandom(),
LibraryID: 1,
Title: "New Song",
Path: "/test/new.mp3",
},
}
// Insert test data first
for i := range testMediaFiles {
Expect(mr.Put(&testMediaFiles[i])).To(Succeed())
}
// Then manually update timestamps using direct SQL to bypass the repository logic
db := GetDBXBuilder()
// Set specific timestamps for testing
oldTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
middleTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
newTime := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
// Update "Old Song": created long ago, updated recently
_, err := db.Update("media_file",
map[string]interface{}{
"created_at": oldTime,
"updated_at": newTime,
},
dbx.HashExp{"id": testMediaFiles[0].ID}).Execute()
Expect(err).ToNot(HaveOccurred())
// Update "Middle Song": created and updated at the same middle time
_, err = db.Update("media_file",
map[string]interface{}{
"created_at": middleTime,
"updated_at": middleTime,
},
dbx.HashExp{"id": testMediaFiles[1].ID}).Execute()
Expect(err).ToNot(HaveOccurred())
// Update "New Song": created recently, updated long ago
_, err = db.Update("media_file",
map[string]interface{}{
"created_at": newTime,
"updated_at": oldTime,
},
dbx.HashExp{"id": testMediaFiles[2].ID}).Execute()
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
// Clean up test data
for _, mf := range testMediaFiles {
_ = mr.Delete(mf.ID)
}
})
When("RecentlyAddedByModTime is false", func() {
var testRepo model.MediaFileRepository
BeforeEach(func() {
conf.Server.RecentlyAddedByModTime = false
// Create repository AFTER setting config
ctx := log.NewContext(GinkgoT().Context())
ctx = request.WithUser(ctx, model.User{ID: "userid"})
testRepo = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
})
It("sorts by created_at", func() {
// Get results sorted by recently_added (should use created_at)
results, err := testRepo.GetAll(model.QueryOptions{
Sort: "recently_added",
Order: "desc",
Filters: squirrel.Eq{"media_file.id": []string{testMediaFiles[0].ID, testMediaFiles[1].ID, testMediaFiles[2].ID}},
})
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(3))
// Verify sorting by created_at (newest first in descending order)
Expect(results[0].Title).To(Equal("New Song")) // created 2022
Expect(results[1].Title).To(Equal("Middle Song")) // created 2021
Expect(results[2].Title).To(Equal("Old Song")) // created 2020
})
It("sorts in ascending order when specified", func() {
// Get results sorted by recently_added in ascending order
results, err := testRepo.GetAll(model.QueryOptions{
Sort: "recently_added",
Order: "asc",
Filters: squirrel.Eq{"media_file.id": []string{testMediaFiles[0].ID, testMediaFiles[1].ID, testMediaFiles[2].ID}},
})
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(3))
// Verify sorting by created_at (oldest first)
Expect(results[0].Title).To(Equal("Old Song")) // created 2020
Expect(results[1].Title).To(Equal("Middle Song")) // created 2021
Expect(results[2].Title).To(Equal("New Song")) // created 2022
})
})
When("RecentlyAddedByModTime is true", func() {
var testRepo model.MediaFileRepository
BeforeEach(func() {
conf.Server.RecentlyAddedByModTime = true
// Create repository AFTER setting config
ctx := log.NewContext(GinkgoT().Context())
ctx = request.WithUser(ctx, model.User{ID: "userid"})
testRepo = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
})
It("sorts by updated_at", func() {
// Get results sorted by recently_added (should use updated_at)
results, err := testRepo.GetAll(model.QueryOptions{
Sort: "recently_added",
Order: "desc",
Filters: squirrel.Eq{"media_file.id": []string{testMediaFiles[0].ID, testMediaFiles[1].ID, testMediaFiles[2].ID}},
})
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(3))
// Verify sorting by updated_at (newest first in descending order)
Expect(results[0].Title).To(Equal("Old Song")) // updated 2022
Expect(results[1].Title).To(Equal("Middle Song")) // updated 2021
Expect(results[2].Title).To(Equal("New Song")) // updated 2020
})
})
})
})
Describe("Search", func() {
Context("text search", func() {
It("finds media files by title", func() {
results, err := mr.Search("Antenna", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(3)) // songAntenna, songAntennaWithLyrics, songAntenna2
for _, result := range results {
Expect(result.Title).To(Equal("Antenna"))
}
})
It("finds media files case insensitively", func() {
results, err := mr.Search("antenna", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(3))
for _, result := range results {
Expect(result.Title).To(Equal("Antenna"))
}
})
It("returns empty result when no matches found", func() {
results, err := mr.Search("nonexistent", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
})
})
Context("MBID search", func() {
var mediaFileWithMBID model.MediaFile
var raw *mediaFileRepository
BeforeEach(func() {
raw = mr.(*mediaFileRepository)
// Create a test media file with MBID
mediaFileWithMBID = model.MediaFile{
ID: "test-mbid-mediafile",
Title: "Test MBID MediaFile",
MbzRecordingID: "550e8400-e29b-41d4-a716-446655440020", // Valid UUID v4
MbzReleaseTrackID: "550e8400-e29b-41d4-a716-446655440021", // Valid UUID v4
LibraryID: 1,
Path: "/test/path/test.mp3",
}
// Insert the test media file into the database
err := mr.Put(&mediaFileWithMBID)
Expect(err).ToNot(HaveOccurred())
})
AfterEach(func() {
// Clean up test data using direct SQL
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": mediaFileWithMBID.ID}))
})
It("finds media file by mbz_recording_id", func() {
results, err := mr.Search("550e8400-e29b-41d4-a716-446655440020", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(1))
Expect(results[0].ID).To(Equal("test-mbid-mediafile"))
Expect(results[0].Title).To(Equal("Test MBID MediaFile"))
})
It("finds media file by mbz_release_track_id", func() {
results, err := mr.Search("550e8400-e29b-41d4-a716-446655440021", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(HaveLen(1))
Expect(results[0].ID).To(Equal("test-mbid-mediafile"))
Expect(results[0].Title).To(Equal("Test MBID MediaFile"))
})
It("returns empty result when MBID is not found", func() {
results, err := mr.Search("550e8400-e29b-41d4-a716-446655440099", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
})
It("missing media files are never returned by search", func() {
// Create a missing media file with MBID
missingMediaFile := model.MediaFile{
ID: "test-missing-mbid-mediafile",
Title: "Test Missing MBID MediaFile",
MbzRecordingID: "550e8400-e29b-41d4-a716-446655440022",
LibraryID: 1,
Path: "/test/path/missing.mp3",
Missing: true,
}
err := mr.Put(&missingMediaFile)
Expect(err).ToNot(HaveOccurred())
// Search never returns missing media files (hardcoded behavior)
results, err := mr.Search("550e8400-e29b-41d4-a716-446655440022", 0, 10)
Expect(err).ToNot(HaveOccurred())
Expect(results).To(BeEmpty())
// Clean up
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": missingMediaFile.ID}))
})
})
})
})

191
persistence/meilisearch.go Normal file
View File

@@ -0,0 +1,191 @@
package persistence
import (
"encoding/json"
"fmt"
"github.com/meilisearch/meilisearch-go"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
)
type MeilisearchService struct {
client meilisearch.ServiceManager
}
func NewMeilisearchService() *MeilisearchService {
if !conf.Server.Meilisearch.Enabled {
return nil
}
client := meilisearch.New(conf.Server.Meilisearch.Host, meilisearch.WithAPIKey(conf.Server.Meilisearch.ApiKey))
return &MeilisearchService{client: client}
}
func (s *MeilisearchService) IndexMediaFile(mf *model.MediaFile) {
if s == nil {
return
}
s.IndexMediaFiles([]model.MediaFile{*mf})
}
func (s *MeilisearchService) IndexMediaFiles(mfs []model.MediaFile) {
if s == nil || len(mfs) == 0 {
return
}
docs := make([]map[string]interface{}, len(mfs))
for i, mf := range mfs {
docs[i] = map[string]interface{}{
"id": mf.ID,
"title": mf.Title,
"artist": mf.Artist,
"album": mf.Album,
"albumArtist": mf.AlbumArtist,
"path": mf.Path,
"year": mf.Year,
"genre": mf.Genre,
}
}
_, err := s.client.Index("mediafiles").AddDocuments(docs, nil)
if err != nil {
log.Error("Error indexing mediafiles", "count", len(mfs), err)
}
}
func (s *MeilisearchService) DeleteMediaFile(id string) {
if s == nil {
return
}
_, err := s.client.Index("mediafiles").DeleteDocument(id)
if err != nil {
log.Error("Error deleting mediafile from index", "id", id, err)
}
}
func (s *MeilisearchService) DeleteMediaFiles(ids []string) {
if s == nil {
return
}
_, err := s.client.Index("mediafiles").DeleteDocuments(ids)
if err != nil {
log.Error("Error deleting mediafiles from index", "ids", ids, err)
}
}
func (s *MeilisearchService) IndexAlbum(album *model.Album) {
if s == nil {
return
}
s.IndexAlbums([]model.Album{*album})
}
func (s *MeilisearchService) IndexAlbums(albums []model.Album) {
if s == nil || len(albums) == 0 {
return
}
docs := make([]map[string]interface{}, len(albums))
for i, album := range albums {
docs[i] = map[string]interface{}{
"id": album.ID,
"name": album.Name,
"artist": album.AlbumArtist,
"albumArtist": album.AlbumArtist,
"year": album.MinYear,
"genre": album.Genre,
}
}
_, err := s.client.Index("albums").AddDocuments(docs, nil)
if err != nil {
log.Error("Error indexing albums", "count", len(albums), err)
}
}
func (s *MeilisearchService) DeleteAlbum(id string) {
if s == nil {
return
}
_, err := s.client.Index("albums").DeleteDocument(id)
if err != nil {
log.Error("Error deleting album from index", "id", id, err)
}
}
func (s *MeilisearchService) DeleteAlbums(ids []string) {
if s == nil {
return
}
_, err := s.client.Index("albums").DeleteDocuments(ids)
if err != nil {
log.Error("Error deleting albums from index", "ids", ids, err)
}
}
func (s *MeilisearchService) IndexArtist(artist *model.Artist) {
if s == nil {
return
}
s.IndexArtists([]model.Artist{*artist})
}
func (s *MeilisearchService) IndexArtists(artists []model.Artist) {
if s == nil || len(artists) == 0 {
return
}
docs := make([]map[string]interface{}, len(artists))
for i, artist := range artists {
docs[i] = map[string]interface{}{
"id": artist.ID,
"name": artist.Name,
}
}
_, err := s.client.Index("artists").AddDocuments(docs, nil)
if err != nil {
log.Error("Error indexing artists", "count", len(artists), err)
}
}
func (s *MeilisearchService) DeleteArtist(id string) {
if s == nil {
return
}
_, err := s.client.Index("artists").DeleteDocument(id)
if err != nil {
log.Error("Error deleting artist from index", "id", id, err)
}
}
func (s *MeilisearchService) DeleteArtists(ids []string) {
if s == nil {
return
}
_, err := s.client.Index("artists").DeleteDocuments(ids)
if err != nil {
log.Error("Error deleting artists from index", "ids", ids, err)
}
}
func (s *MeilisearchService) Search(indexName string, query string, offset, limit int) ([]string, error) {
if s == nil {
return nil, fmt.Errorf("meilisearch is not enabled")
}
searchRes, err := s.client.Index(indexName).Search(query, &meilisearch.SearchRequest{
Offset: int64(offset),
Limit: int64(limit),
})
if err != nil {
return nil, err
}
var ids []string
for _, hit := range searchRes.Hits {
if id, ok := hit["id"]; ok {
var idStr string
if err := json.Unmarshal(id, &idStr); err == nil {
ids = append(ids, idStr)
}
}
}
// Handle case where id might be non-string if necessary, though simpler is better for now.
// Meilisearch returns map[string]interface{}
return ids, nil
}

246
persistence/persistence.go Normal file
View File

@@ -0,0 +1,246 @@
package persistence
import (
"context"
"database/sql"
"fmt"
"reflect"
"time"
"github.com/navidrome/navidrome/db"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/run"
"github.com/pocketbase/dbx"
)
type SQLStore struct {
db dbx.Builder
ms *MeilisearchService
}
func New(conn *sql.DB) model.DataStore {
return &SQLStore{
db: dbx.NewFromDB(conn, db.Driver),
ms: NewMeilisearchService(),
}
}
func (s *SQLStore) Album(ctx context.Context) model.AlbumRepository {
return NewAlbumRepository(ctx, s.getDBXBuilder(), s.ms)
}
func (s *SQLStore) Artist(ctx context.Context) model.ArtistRepository {
return NewArtistRepository(ctx, s.getDBXBuilder(), s.ms)
}
func (s *SQLStore) MediaFile(ctx context.Context) model.MediaFileRepository {
return NewMediaFileRepository(ctx, s.getDBXBuilder(), s.ms)
}
func (s *SQLStore) Library(ctx context.Context) model.LibraryRepository {
return NewLibraryRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Folder(ctx context.Context) model.FolderRepository {
return newFolderRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Genre(ctx context.Context) model.GenreRepository {
return NewGenreRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Tag(ctx context.Context) model.TagRepository {
return NewTagRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) PlayQueue(ctx context.Context) model.PlayQueueRepository {
return NewPlayQueueRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Playlist(ctx context.Context) model.PlaylistRepository {
return NewPlaylistRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Property(ctx context.Context) model.PropertyRepository {
return NewPropertyRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Radio(ctx context.Context) model.RadioRepository {
return NewRadioRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) UserProps(ctx context.Context) model.UserPropsRepository {
return NewUserPropsRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Share(ctx context.Context) model.ShareRepository {
return NewShareRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) User(ctx context.Context) model.UserRepository {
return NewUserRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Transcoding(ctx context.Context) model.TranscodingRepository {
return NewTranscodingRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Player(ctx context.Context) model.PlayerRepository {
return NewPlayerRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) ScrobbleBuffer(ctx context.Context) model.ScrobbleBufferRepository {
return NewScrobbleBufferRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Scrobble(ctx context.Context) model.ScrobbleRepository {
return NewScrobbleRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Resource(ctx context.Context, m interface{}) model.ResourceRepository {
switch m.(type) {
case model.User:
return s.User(ctx).(model.ResourceRepository)
case model.Transcoding:
return s.Transcoding(ctx).(model.ResourceRepository)
case model.Player:
return s.Player(ctx).(model.ResourceRepository)
case model.Artist:
return s.Artist(ctx).(model.ResourceRepository)
case model.Album:
return s.Album(ctx).(model.ResourceRepository)
case model.MediaFile:
return s.MediaFile(ctx).(model.ResourceRepository)
case model.Genre:
return s.Genre(ctx).(model.ResourceRepository)
case model.Playlist:
return s.Playlist(ctx).(model.ResourceRepository)
case model.Radio:
return s.Radio(ctx).(model.ResourceRepository)
case model.Share:
return s.Share(ctx).(model.ResourceRepository)
case model.Tag:
return s.Tag(ctx).(model.ResourceRepository)
}
log.Error("Resource not implemented", "model", reflect.TypeOf(m).Name())
return nil
}
func (s *SQLStore) ReindexAll(ctx context.Context) error {
if s.ms == nil {
return nil
}
log.Info("Starting full re-index")
// Index Artists
artists, err := s.Artist(ctx).GetAll()
if err != nil {
return fmt.Errorf("fetching artists: %w", err)
}
s.ms.IndexArtists(artists)
log.Info(ctx, "Indexed artists", "count", len(artists))
// Index Albums
albums, err := s.Album(ctx).GetAll()
if err != nil {
return fmt.Errorf("fetching albums: %w", err)
}
s.ms.IndexAlbums(albums)
log.Info(ctx, "Indexed albums", "count", len(albums))
// Index MediaFiles
mfs, err := s.MediaFile(ctx).GetAll()
if err != nil {
return fmt.Errorf("fetching media files: %w", err)
}
batchSize := 2000
for i := 0; i < len(mfs); i += batchSize {
end := i + batchSize
if end > len(mfs) {
end = len(mfs)
}
s.ms.IndexMediaFiles(mfs[i:end])
log.Info(ctx, "Indexed media files batch", "start", i, "end", end)
}
return nil
}
func (s *SQLStore) WithTx(block func(tx model.DataStore) error, scope ...string) error {
var msg string
if len(scope) > 0 {
msg = scope[0]
}
start := time.Now()
conn, inTx := s.db.(*dbx.DB)
if !inTx {
log.Trace("Nested Transaction started", "scope", msg)
conn = dbx.NewFromDB(db.Db(), db.Driver)
} else {
log.Trace("Transaction started", "scope", msg)
}
return conn.Transactional(func(tx *dbx.Tx) error {
newDb := &SQLStore{db: tx}
err := block(newDb)
if !inTx {
log.Trace("Nested Transaction finished", "scope", msg, "elapsed", time.Since(start), err)
} else {
log.Trace("Transaction finished", "scope", msg, "elapsed", time.Since(start), err)
}
return err
})
}
func (s *SQLStore) WithTxImmediate(block func(tx model.DataStore) error, scope ...string) error {
ctx := context.Background()
return s.WithTx(func(tx model.DataStore) error {
// Workaround to force the transaction to be upgraded to immediate mode to avoid deadlocks
// See https://berthub.eu/articles/posts/a-brief-post-on-sqlite3-database-locked-despite-timeout/
_ = tx.Property(ctx).Put("tmp_lock_flag", "")
defer func() {
_ = tx.Property(ctx).Delete("tmp_lock_flag")
}()
return block(tx)
}, scope...)
}
func (s *SQLStore) GC(ctx context.Context, libraryIDs ...int) error {
trace := func(ctx context.Context, msg string, f func() error) func() error {
return func() error {
start := time.Now()
err := f()
log.Debug(ctx, "GC: "+msg, "elapsed", time.Since(start), err)
return err
}
}
// If libraryIDs are provided, scope operations to those libraries where possible
scoped := len(libraryIDs) > 0
if scoped {
log.Debug(ctx, "GC: Running selective garbage collection", "libraryIDs", libraryIDs)
}
err := run.Sequentially(
trace(ctx, "purge empty albums", func() error { return s.Album(ctx).(*albumRepository).purgeEmpty(libraryIDs...) }),
trace(ctx, "purge empty artists", func() error { return s.Artist(ctx).(*artistRepository).purgeEmpty() }),
trace(ctx, "mark missing artists", func() error { return s.Artist(ctx).(*artistRepository).markMissing() }),
trace(ctx, "purge empty folders", func() error { return s.Folder(ctx).(*folderRepository).purgeEmpty(libraryIDs...) }),
trace(ctx, "clean album annotations", func() error { return s.Album(ctx).(*albumRepository).cleanAnnotations() }),
trace(ctx, "clean artist annotations", func() error { return s.Artist(ctx).(*artistRepository).cleanAnnotations() }),
trace(ctx, "clean media file annotations", func() error { return s.MediaFile(ctx).(*mediaFileRepository).cleanAnnotations() }),
trace(ctx, "clean media file bookmarks", func() error { return s.MediaFile(ctx).(*mediaFileRepository).cleanBookmarks() }),
trace(ctx, "purge non used tags", func() error { return s.Tag(ctx).(*tagRepository).purgeUnused() }),
trace(ctx, "remove orphan playlist tracks", func() error { return s.Playlist(ctx).(*playlistRepository).removeOrphans() }),
)
if err != nil {
log.Error(ctx, "Error tidying up database", err)
}
return err
}
func (s *SQLStore) getDBXBuilder() dbx.Builder {
if s.db == nil {
return dbx.NewFromDB(db.Db(), db.Driver)
}
return s.db
}

View File

@@ -0,0 +1,271 @@
package persistence
import (
"context"
"path/filepath"
"testing"
_ "github.com/mattn/go-sqlite3"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/db"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
"github.com/navidrome/navidrome/tests"
"github.com/navidrome/navidrome/utils/gg"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
func TestPersistence(t *testing.T) {
tests.Init(t, true)
//os.Remove("./test-123.db")
//conf.Server.DbPath = "./test-123.db"
conf.Server.DbPath = "file::memory:?cache=shared&_foreign_keys=on"
defer db.Init(context.Background())()
log.SetLevel(log.LevelFatal)
RegisterFailHandler(Fail)
RunSpecs(t, "Persistence Suite")
}
func mf(mf model.MediaFile) model.MediaFile {
mf.Tags = model.Tags{}
mf.LibraryID = 1
mf.LibraryPath = "music" // Default folder
mf.LibraryName = "Music Library"
mf.Participants = model.Participants{
model.RoleArtist: model.ParticipantList{
model.Participant{Artist: model.Artist{ID: mf.ArtistID, Name: mf.Artist}},
},
}
if mf.Lyrics == "" {
mf.Lyrics = "[]"
}
return mf
}
func al(al model.Album) model.Album {
al.LibraryID = 1
al.LibraryPath = "music"
al.LibraryName = "Music Library"
al.Discs = model.Discs{}
al.Tags = model.Tags{}
al.Participants = model.Participants{}
return al
}
var (
artistKraftwerk = model.Artist{ID: "2", Name: "Kraftwerk", OrderArtistName: "kraftwerk"}
artistBeatles = model.Artist{ID: "3", Name: "The Beatles", OrderArtistName: "beatles"}
testArtists = model.Artists{
artistKraftwerk,
artistBeatles,
}
)
var (
albumSgtPeppers = al(model.Album{ID: "101", Name: "Sgt Peppers", AlbumArtist: "The Beatles", OrderAlbumName: "sgt peppers", AlbumArtistID: "3", EmbedArtPath: p("/beatles/1/sgt/a day.mp3"), SongCount: 1, MaxYear: 1967})
albumAbbeyRoad = al(model.Album{ID: "102", Name: "Abbey Road", AlbumArtist: "The Beatles", OrderAlbumName: "abbey road", AlbumArtistID: "3", EmbedArtPath: p("/beatles/1/come together.mp3"), SongCount: 1, MaxYear: 1969})
albumRadioactivity = al(model.Album{ID: "103", Name: "Radioactivity", AlbumArtist: "Kraftwerk", OrderAlbumName: "radioactivity", AlbumArtistID: "2", EmbedArtPath: p("/kraft/radio/radio.mp3"), SongCount: 2})
albumMultiDisc = al(model.Album{ID: "104", Name: "Multi Disc Album", AlbumArtist: "Test Artist", OrderAlbumName: "multi disc album", AlbumArtistID: "1", EmbedArtPath: p("/test/multi/disc1/track1.mp3"), SongCount: 4})
testAlbums = model.Albums{
albumSgtPeppers,
albumAbbeyRoad,
albumRadioactivity,
albumMultiDisc,
}
)
var (
songDayInALife = mf(model.MediaFile{ID: "1001", Title: "A Day In A Life", ArtistID: "3", Artist: "The Beatles", AlbumID: "101", Album: "Sgt Peppers", Path: p("/beatles/1/sgt/a day.mp3")})
songComeTogether = mf(model.MediaFile{ID: "1002", Title: "Come Together", ArtistID: "3", Artist: "The Beatles", AlbumID: "102", Album: "Abbey Road", Path: p("/beatles/1/come together.mp3")})
songRadioactivity = mf(model.MediaFile{ID: "1003", Title: "Radioactivity", ArtistID: "2", Artist: "Kraftwerk", AlbumID: "103", Album: "Radioactivity", Path: p("/kraft/radio/radio.mp3")})
songAntenna = mf(model.MediaFile{ID: "1004", Title: "Antenna", ArtistID: "2", Artist: "Kraftwerk",
AlbumID: "103",
Path: p("/kraft/radio/antenna.mp3"),
RGAlbumGain: gg.P(1.0), RGAlbumPeak: gg.P(2.0), RGTrackGain: gg.P(3.0), RGTrackPeak: gg.P(4.0),
})
songAntennaWithLyrics = mf(model.MediaFile{
ID: "1005",
Title: "Antenna",
ArtistID: "2",
Artist: "Kraftwerk",
AlbumID: "103",
Lyrics: `[{"lang":"xxx","line":[{"value":"This is a set of lyrics"}],"synced":false}]`,
})
songAntenna2 = mf(model.MediaFile{ID: "1006", Title: "Antenna", ArtistID: "2", Artist: "Kraftwerk", AlbumID: "103"})
// Multi-disc album tracks (intentionally out of order to test sorting)
songDisc2Track11 = mf(model.MediaFile{ID: "2001", Title: "Disc 2 Track 11", ArtistID: "1", Artist: "Test Artist", AlbumID: "104", Album: "Multi Disc Album", DiscNumber: 2, TrackNumber: 11, Path: p("/test/multi/disc2/track11.mp3"), OrderAlbumName: "multi disc album", OrderArtistName: "test artist"})
songDisc1Track01 = mf(model.MediaFile{ID: "2002", Title: "Disc 1 Track 1", ArtistID: "1", Artist: "Test Artist", AlbumID: "104", Album: "Multi Disc Album", DiscNumber: 1, TrackNumber: 1, Path: p("/test/multi/disc1/track1.mp3"), OrderAlbumName: "multi disc album", OrderArtistName: "test artist"})
songDisc2Track01 = mf(model.MediaFile{ID: "2003", Title: "Disc 2 Track 1", ArtistID: "1", Artist: "Test Artist", AlbumID: "104", Album: "Multi Disc Album", DiscNumber: 2, TrackNumber: 1, Path: p("/test/multi/disc2/track1.mp3"), OrderAlbumName: "multi disc album", OrderArtistName: "test artist"})
songDisc1Track02 = mf(model.MediaFile{ID: "2004", Title: "Disc 1 Track 2", ArtistID: "1", Artist: "Test Artist", AlbumID: "104", Album: "Multi Disc Album", DiscNumber: 1, TrackNumber: 2, Path: p("/test/multi/disc1/track2.mp3"), OrderAlbumName: "multi disc album", OrderArtistName: "test artist"})
testSongs = model.MediaFiles{
songDayInALife,
songComeTogether,
songRadioactivity,
songAntenna,
songAntennaWithLyrics,
songAntenna2,
songDisc2Track11,
songDisc1Track01,
songDisc2Track01,
songDisc1Track02,
}
)
var (
radioWithoutHomePage = model.Radio{ID: "1235", StreamUrl: "https://example.com:8000/1/stream.mp3", HomePageUrl: "", Name: "No Homepage"}
radioWithHomePage = model.Radio{ID: "5010", StreamUrl: "https://example.com/stream.mp3", Name: "Example Radio", HomePageUrl: "https://example.com"}
testRadios = model.Radios{radioWithoutHomePage, radioWithHomePage}
)
var (
plsBest model.Playlist
plsCool model.Playlist
testPlaylists []*model.Playlist
)
var (
adminUser = model.User{ID: "userid", UserName: "userid", Name: "admin", Email: "admin@email.com", IsAdmin: true}
regularUser = model.User{ID: "2222", UserName: "regular-user", Name: "Regular User", Email: "regular@example.com"}
testUsers = model.Users{adminUser, regularUser}
)
func p(path string) string {
return filepath.FromSlash(path)
}
// Initialize test DB
// TODO Load this data setup from file(s)
var _ = BeforeSuite(func() {
conn := GetDBXBuilder()
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, adminUser)
ur := NewUserRepository(ctx, conn)
for i := range testUsers {
err := ur.Put(&testUsers[i])
if err != nil {
panic(err)
}
}
// Associate users with library 1 (default test library)
for i := range testUsers {
err := ur.SetUserLibraries(testUsers[i].ID, []int{1})
if err != nil {
panic(err)
}
}
alr := NewAlbumRepository(ctx, conn, nil).(*albumRepository)
for i := range testAlbums {
a := testAlbums[i]
err := alr.Put(&a)
if err != nil {
panic(err)
}
}
arr := NewArtistRepository(ctx, conn, nil)
for i := range testArtists {
a := testArtists[i]
err := arr.Put(&a)
if err != nil {
panic(err)
}
}
// Associate artists with library 1 (default test library)
lr := NewLibraryRepository(ctx, conn)
for i := range testArtists {
err := lr.AddArtist(1, testArtists[i].ID)
if err != nil {
panic(err)
}
}
mr := NewMediaFileRepository(ctx, conn, nil)
for i := range testSongs {
err := mr.Put(&testSongs[i])
if err != nil {
panic(err)
}
}
rar := NewRadioRepository(ctx, conn)
for i := range testRadios {
r := testRadios[i]
err := rar.Put(&r)
if err != nil {
panic(err)
}
}
plsBest = model.Playlist{
Name: "Best",
Comment: "No Comments",
OwnerID: "userid",
OwnerName: "userid",
Public: true,
SongCount: 2,
}
plsBest.AddMediaFilesByID([]string{"1001", "1003"})
plsCool = model.Playlist{Name: "Cool", OwnerID: "userid", OwnerName: "userid"}
plsCool.AddMediaFilesByID([]string{"1004"})
testPlaylists = []*model.Playlist{&plsBest, &plsCool}
pr := NewPlaylistRepository(ctx, conn)
for i := range testPlaylists {
err := pr.Put(testPlaylists[i])
if err != nil {
panic(err)
}
}
// Prepare annotations
if err := arr.SetStar(true, artistBeatles.ID); err != nil {
panic(err)
}
ar, err := arr.Get(artistBeatles.ID)
if err != nil {
panic(err)
}
if ar == nil {
panic("artist not found after SetStar")
}
artistBeatles.Starred = true
artistBeatles.StarredAt = ar.StarredAt
testArtists[1] = artistBeatles
if err := alr.SetStar(true, albumRadioactivity.ID); err != nil {
panic(err)
}
al, err := alr.Get(albumRadioactivity.ID)
if err != nil {
panic(err)
}
if al == nil {
panic("album not found after SetStar")
}
albumRadioactivity.Starred = true
albumRadioactivity.StarredAt = al.StarredAt
testAlbums[2] = albumRadioactivity
if err := mr.SetStar(true, songComeTogether.ID); err != nil {
panic(err)
}
mf, err := mr.Get(songComeTogether.ID)
if err != nil {
panic(err)
}
songComeTogether.Starred = true
songComeTogether.StarredAt = mf.StarredAt
testSongs[1] = songComeTogether
})
func GetDBXBuilder() *dbx.DB {
return dbx.NewFromDB(db.Db(), db.Dialect)
}

View File

@@ -0,0 +1,58 @@
package persistence
import (
"context"
"github.com/navidrome/navidrome/db"
"github.com/navidrome/navidrome/model"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("SQLStore", func() {
var ds model.DataStore
var ctx context.Context
BeforeEach(func() {
ds = New(db.Db())
ctx = context.Background()
})
Describe("WithTx", func() {
Context("When block returns nil", func() {
It("commits changes to the DB", func() {
err := ds.WithTx(func(tx model.DataStore) error {
pl := tx.Player(ctx)
err := pl.Put(&model.Player{ID: "666", UserId: "userid"})
Expect(err).ToNot(HaveOccurred())
pr := tx.Property(ctx)
err = pr.Put("777", "value")
Expect(err).ToNot(HaveOccurred())
return nil
})
Expect(err).ToNot(HaveOccurred())
Expect(ds.Player(ctx).Get("666")).To(Equal(&model.Player{ID: "666", UserId: "userid", Username: "userid"}))
Expect(ds.Property(ctx).Get("777")).To(Equal("value"))
})
})
Context("When block returns an error", func() {
It("rollbacks changes to the DB", func() {
err := ds.WithTx(func(tx model.DataStore) error {
pr := tx.Property(ctx)
err := pr.Put("999", "value")
Expect(err).ToNot(HaveOccurred())
// Will fail as it is missing the UserName
pl := tx.Player(ctx)
err = pl.Put(&model.Player{ID: "888"})
Expect(err).To(HaveOccurred())
return err
})
Expect(err).To(HaveOccurred())
_, err = ds.Property(ctx).Get("999")
Expect(err).To(MatchError(model.ErrNotFound))
_, err = ds.Player(ctx).Get("888")
Expect(err).To(MatchError(model.ErrNotFound))
})
})
})
})

View File

@@ -0,0 +1,169 @@
package persistence
import (
"context"
"errors"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type playerRepository struct {
sqlRepository
}
func NewPlayerRepository(ctx context.Context, db dbx.Builder) model.PlayerRepository {
r := &playerRepository{}
r.ctx = ctx
r.db = db
r.registerModel(&model.Player{}, map[string]filterFunc{
"name": containsFilter("player.name"),
})
r.setSortMappings(map[string]string{
"user_name": "username", //TODO rename all user_name and userName to username
})
return r
}
func (r *playerRepository) Put(p *model.Player) error {
_, err := r.put(p.ID, p)
return err
}
func (r *playerRepository) selectPlayer(options ...model.QueryOptions) SelectBuilder {
return r.newSelect(options...).
Columns("player.*").
Join("user ON player.user_id = user.id").
Columns("user.user_name username")
}
func (r *playerRepository) Get(id string) (*model.Player, error) {
sel := r.selectPlayer().Where(Eq{"player.id": id})
var res model.Player
err := r.queryOne(sel, &res)
return &res, err
}
func (r *playerRepository) FindMatch(userId, client, userAgent string) (*model.Player, error) {
sel := r.selectPlayer().Where(And{
Eq{"client": client},
Eq{"user_agent": userAgent},
Eq{"user_id": userId},
})
var res model.Player
err := r.queryOne(sel, &res)
return &res, err
}
func (r *playerRepository) newRestSelect(options ...model.QueryOptions) SelectBuilder {
s := r.selectPlayer(options...)
return s.Where(r.addRestriction())
}
func (r *playerRepository) addRestriction(sql ...Sqlizer) Sqlizer {
s := And{}
if len(sql) > 0 {
s = append(s, sql[0])
}
u := loggedUser(r.ctx)
if u.IsAdmin {
return s
}
return append(s, Eq{"user_id": u.ID})
}
func (r *playerRepository) CountByClient(options ...model.QueryOptions) (map[string]int64, error) {
sel := r.newSelect(options...).
Columns(
"case when client = 'NavidromeUI' then name else client end as player",
"count(*) as count",
).GroupBy("client")
var res []struct {
Player string
Count int64
}
err := r.queryAll(sel, &res)
if err != nil {
return nil, err
}
counts := make(map[string]int64, len(res))
for _, c := range res {
counts[c.Player] = c.Count
}
return counts, nil
}
func (r *playerRepository) CountAll(options ...model.QueryOptions) (int64, error) {
return r.count(r.newRestSelect(), options...)
}
func (r *playerRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *playerRepository) Read(id string) (interface{}, error) {
sel := r.newRestSelect().Where(Eq{"player.id": id})
var res model.Player
err := r.queryOne(sel, &res)
return &res, err
}
func (r *playerRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
sel := r.newRestSelect(r.parseRestOptions(r.ctx, options...))
res := model.Players{}
err := r.queryAll(sel, &res)
return res, err
}
func (r *playerRepository) EntityName() string {
return "player"
}
func (r *playerRepository) NewInstance() interface{} {
return &model.Player{}
}
func (r *playerRepository) isPermitted(p *model.Player) bool {
u := loggedUser(r.ctx)
return u.IsAdmin || p.UserId == u.ID
}
func (r *playerRepository) Save(entity interface{}) (string, error) {
t := entity.(*model.Player)
if !r.isPermitted(t) {
return "", rest.ErrPermissionDenied
}
id, err := r.put(t.ID, t)
if errors.Is(err, model.ErrNotFound) {
return "", rest.ErrNotFound
}
return id, err
}
func (r *playerRepository) Update(id string, entity interface{}, cols ...string) error {
t := entity.(*model.Player)
t.ID = id
if !r.isPermitted(t) {
return rest.ErrPermissionDenied
}
_, err := r.put(id, t, cols...)
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
func (r *playerRepository) Delete(id string) error {
filter := r.addRestriction(And{Eq{"player.id": id}})
err := r.delete(filter)
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
var _ model.PlayerRepository = (*playerRepository)(nil)
var _ rest.Repository = (*playerRepository)(nil)
var _ rest.Persistable = (*playerRepository)(nil)

View File

@@ -0,0 +1,247 @@
package persistence
import (
"context"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
var _ = Describe("PlayerRepository", func() {
var adminRepo *playerRepository
var database *dbx.DB
var (
adminPlayer1 = model.Player{ID: "1", Name: "NavidromeUI [Firefox/Linux]", UserAgent: "Firefox/Linux", UserId: adminUser.ID, Username: adminUser.UserName, Client: "NavidromeUI", IP: "127.0.0.1", ReportRealPath: true, ScrobbleEnabled: true}
adminPlayer2 = model.Player{ID: "2", Name: "GenericClient [Chrome/Windows]", IP: "192.168.0.5", UserAgent: "Chrome/Windows", UserId: adminUser.ID, Username: adminUser.UserName, Client: "GenericClient", MaxBitRate: 128}
regularPlayer = model.Player{ID: "3", Name: "NavidromeUI [Safari/macOS]", UserAgent: "Safari/macOS", UserId: regularUser.ID, Username: regularUser.UserName, Client: "NavidromeUI", ReportRealPath: true, ScrobbleEnabled: false}
players = model.Players{adminPlayer1, adminPlayer2, regularPlayer}
)
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, adminUser)
database = GetDBXBuilder()
adminRepo = NewPlayerRepository(ctx, database).(*playerRepository)
for idx := range players {
err := adminRepo.Put(&players[idx])
Expect(err).To(BeNil())
}
})
AfterEach(func() {
items, err := adminRepo.ReadAll()
Expect(err).To(BeNil())
players, ok := items.(model.Players)
Expect(ok).To(BeTrue())
for i := range players {
err = adminRepo.Delete(players[i].ID)
Expect(err).To(BeNil())
}
})
Describe("EntityName", func() {
It("returns the right name", func() {
Expect(adminRepo.EntityName()).To(Equal("player"))
})
})
Describe("FindMatch", func() {
It("finds existing match", func() {
player, err := adminRepo.FindMatch(adminUser.ID, "NavidromeUI", "Firefox/Linux")
Expect(err).To(BeNil())
Expect(*player).To(Equal(adminPlayer1))
})
It("doesn't find bad match", func() {
_, err := adminRepo.FindMatch(regularUser.ID, "NavidromeUI", "Firefox/Linux")
Expect(err).To(Equal(model.ErrNotFound))
})
})
Describe("Get", func() {
It("Gets an existing item from user", func() {
player, err := adminRepo.Get(adminPlayer1.ID)
Expect(err).To(BeNil())
Expect(*player).To(Equal(adminPlayer1))
})
It("Gets an existing item from another user", func() {
player, err := adminRepo.Get(regularPlayer.ID)
Expect(err).To(BeNil())
Expect(*player).To(Equal(regularPlayer))
})
It("does not get nonexistent item", func() {
_, err := adminRepo.Get("i don't exist")
Expect(err).To(Equal(model.ErrNotFound))
})
})
DescribeTableSubtree("per context", func(admin bool, players model.Players, userPlayer model.Player, otherPlayer model.Player) {
var repo *playerRepository
BeforeEach(func() {
if admin {
repo = adminRepo
} else {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, regularUser)
repo = NewPlayerRepository(ctx, database).(*playerRepository)
}
})
baseCount := int64(len(players))
Describe("Count", func() {
It("should return all", func() {
count, err := repo.Count()
Expect(err).To(BeNil())
Expect(count).To(Equal(baseCount))
})
})
Describe("Delete", func() {
DescribeTable("item type", func(player model.Player) {
err := repo.Delete(player.ID)
Expect(err).To(BeNil())
isReal := player.UserId != ""
canDelete := admin || player.UserId == userPlayer.UserId
count, err := repo.Count()
Expect(err).To(BeNil())
if isReal && canDelete {
Expect(count).To(Equal(baseCount - 1))
} else {
Expect(count).To(Equal(baseCount))
}
item, err := repo.Get(player.ID)
if !isReal || canDelete {
Expect(err).To(Equal(model.ErrNotFound))
} else {
Expect(*item).To(Equal(player))
}
},
Entry("same user", userPlayer),
Entry("other item", otherPlayer),
Entry("fake item", model.Player{}),
)
})
Describe("Read", func() {
It("can read from current user", func() {
player, err := repo.Read(userPlayer.ID)
Expect(err).To(BeNil())
Expect(player).To(Equal(&userPlayer))
})
It("can read from other user or fail if not admin", func() {
player, err := repo.Read(otherPlayer.ID)
if admin {
Expect(err).To(BeNil())
Expect(player).To(Equal(&otherPlayer))
} else {
Expect(err).To(Equal(model.ErrNotFound))
}
})
It("does not get nonexistent item", func() {
_, err := repo.Read("i don't exist")
Expect(err).To(Equal(model.ErrNotFound))
})
})
Describe("ReadAll", func() {
It("should get all items", func() {
data, err := repo.ReadAll()
Expect(err).To(BeNil())
Expect(data).To(Equal(players))
})
})
Describe("Save", func() {
DescribeTable("item type", func(player model.Player) {
clone := player
clone.ID = ""
clone.IP = "192.168.1.1"
id, err := repo.Save(&clone)
if clone.UserId == "" {
Expect(err).To(HaveOccurred())
} else if !admin && player.Username == adminPlayer1.Username {
Expect(err).To(Equal(rest.ErrPermissionDenied))
clone.UserId = ""
} else {
Expect(err).To(BeNil())
Expect(id).ToNot(BeEmpty())
}
count, err := repo.Count()
Expect(err).To(BeNil())
clone.ID = id
newItem, err := repo.Get(id)
if clone.UserId == "" {
Expect(count).To(Equal(baseCount))
Expect(err).To(Equal(model.ErrNotFound))
} else {
Expect(count).To(Equal(baseCount + 1))
Expect(err).To(BeNil())
Expect(*newItem).To(Equal(clone))
}
},
Entry("same user", userPlayer),
Entry("other item", otherPlayer),
Entry("fake item", model.Player{}),
)
})
Describe("Update", func() {
DescribeTable("item type", func(player model.Player) {
clone := player
clone.IP = "192.168.1.1"
clone.MaxBitRate = 10000
err := repo.Update(clone.ID, &clone, "ip")
if clone.UserId == "" {
Expect(err).To(HaveOccurred())
} else if !admin && player.Username == adminPlayer1.Username {
Expect(err).To(Equal(rest.ErrPermissionDenied))
clone.IP = player.IP
} else {
Expect(err).To(BeNil())
}
clone.MaxBitRate = player.MaxBitRate
newItem, err := repo.Get(clone.ID)
if player.UserId == "" {
Expect(err).To(Equal(model.ErrNotFound))
} else if !admin && player.UserId == adminUser.ID {
Expect(*newItem).To(Equal(player))
} else {
Expect(*newItem).To(Equal(clone))
}
},
Entry("same user", userPlayer),
Entry("other item", otherPlayer),
Entry("fake item", model.Player{}),
)
})
},
Entry("admin context", true, players, adminPlayer1, regularPlayer),
Entry("regular context", false, model.Players{regularPlayer}, regularPlayer, adminPlayer1),
)
})

View File

@@ -0,0 +1,528 @@
package persistence
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"slices"
"time"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/criteria"
"github.com/pocketbase/dbx"
)
type playlistRepository struct {
sqlRepository
}
type dbPlaylist struct {
model.Playlist `structs:",flatten"`
Rules sql.NullString `structs:"-"`
}
func (p *dbPlaylist) PostScan() error {
if p.Rules.String != "" {
return json.Unmarshal([]byte(p.Rules.String), &p.Playlist.Rules)
}
return nil
}
func (p dbPlaylist) PostMapArgs(args map[string]any) error {
var err error
if p.Playlist.IsSmartPlaylist() {
args["rules"], err = json.Marshal(p.Playlist.Rules)
if err != nil {
return fmt.Errorf("invalid criteria expression: %w", err)
}
return nil
}
delete(args, "rules")
return nil
}
func NewPlaylistRepository(ctx context.Context, db dbx.Builder) model.PlaylistRepository {
r := &playlistRepository{}
r.ctx = ctx
r.db = db
r.registerModel(&model.Playlist{}, map[string]filterFunc{
"q": playlistFilter,
"smart": smartPlaylistFilter,
})
r.setSortMappings(map[string]string{
"owner_name": "owner_name",
})
return r
}
func playlistFilter(_ string, value interface{}) Sqlizer {
return Or{
substringFilter("playlist.name", value),
substringFilter("playlist.comment", value),
}
}
func smartPlaylistFilter(string, interface{}) Sqlizer {
return Or{
Eq{"rules": ""},
Eq{"rules": nil},
}
}
func (r *playlistRepository) userFilter() Sqlizer {
user := loggedUser(r.ctx)
if user.IsAdmin {
return And{}
}
return Or{
Eq{"public": true},
Eq{"owner_id": user.ID},
}
}
func (r *playlistRepository) CountAll(options ...model.QueryOptions) (int64, error) {
sq := Select().Where(r.userFilter())
return r.count(sq, options...)
}
func (r *playlistRepository) Exists(id string) (bool, error) {
return r.exists(And{Eq{"id": id}, r.userFilter()})
}
func (r *playlistRepository) Delete(id string) error {
usr := loggedUser(r.ctx)
if !usr.IsAdmin {
pls, err := r.Get(id)
if err != nil {
return err
}
if pls.OwnerID != usr.ID {
return rest.ErrPermissionDenied
}
}
return r.delete(And{Eq{"id": id}, r.userFilter()})
}
func (r *playlistRepository) Put(p *model.Playlist) error {
pls := dbPlaylist{Playlist: *p}
if pls.ID == "" {
pls.CreatedAt = time.Now()
} else {
ok, err := r.Exists(pls.ID)
if err != nil {
return err
}
if !ok {
return model.ErrNotAuthorized
}
}
pls.UpdatedAt = time.Now()
id, err := r.put(pls.ID, pls)
if err != nil {
return err
}
p.ID = id
if p.IsSmartPlaylist() {
// Do not update tracks at this point, as it may take a long time and lock the DB, breaking the scan process
//r.refreshSmartPlaylist(p)
return nil
}
// Only update tracks if they were specified
if len(pls.Tracks) > 0 {
return r.updateTracks(id, p.MediaFiles())
}
return r.refreshCounters(&pls.Playlist)
}
func (r *playlistRepository) Get(id string) (*model.Playlist, error) {
return r.findBy(And{Eq{"playlist.id": id}, r.userFilter()})
}
func (r *playlistRepository) GetWithTracks(id string, refreshSmartPlaylist, includeMissing bool) (*model.Playlist, error) {
pls, err := r.Get(id)
if err != nil {
return nil, err
}
if refreshSmartPlaylist {
r.refreshSmartPlaylist(pls)
}
tracks, err := r.loadTracks(Select().From("playlist_tracks").
Where(Eq{"missing": false}).
OrderBy("playlist_tracks.id"), id)
if err != nil {
log.Error(r.ctx, "Error loading playlist tracks ", "playlist", pls.Name, "id", pls.ID, err)
return nil, err
}
pls.SetTracks(tracks)
return pls, nil
}
func (r *playlistRepository) FindByPath(path string) (*model.Playlist, error) {
return r.findBy(Eq{"path": path})
}
func (r *playlistRepository) findBy(sql Sqlizer) (*model.Playlist, error) {
sel := r.selectPlaylist().Where(sql)
var pls []dbPlaylist
err := r.queryAll(sel, &pls)
if err != nil {
return nil, err
}
if len(pls) == 0 {
return nil, model.ErrNotFound
}
return &pls[0].Playlist, nil
}
func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playlists, error) {
sel := r.selectPlaylist(options...).Where(r.userFilter())
var res []dbPlaylist
err := r.queryAll(sel, &res)
if err != nil {
return nil, err
}
playlists := make(model.Playlists, len(res))
for i, p := range res {
playlists[i] = p.Playlist
}
return playlists, err
}
func (r *playlistRepository) GetPlaylists(mediaFileId string) (model.Playlists, error) {
sel := r.selectPlaylist(model.QueryOptions{Sort: "name"}).
Join("playlist_tracks on playlist.id = playlist_tracks.playlist_id").
Where(And{Eq{"playlist_tracks.media_file_id": mediaFileId}, r.userFilter()})
var res []dbPlaylist
err := r.queryAll(sel, &res)
if err != nil {
if errors.Is(err, model.ErrNotFound) {
return model.Playlists{}, nil
}
return nil, err
}
playlists := make(model.Playlists, len(res))
for i, p := range res {
playlists[i] = p.Playlist
}
return playlists, nil
}
func (r *playlistRepository) selectPlaylist(options ...model.QueryOptions) SelectBuilder {
return r.newSelect(options...).Join("user on user.id = owner_id").
Columns(r.tableName+".*", "user.user_name as owner_name")
}
func (r *playlistRepository) refreshSmartPlaylist(pls *model.Playlist) bool {
// Only refresh if it is a smart playlist and was not refreshed within the interval provided by the refresh delay config
if !pls.IsSmartPlaylist() || (pls.EvaluatedAt != nil && time.Since(*pls.EvaluatedAt) < conf.Server.SmartPlaylistRefreshDelay) {
return false
}
// Never refresh other users' playlists
usr := loggedUser(r.ctx)
if pls.OwnerID != usr.ID {
log.Trace(r.ctx, "Not refreshing smart playlist from other user", "playlist", pls.Name, "id", pls.ID)
return false
}
log.Debug(r.ctx, "Refreshing smart playlist", "playlist", pls.Name, "id", pls.ID)
start := time.Now()
// Remove old tracks
del := Delete("playlist_tracks").Where(Eq{"playlist_id": pls.ID})
_, err := r.executeSQL(del)
if err != nil {
log.Error(r.ctx, "Error deleting old smart playlist tracks", "playlist", pls.Name, "id", pls.ID, err)
return false
}
// Re-populate playlist based on Smart Playlist criteria
rules := *pls.Rules
// If the playlist depends on other playlists, recursively refresh them first
childPlaylistIds := rules.ChildPlaylistIds()
for _, id := range childPlaylistIds {
childPls, err := r.Get(id)
if err != nil {
log.Error(r.ctx, "Error loading child playlist", "id", pls.ID, "childId", id, err)
return false
}
r.refreshSmartPlaylist(childPls)
}
sq := Select("row_number() over (order by "+rules.OrderBy()+") as id", "'"+pls.ID+"' as playlist_id", "media_file.id as media_file_id").
From("media_file").LeftJoin("annotation on (" +
"annotation.item_id = media_file.id" +
" AND annotation.item_type = 'media_file'" +
" AND annotation.user_id = '" + usr.ID + "')")
// Only include media files from libraries the user has access to
sq = r.applyLibraryFilter(sq, "media_file")
// Apply the criteria rules
sq = r.addCriteria(sq, rules)
insSql := Insert("playlist_tracks").Columns("id", "playlist_id", "media_file_id").Select(sq)
_, err = r.executeSQL(insSql)
if err != nil {
log.Error(r.ctx, "Error refreshing smart playlist tracks", "playlist", pls.Name, "id", pls.ID, err)
return false
}
// Update playlist stats
err = r.refreshCounters(pls)
if err != nil {
log.Error(r.ctx, "Error updating smart playlist stats", "playlist", pls.Name, "id", pls.ID, err)
return false
}
// Update when the playlist was last refreshed (for cache purposes)
updSql := Update(r.tableName).Set("evaluated_at", time.Now()).Where(Eq{"id": pls.ID})
_, err = r.executeSQL(updSql)
if err != nil {
log.Error(r.ctx, "Error updating smart playlist", "playlist", pls.Name, "id", pls.ID, err)
return false
}
log.Debug(r.ctx, "Refreshed playlist", "playlist", pls.Name, "id", pls.ID, "numTracks", pls.SongCount, "elapsed", time.Since(start))
return true
}
func (r *playlistRepository) addCriteria(sql SelectBuilder, c criteria.Criteria) SelectBuilder {
sql = sql.Where(c)
if c.Limit > 0 {
sql = sql.Limit(uint64(c.Limit)).Offset(uint64(c.Offset))
}
if order := c.OrderBy(); order != "" {
sql = sql.OrderBy(order)
}
return sql
}
func (r *playlistRepository) updateTracks(id string, tracks model.MediaFiles) error {
ids := make([]string, len(tracks))
for i := range tracks {
ids[i] = tracks[i].ID
}
return r.updatePlaylist(id, ids)
}
func (r *playlistRepository) updatePlaylist(playlistId string, mediaFileIds []string) error {
if !r.isWritable(playlistId) {
return rest.ErrPermissionDenied
}
// Remove old tracks
del := Delete("playlist_tracks").Where(Eq{"playlist_id": playlistId})
_, err := r.executeSQL(del)
if err != nil {
return err
}
return r.addTracks(playlistId, 1, mediaFileIds)
}
func (r *playlistRepository) addTracks(playlistId string, startingPos int, mediaFileIds []string) error {
// Break the track list in chunks to avoid hitting SQLITE_MAX_VARIABLE_NUMBER limit
// Add new tracks, chunk by chunk
pos := startingPos
for chunk := range slices.Chunk(mediaFileIds, 200) {
ins := Insert("playlist_tracks").Columns("playlist_id", "media_file_id", "id")
for _, t := range chunk {
ins = ins.Values(playlistId, t, pos)
pos++
}
_, err := r.executeSQL(ins)
if err != nil {
return err
}
}
return r.refreshCounters(&model.Playlist{ID: playlistId})
}
// refreshCounters updates total playlist duration, size and count
func (r *playlistRepository) refreshCounters(pls *model.Playlist) error {
statsSql := Select(
"coalesce(sum(duration), 0) as duration",
"coalesce(sum(size), 0) as size",
"count(*) as count",
).
From("media_file").
Join("playlist_tracks f on f.media_file_id = media_file.id").
Where(Eq{"playlist_id": pls.ID})
var res struct{ Duration, Size, Count float32 }
err := r.queryOne(statsSql, &res)
if err != nil {
return err
}
// Update playlist's total duration, size and count
upd := Update("playlist").
Set("duration", res.Duration).
Set("size", res.Size).
Set("song_count", res.Count).
Set("updated_at", time.Now()).
Where(Eq{"id": pls.ID})
_, err = r.executeSQL(upd)
if err != nil {
return err
}
pls.SongCount = int(res.Count)
pls.Duration = res.Duration
pls.Size = int64(res.Size)
return nil
}
func (r *playlistRepository) loadTracks(sel SelectBuilder, id string) (model.PlaylistTracks, error) {
sel = r.applyLibraryFilter(sel, "f")
userID := loggedUser(r.ctx).ID
tracksQuery := sel.
Columns(
"coalesce(starred, 0) as starred",
"starred_at",
"coalesce(play_count, 0) as play_count",
"play_date",
"coalesce(rating, 0) as rating",
"rated_at",
"f.*",
"playlist_tracks.*",
"library.path as library_path",
"library.name as library_name",
).
LeftJoin("annotation on (" +
"annotation.item_id = media_file_id" +
" AND annotation.item_type = 'media_file'" +
" AND annotation.user_id = '" + userID + "')").
Join("media_file f on f.id = media_file_id").
Join("library on f.library_id = library.id").
Where(Eq{"playlist_id": id})
tracks := dbPlaylistTracks{}
err := r.queryAll(tracksQuery, &tracks)
if err != nil {
return nil, err
}
return tracks.toModels(), err
}
func (r *playlistRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *playlistRepository) Read(id string) (interface{}, error) {
return r.Get(id)
}
func (r *playlistRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *playlistRepository) EntityName() string {
return "playlist"
}
func (r *playlistRepository) NewInstance() interface{} {
return &model.Playlist{}
}
func (r *playlistRepository) Save(entity interface{}) (string, error) {
pls := entity.(*model.Playlist)
pls.OwnerID = loggedUser(r.ctx).ID
pls.ID = "" // Make sure we don't override an existing playlist
err := r.Put(pls)
if err != nil {
return "", err
}
return pls.ID, err
}
func (r *playlistRepository) Update(id string, entity interface{}, cols ...string) error {
pls := dbPlaylist{Playlist: *entity.(*model.Playlist)}
current, err := r.Get(id)
if err != nil {
return err
}
usr := loggedUser(r.ctx)
if !usr.IsAdmin {
// Only the owner can update the playlist
if current.OwnerID != usr.ID {
return rest.ErrPermissionDenied
}
// Regular users can't change the ownership of a playlist
if pls.OwnerID != "" && pls.OwnerID != usr.ID {
return rest.ErrPermissionDenied
}
}
pls.ID = id
pls.UpdatedAt = time.Now()
_, err = r.put(id, pls, append(cols, "updatedAt")...)
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
func (r *playlistRepository) removeOrphans() error {
sel := Select("playlist_tracks.playlist_id as id", "p.name").From("playlist_tracks").
Join("playlist p on playlist_tracks.playlist_id = p.id").
LeftJoin("media_file mf on playlist_tracks.media_file_id = mf.id").
Where(Eq{"mf.id": nil}).
GroupBy("playlist_tracks.playlist_id")
var pls []struct{ Id, Name string }
err := r.queryAll(sel, &pls)
if err != nil {
return fmt.Errorf("fetching playlists with orphan tracks: %w", err)
}
for _, pl := range pls {
log.Debug(r.ctx, "Cleaning-up orphan tracks from playlist", "id", pl.Id, "name", pl.Name)
del := Delete("playlist_tracks").Where(And{
ConcatExpr("media_file_id not in (select id from media_file)"),
Eq{"playlist_id": pl.Id},
})
n, err := r.executeSQL(del)
if n == 0 || err != nil {
return fmt.Errorf("deleting orphan tracks from playlist %s: %w", pl.Name, err)
}
log.Debug(r.ctx, "Deleted tracks, now reordering", "id", pl.Id, "name", pl.Name, "deleted", n)
// Renumber the playlist if any track was removed
if err := r.renumber(pl.Id); err != nil {
return fmt.Errorf("renumbering playlist %s: %w", pl.Name, err)
}
}
return nil
}
func (r *playlistRepository) renumber(id string) error {
var ids []string
sq := Select("media_file_id").From("playlist_tracks").Where(Eq{"playlist_id": id}).OrderBy("id")
err := r.queryAllSlice(sq, &ids)
if err != nil {
return err
}
return r.updatePlaylist(id, ids)
}
func (r *playlistRepository) isWritable(playlistId string) bool {
usr := loggedUser(r.ctx)
if usr.IsAdmin {
return true
}
pls, err := r.Get(playlistId)
return err == nil && pls.OwnerID == usr.ID
}
var _ model.PlaylistRepository = (*playlistRepository)(nil)
var _ rest.Repository = (*playlistRepository)(nil)
var _ rest.Persistable = (*playlistRepository)(nil)

View File

@@ -0,0 +1,501 @@
package persistence
import (
"time"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/criteria"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
var _ = Describe("PlaylistRepository", func() {
var repo model.PlaylistRepository
BeforeEach(func() {
ctx := log.NewContext(GinkgoT().Context())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
repo = NewPlaylistRepository(ctx, GetDBXBuilder())
})
Describe("Count", func() {
It("returns the number of playlists in the DB", func() {
Expect(repo.CountAll()).To(Equal(int64(2)))
})
})
Describe("Exists", func() {
It("returns true for an existing playlist", func() {
Expect(repo.Exists(plsCool.ID)).To(BeTrue())
})
It("returns false for a non-existing playlist", func() {
Expect(repo.Exists("666")).To(BeFalse())
})
})
Describe("Get", func() {
It("returns an existing playlist", func() {
p, err := repo.Get(plsBest.ID)
Expect(err).To(BeNil())
// Compare all but Tracks and timestamps
p2 := *p
p2.Tracks = plsBest.Tracks
p2.UpdatedAt = plsBest.UpdatedAt
p2.CreatedAt = plsBest.CreatedAt
Expect(p2).To(Equal(plsBest))
// Compare tracks
for i := range p.Tracks {
Expect(p.Tracks[i].ID).To(Equal(plsBest.Tracks[i].ID))
}
})
It("returns ErrNotFound for a non-existing playlist", func() {
_, err := repo.Get("666")
Expect(err).To(MatchError(model.ErrNotFound))
})
It("returns all tracks", func() {
pls, err := repo.GetWithTracks(plsBest.ID, true, false)
Expect(err).ToNot(HaveOccurred())
Expect(pls.Name).To(Equal(plsBest.Name))
Expect(pls.Tracks).To(HaveLen(2))
Expect(pls.Tracks[0].ID).To(Equal("1"))
Expect(pls.Tracks[0].PlaylistID).To(Equal(plsBest.ID))
Expect(pls.Tracks[0].MediaFileID).To(Equal(songDayInALife.ID))
Expect(pls.Tracks[0].MediaFile.ID).To(Equal(songDayInALife.ID))
Expect(pls.Tracks[1].ID).To(Equal("2"))
Expect(pls.Tracks[1].PlaylistID).To(Equal(plsBest.ID))
Expect(pls.Tracks[1].MediaFileID).To(Equal(songRadioactivity.ID))
Expect(pls.Tracks[1].MediaFile.ID).To(Equal(songRadioactivity.ID))
mfs := pls.MediaFiles()
Expect(mfs).To(HaveLen(2))
Expect(mfs[0].ID).To(Equal(songDayInALife.ID))
Expect(mfs[1].ID).To(Equal(songRadioactivity.ID))
})
})
It("Put/Exists/Delete", func() {
By("saves the playlist to the DB")
newPls := model.Playlist{Name: "Great!", OwnerID: "userid"}
newPls.AddMediaFilesByID([]string{"1004", "1003"})
By("saves the playlist to the DB")
Expect(repo.Put(&newPls)).To(BeNil())
By("adds repeated songs to a playlist and keeps the order")
newPls.AddMediaFilesByID([]string{"1004"})
Expect(repo.Put(&newPls)).To(BeNil())
saved, _ := repo.GetWithTracks(newPls.ID, true, false)
Expect(saved.Tracks).To(HaveLen(3))
Expect(saved.Tracks[0].MediaFileID).To(Equal("1004"))
Expect(saved.Tracks[1].MediaFileID).To(Equal("1003"))
Expect(saved.Tracks[2].MediaFileID).To(Equal("1004"))
By("returns the newly created playlist")
Expect(repo.Exists(newPls.ID)).To(BeTrue())
By("returns deletes the playlist")
Expect(repo.Delete(newPls.ID)).To(BeNil())
By("returns error if tries to retrieve the deleted playlist")
Expect(repo.Exists(newPls.ID)).To(BeFalse())
})
Describe("GetAll", func() {
It("returns all playlists from DB", func() {
all, err := repo.GetAll()
Expect(err).To(BeNil())
Expect(all[0].ID).To(Equal(plsBest.ID))
Expect(all[1].ID).To(Equal(plsCool.ID))
})
})
Describe("GetPlaylists", func() {
It("returns playlists for a track", func() {
pls, err := repo.GetPlaylists(songRadioactivity.ID)
Expect(err).ToNot(HaveOccurred())
Expect(pls).To(HaveLen(1))
Expect(pls[0].ID).To(Equal(plsBest.ID))
})
It("returns empty when none", func() {
pls, err := repo.GetPlaylists("9999")
Expect(err).ToNot(HaveOccurred())
Expect(pls).To(HaveLen(0))
})
})
Context("Smart Playlists", func() {
var rules *criteria.Criteria
BeforeEach(func() {
rules = &criteria.Criteria{
Expression: criteria.All{
criteria.Contains{"title": "love"},
},
}
})
Context("valid rules", func() {
Specify("Put/Get", func() {
newPls := model.Playlist{Name: "Great!", OwnerID: "userid", Rules: rules}
Expect(repo.Put(&newPls)).To(Succeed())
savedPls, err := repo.Get(newPls.ID)
Expect(err).ToNot(HaveOccurred())
Expect(savedPls.Rules).To(Equal(rules))
})
})
Context("invalid rules", func() {
It("fails to Put it in the DB", func() {
rules = &criteria.Criteria{
// This is invalid because "contains" cannot have multiple fields
Expression: criteria.All{
criteria.Contains{"genre": "Hardcore", "filetype": "mp3"},
},
}
newPls := model.Playlist{Name: "Great!", OwnerID: "userid", Rules: rules}
Expect(repo.Put(&newPls)).To(MatchError(ContainSubstring("invalid criteria expression")))
})
})
// TODO Validate these tests
XContext("child smart playlists", func() {
When("refresh day has expired", func() {
It("should refresh tracks for smart playlist referenced in parent smart playlist criteria", func() {
conf.Server.SmartPlaylistRefreshDelay = -1 * time.Second
nestedPls := model.Playlist{Name: "Nested", OwnerID: "userid", Rules: rules}
Expect(repo.Put(&nestedPls)).To(Succeed())
parentPls := model.Playlist{Name: "Parent", OwnerID: "userid", Rules: &criteria.Criteria{
Expression: criteria.All{
criteria.InPlaylist{"id": nestedPls.ID},
},
}}
Expect(repo.Put(&parentPls)).To(Succeed())
nestedPlsRead, err := repo.Get(nestedPls.ID)
Expect(err).ToNot(HaveOccurred())
_, err = repo.GetWithTracks(parentPls.ID, true, false)
Expect(err).ToNot(HaveOccurred())
// Check that the nested playlist was refreshed by parent get by verifying evaluatedAt is updated since first nestedPls get
nestedPlsAfterParentGet, err := repo.Get(nestedPls.ID)
Expect(err).ToNot(HaveOccurred())
Expect(*nestedPlsAfterParentGet.EvaluatedAt).To(BeTemporally(">", *nestedPlsRead.EvaluatedAt))
})
})
When("refresh day has not expired", func() {
It("should NOT refresh tracks for smart playlist referenced in parent smart playlist criteria", func() {
conf.Server.SmartPlaylistRefreshDelay = 1 * time.Hour
nestedPls := model.Playlist{Name: "Nested", OwnerID: "userid", Rules: rules}
Expect(repo.Put(&nestedPls)).To(Succeed())
parentPls := model.Playlist{Name: "Parent", OwnerID: "userid", Rules: &criteria.Criteria{
Expression: criteria.All{
criteria.InPlaylist{"id": nestedPls.ID},
},
}}
Expect(repo.Put(&parentPls)).To(Succeed())
nestedPlsRead, err := repo.Get(nestedPls.ID)
Expect(err).ToNot(HaveOccurred())
_, err = repo.GetWithTracks(parentPls.ID, true, false)
Expect(err).ToNot(HaveOccurred())
// Check that the nested playlist was not refreshed by parent get by verifying evaluatedAt is not updated since first nestedPls get
nestedPlsAfterParentGet, err := repo.Get(nestedPls.ID)
Expect(err).ToNot(HaveOccurred())
Expect(*nestedPlsAfterParentGet.EvaluatedAt).To(Equal(*nestedPlsRead.EvaluatedAt))
})
})
})
})
Describe("Playlist Track Sorting", func() {
var testPlaylistID string
AfterEach(func() {
if testPlaylistID != "" {
Expect(repo.Delete(testPlaylistID)).To(BeNil())
testPlaylistID = ""
}
})
It("sorts tracks correctly by album (disc and track number)", func() {
By("creating a playlist with multi-disc album tracks in arbitrary order")
newPls := model.Playlist{Name: "Multi-Disc Test", OwnerID: "userid"}
// Add tracks in intentionally scrambled order
newPls.AddMediaFilesByID([]string{"2001", "2002", "2003", "2004"})
Expect(repo.Put(&newPls)).To(Succeed())
testPlaylistID = newPls.ID
By("retrieving tracks sorted by album")
tracksRepo := repo.Tracks(newPls.ID, false)
tracks, err := tracksRepo.GetAll(model.QueryOptions{Sort: "album", Order: "asc"})
Expect(err).ToNot(HaveOccurred())
By("verifying tracks are sorted by disc number then track number")
Expect(tracks).To(HaveLen(4))
// Expected order: Disc 1 Track 1, Disc 1 Track 2, Disc 2 Track 1, Disc 2 Track 11
Expect(tracks[0].MediaFileID).To(Equal("2002")) // Disc 1, Track 1
Expect(tracks[1].MediaFileID).To(Equal("2004")) // Disc 1, Track 2
Expect(tracks[2].MediaFileID).To(Equal("2003")) // Disc 2, Track 1
Expect(tracks[3].MediaFileID).To(Equal("2001")) // Disc 2, Track 11
})
})
Describe("Smart Playlists with Tag Criteria", func() {
var mfRepo model.MediaFileRepository
var testPlaylistID string
var songWithGrouping, songWithoutGrouping model.MediaFile
BeforeEach(func() {
ctx := log.NewContext(GinkgoT().Context())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
mfRepo = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
// Register 'grouping' as a valid tag for smart playlists
criteria.AddTagNames([]string{"grouping"})
// Create a song with the grouping tag
songWithGrouping = model.MediaFile{
ID: "test-grouping-1",
Title: "Song With Grouping",
Artist: "Test Artist",
ArtistID: "1",
Album: "Test Album",
AlbumID: "101",
Path: "/test/grouping/song1.mp3",
Tags: model.Tags{
"grouping": []string{"My Crate"},
},
Participants: model.Participants{},
LibraryID: 1,
Lyrics: "[]",
}
Expect(mfRepo.Put(&songWithGrouping)).To(Succeed())
// Create a song without the grouping tag
songWithoutGrouping = model.MediaFile{
ID: "test-grouping-2",
Title: "Song Without Grouping",
Artist: "Test Artist",
ArtistID: "1",
Album: "Test Album",
AlbumID: "101",
Path: "/test/grouping/song2.mp3",
Tags: model.Tags{},
Participants: model.Participants{},
LibraryID: 1,
Lyrics: "[]",
}
Expect(mfRepo.Put(&songWithoutGrouping)).To(Succeed())
})
AfterEach(func() {
if testPlaylistID != "" {
_ = repo.Delete(testPlaylistID)
testPlaylistID = ""
}
// Clean up test media files
_, _ = GetDBXBuilder().Delete("media_file", dbx.HashExp{"id": "test-grouping-1"}).Execute()
_, _ = GetDBXBuilder().Delete("media_file", dbx.HashExp{"id": "test-grouping-2"}).Execute()
})
It("matches tracks with a tag value using 'contains' with empty string (issue #4728 workaround)", func() {
By("creating a smart playlist that checks if grouping tag has any value")
// This is the workaround for issue #4728: using 'contains' with empty string
// generates SQL: value LIKE '%%' which matches any non-empty string
rules := &criteria.Criteria{
Expression: criteria.All{
criteria.Contains{"grouping": ""},
},
}
newPls := model.Playlist{Name: "Tracks with Grouping", OwnerID: "userid", Rules: rules}
Expect(repo.Put(&newPls)).To(Succeed())
testPlaylistID = newPls.ID
By("refreshing the smart playlist")
conf.Server.SmartPlaylistRefreshDelay = -1 * time.Second // Force refresh
pls, err := repo.GetWithTracks(newPls.ID, true, false)
Expect(err).ToNot(HaveOccurred())
By("verifying only the track with grouping tag is matched")
Expect(pls.Tracks).To(HaveLen(1))
Expect(pls.Tracks[0].MediaFileID).To(Equal(songWithGrouping.ID))
})
It("excludes tracks with a tag value using 'notContains' with empty string", func() {
By("creating a smart playlist that checks if grouping tag is NOT set")
rules := &criteria.Criteria{
Expression: criteria.All{
criteria.NotContains{"grouping": ""},
},
}
newPls := model.Playlist{Name: "Tracks without Grouping", OwnerID: "userid", Rules: rules}
Expect(repo.Put(&newPls)).To(Succeed())
testPlaylistID = newPls.ID
By("refreshing the smart playlist")
conf.Server.SmartPlaylistRefreshDelay = -1 * time.Second // Force refresh
pls, err := repo.GetWithTracks(newPls.ID, true, false)
Expect(err).ToNot(HaveOccurred())
By("verifying the track with grouping is NOT in the playlist")
for _, track := range pls.Tracks {
Expect(track.MediaFileID).ToNot(Equal(songWithGrouping.ID))
}
By("verifying the track without grouping IS in the playlist")
var foundWithoutGrouping bool
for _, track := range pls.Tracks {
if track.MediaFileID == songWithoutGrouping.ID {
foundWithoutGrouping = true
break
}
}
Expect(foundWithoutGrouping).To(BeTrue())
})
})
Describe("Smart Playlists Library Filtering", func() {
var mfRepo model.MediaFileRepository
var testPlaylistID string
var lib2ID int
var restrictedUserID string
var uniqueLibPath string
BeforeEach(func() {
db := GetDBXBuilder()
// Generate unique IDs for this test run
uniqueSuffix := time.Now().Format("20060102150405.000")
restrictedUserID = "restricted-user-" + uniqueSuffix
uniqueLibPath = "/music/lib2-" + uniqueSuffix
// Create a second library with unique name and path to avoid conflicts with other tests
_, err := db.DB().Exec("INSERT INTO library (name, path, created_at, updated_at) VALUES (?, ?, datetime('now'), datetime('now'))", "Library 2-"+uniqueSuffix, uniqueLibPath)
Expect(err).ToNot(HaveOccurred())
err = db.DB().QueryRow("SELECT last_insert_rowid()").Scan(&lib2ID)
Expect(err).ToNot(HaveOccurred())
// Create a restricted user with access only to library 1
_, err = db.DB().Exec("INSERT INTO user (id, user_name, name, is_admin, password, created_at, updated_at) VALUES (?, ?, 'Restricted User', false, 'pass', datetime('now'), datetime('now'))", restrictedUserID, restrictedUserID)
Expect(err).ToNot(HaveOccurred())
_, err = db.DB().Exec("INSERT INTO user_library (user_id, library_id) VALUES (?, 1)", restrictedUserID)
Expect(err).ToNot(HaveOccurred())
// Create test media files in each library
ctx := log.NewContext(GinkgoT().Context())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
mfRepo = NewMediaFileRepository(ctx, db, nil)
// Song in library 1 (accessible by restricted user)
songLib1 := model.MediaFile{
ID: "lib1-song",
Title: "Song in Lib1",
Artist: "Test Artist",
ArtistID: "1",
Album: "Test Album",
AlbumID: "101",
Path: "/music/lib1/song.mp3",
LibraryID: 1,
Participants: model.Participants{},
Tags: model.Tags{},
Lyrics: "[]",
}
Expect(mfRepo.Put(&songLib1)).To(Succeed())
// Song in library 2 (NOT accessible by restricted user)
songLib2 := model.MediaFile{
ID: "lib2-song",
Title: "Song in Lib2",
Artist: "Test Artist",
ArtistID: "1",
Album: "Test Album",
AlbumID: "101",
Path: uniqueLibPath + "/song.mp3",
LibraryID: lib2ID,
Participants: model.Participants{},
Tags: model.Tags{},
Lyrics: "[]",
}
Expect(mfRepo.Put(&songLib2)).To(Succeed())
})
AfterEach(func() {
db := GetDBXBuilder()
if testPlaylistID != "" {
_ = repo.Delete(testPlaylistID)
testPlaylistID = ""
}
// Clean up test data
_, _ = db.Delete("media_file", dbx.HashExp{"id": "lib1-song"}).Execute()
_, _ = db.Delete("media_file", dbx.HashExp{"id": "lib2-song"}).Execute()
_, _ = db.Delete("user_library", dbx.HashExp{"user_id": restrictedUserID}).Execute()
_, _ = db.Delete("user", dbx.HashExp{"id": restrictedUserID}).Execute()
_, _ = db.DB().Exec("DELETE FROM library WHERE id = ?", lib2ID)
})
It("should only include tracks from libraries the user has access to (issue #4738)", func() {
db := GetDBXBuilder()
ctx := log.NewContext(GinkgoT().Context())
// Create the smart playlist as the restricted user
restrictedUser := model.User{ID: restrictedUserID, UserName: restrictedUserID, IsAdmin: false}
ctx = request.WithUser(ctx, restrictedUser)
restrictedRepo := NewPlaylistRepository(ctx, db)
// Create a smart playlist that matches all songs
rules := &criteria.Criteria{
Expression: criteria.All{
criteria.Gt{"playCount": -1}, // Matches everything
},
}
newPls := model.Playlist{Name: "All Songs", OwnerID: restrictedUserID, Rules: rules}
Expect(restrictedRepo.Put(&newPls)).To(Succeed())
testPlaylistID = newPls.ID
By("refreshing the smart playlist")
conf.Server.SmartPlaylistRefreshDelay = -1 * time.Second // Force refresh
pls, err := restrictedRepo.GetWithTracks(newPls.ID, true, false)
Expect(err).ToNot(HaveOccurred())
By("verifying only the track from library 1 is in the playlist")
var foundLib1Song, foundLib2Song bool
for _, track := range pls.Tracks {
if track.MediaFileID == "lib1-song" {
foundLib1Song = true
}
if track.MediaFileID == "lib2-song" {
foundLib2Song = true
}
}
Expect(foundLib1Song).To(BeTrue(), "Song from library 1 should be in the playlist")
Expect(foundLib2Song).To(BeFalse(), "Song from library 2 should NOT be in the playlist")
By("verifying playlist_tracks table only contains the accessible track")
var playlistTracksCount int
err = db.DB().QueryRow("SELECT count(*) FROM playlist_tracks WHERE playlist_id = ?", newPls.ID).Scan(&playlistTracksCount)
Expect(err).ToNot(HaveOccurred())
// Count should only include tracks visible to the user (lib1-song)
// The count may include other test songs from library 1, but NOT lib2-song
var lib2TrackCount int
err = db.DB().QueryRow("SELECT count(*) FROM playlist_tracks WHERE playlist_id = ? AND media_file_id = 'lib2-song'", newPls.ID).Scan(&lib2TrackCount)
Expect(err).ToNot(HaveOccurred())
Expect(lib2TrackCount).To(Equal(0), "lib2-song should not be in playlist_tracks")
By("verifying SongCount matches visible tracks")
Expect(pls.SongCount).To(Equal(len(pls.Tracks)), "SongCount should match the number of visible tracks")
})
})
})

View File

@@ -0,0 +1,247 @@
package persistence
import (
"database/sql"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/slice"
)
type playlistTrackRepository struct {
sqlRepository
playlistId string
playlist *model.Playlist
playlistRepo *playlistRepository
}
type dbPlaylistTrack struct {
dbMediaFile
*model.PlaylistTrack `structs:",flatten"`
}
func (t *dbPlaylistTrack) PostScan() error {
if err := t.dbMediaFile.PostScan(); err != nil {
return err
}
t.PlaylistTrack.MediaFile = *t.dbMediaFile.MediaFile
t.PlaylistTrack.MediaFile.ID = t.MediaFileID
return nil
}
type dbPlaylistTracks []dbPlaylistTrack
func (t dbPlaylistTracks) toModels() model.PlaylistTracks {
return slice.Map(t, func(trk dbPlaylistTrack) model.PlaylistTrack {
return *trk.PlaylistTrack
})
}
func (r *playlistRepository) Tracks(playlistId string, refreshSmartPlaylist bool) model.PlaylistTrackRepository {
p := &playlistTrackRepository{}
p.playlistRepo = r
p.playlistId = playlistId
p.ctx = r.ctx
p.db = r.db
p.tableName = "playlist_tracks"
p.registerModel(&model.PlaylistTrack{}, map[string]filterFunc{
"missing": booleanFilter,
"library_id": libraryIdFilter,
})
p.setSortMappings(
map[string]string{
"id": "playlist_tracks.id",
"artist": "order_artist_name",
"album_artist": "order_album_artist_name",
"album": "order_album_name, album_id, disc_number, track_number, order_artist_name, title",
"title": "order_title",
// To make sure these fields will be whitelisted
"duration": "duration",
"year": "year",
"bpm": "bpm",
"channels": "channels",
},
"f") // TODO I don't like this solution, but I won't change it now as it's not the focus of BFR.
pls, err := r.Get(playlistId)
if err != nil {
log.Warn(r.ctx, "Error getting playlist's tracks", "playlistId", playlistId, err)
return nil
}
if refreshSmartPlaylist {
r.refreshSmartPlaylist(pls)
}
p.playlist = pls
return p
}
func (r *playlistTrackRepository) Count(options ...rest.QueryOptions) (int64, error) {
query := Select().
LeftJoin("media_file f on f.id = media_file_id").
Where(Eq{"playlist_id": r.playlistId})
return r.count(query, r.parseRestOptions(r.ctx, options...))
}
func (r *playlistTrackRepository) Read(id string) (interface{}, error) {
userID := loggedUser(r.ctx).ID
sel := r.newSelect().
LeftJoin("annotation on ("+
"annotation.item_id = media_file_id"+
" AND annotation.item_type = 'media_file'"+
" AND annotation.user_id = '"+userID+"')").
Columns(
"coalesce(starred, 0) as starred",
"coalesce(play_count, 0) as play_count",
"coalesce(rating, 0) as rating",
"starred_at",
"play_date",
"rated_at",
"f.*",
"playlist_tracks.*",
).
Join("media_file f on f.id = media_file_id").
Where(And{Eq{"playlist_id": r.playlistId}, Eq{"playlist_tracks.id": id}})
var trk dbPlaylistTrack
err := r.queryOne(sel, &trk)
return trk.PlaylistTrack, err
}
func (r *playlistTrackRepository) GetAll(options ...model.QueryOptions) (model.PlaylistTracks, error) {
tracks, err := r.playlistRepo.loadTracks(r.newSelect(options...), r.playlistId)
if err != nil {
return nil, err
}
return tracks, err
}
func (r *playlistTrackRepository) GetAlbumIDs(options ...model.QueryOptions) ([]string, error) {
query := r.newSelect(options...).Columns("distinct mf.album_id").
Join("media_file mf on mf.id = media_file_id").
Where(Eq{"playlist_id": r.playlistId})
var ids []string
err := r.queryAllSlice(query, &ids)
if err != nil {
return nil, err
}
return ids, nil
}
func (r *playlistTrackRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *playlistTrackRepository) EntityName() string {
return "playlist_tracks"
}
func (r *playlistTrackRepository) NewInstance() interface{} {
return &model.PlaylistTrack{}
}
func (r *playlistTrackRepository) isTracksEditable() bool {
return r.playlistRepo.isWritable(r.playlistId) && !r.playlist.IsSmartPlaylist()
}
func (r *playlistTrackRepository) Add(mediaFileIds []string) (int, error) {
if !r.isTracksEditable() {
return 0, rest.ErrPermissionDenied
}
if len(mediaFileIds) > 0 {
log.Debug(r.ctx, "Adding songs to playlist", "playlistId", r.playlistId, "mediaFileIds", mediaFileIds)
} else {
return 0, nil
}
// Get next pos (ID) in playlist
sq := r.newSelect().Columns("max(id) as max").Where(Eq{"playlist_id": r.playlistId})
var res struct{ Max sql.NullInt32 }
err := r.queryOne(sq, &res)
if err != nil {
return 0, err
}
return len(mediaFileIds), r.playlistRepo.addTracks(r.playlistId, int(res.Max.Int32+1), mediaFileIds)
}
func (r *playlistTrackRepository) addMediaFileIds(cond Sqlizer) (int, error) {
sq := Select("id").From("media_file").Where(cond).OrderBy("album_artist, album, release_date, disc_number, track_number")
var ids []string
err := r.queryAllSlice(sq, &ids)
if err != nil {
log.Error(r.ctx, "Error getting tracks to add to playlist", err)
return 0, err
}
return r.Add(ids)
}
func (r *playlistTrackRepository) AddAlbums(albumIds []string) (int, error) {
return r.addMediaFileIds(Eq{"album_id": albumIds})
}
func (r *playlistTrackRepository) AddArtists(artistIds []string) (int, error) {
return r.addMediaFileIds(Eq{"album_artist_id": artistIds})
}
func (r *playlistTrackRepository) AddDiscs(discs []model.DiscID) (int, error) {
if len(discs) == 0 {
return 0, nil
}
var clauses Or
for _, d := range discs {
clauses = append(clauses, And{Eq{"album_id": d.AlbumID}, Eq{"release_date": d.ReleaseDate}, Eq{"disc_number": d.DiscNumber}})
}
return r.addMediaFileIds(clauses)
}
// Get ids from all current tracks
func (r *playlistTrackRepository) getTracks() ([]string, error) {
all := r.newSelect().Columns("media_file_id").Where(Eq{"playlist_id": r.playlistId}).OrderBy("id")
var ids []string
err := r.queryAllSlice(all, &ids)
if err != nil {
log.Error(r.ctx, "Error querying current tracks from playlist", "playlistId", r.playlistId, err)
return nil, err
}
return ids, nil
}
func (r *playlistTrackRepository) Delete(ids ...string) error {
if !r.isTracksEditable() {
return rest.ErrPermissionDenied
}
err := r.delete(And{Eq{"playlist_id": r.playlistId}, Eq{"id": ids}})
if err != nil {
return err
}
return r.playlistRepo.renumber(r.playlistId)
}
func (r *playlistTrackRepository) DeleteAll() error {
if !r.isTracksEditable() {
return rest.ErrPermissionDenied
}
err := r.delete(Eq{"playlist_id": r.playlistId})
if err != nil {
return err
}
return r.playlistRepo.renumber(r.playlistId)
}
func (r *playlistTrackRepository) Reorder(pos int, newPos int) error {
if !r.isTracksEditable() {
return rest.ErrPermissionDenied
}
ids, err := r.getTracks()
if err != nil {
return err
}
newOrder := slice.Move(ids, pos-1, newPos-1)
return r.playlistRepo.updatePlaylist(r.playlistId, newOrder)
}
var _ model.PlaylistTrackRepository = (*playlistTrackRepository)(nil)

View File

@@ -0,0 +1,178 @@
package persistence
import (
"context"
"errors"
"strings"
"time"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
type playQueueRepository struct {
sqlRepository
}
func NewPlayQueueRepository(ctx context.Context, db dbx.Builder) model.PlayQueueRepository {
r := &playQueueRepository{}
r.ctx = ctx
r.db = db
r.tableName = "playqueue"
return r
}
type playQueue struct {
ID string `structs:"id"`
UserID string `structs:"user_id"`
Current int `structs:"current"`
Position int64 `structs:"position"`
ChangedBy string `structs:"changed_by"`
Items string `structs:"items"`
CreatedAt time.Time `structs:"created_at"`
UpdatedAt time.Time `structs:"updated_at"`
}
func (r *playQueueRepository) Store(q *model.PlayQueue, colNames ...string) error {
u := loggedUser(r.ctx)
// Always find existing playqueue for this user
existingQueue, err := r.Retrieve(q.UserID)
if err != nil && !errors.Is(err, model.ErrNotFound) {
log.Error(r.ctx, "Error retrieving existing playqueue", "user", u.UserName, err)
return err
}
// Use existing ID if found, otherwise keep the provided ID (which may be empty for new records)
if !errors.Is(err, model.ErrNotFound) && existingQueue.ID != "" {
q.ID = existingQueue.ID
}
// When no specific columns are provided, we replace the whole queue
if len(colNames) == 0 {
err := r.clearPlayQueue(q.UserID)
if err != nil {
log.Error(r.ctx, "Error deleting previous playqueue", "user", u.UserName, err)
return err
}
if len(q.Items) == 0 {
return nil
}
}
pq := r.fromModel(q)
if pq.ID == "" {
pq.CreatedAt = time.Now()
}
pq.UpdatedAt = time.Now()
_, err = r.put(pq.ID, pq, colNames...)
if err != nil {
log.Error(r.ctx, "Error saving playqueue", "user", u.UserName, err)
return err
}
return nil
}
func (r *playQueueRepository) RetrieveWithMediaFiles(userId string) (*model.PlayQueue, error) {
sel := r.newSelect().Columns("*").Where(Eq{"user_id": userId})
var res playQueue
err := r.queryOne(sel, &res)
q := r.toModel(&res)
q.Items = r.loadTracks(q.Items)
return &q, err
}
func (r *playQueueRepository) Retrieve(userId string) (*model.PlayQueue, error) {
sel := r.newSelect().Columns("*").Where(Eq{"user_id": userId})
var res playQueue
err := r.queryOne(sel, &res)
q := r.toModel(&res)
return &q, err
}
func (r *playQueueRepository) fromModel(q *model.PlayQueue) playQueue {
pq := playQueue{
ID: q.ID,
UserID: q.UserID,
Current: q.Current,
Position: q.Position,
ChangedBy: q.ChangedBy,
CreatedAt: q.CreatedAt,
UpdatedAt: q.UpdatedAt,
}
var itemIDs []string
for _, t := range q.Items {
itemIDs = append(itemIDs, t.ID)
}
pq.Items = strings.Join(itemIDs, ",")
return pq
}
func (r *playQueueRepository) toModel(pq *playQueue) model.PlayQueue {
q := model.PlayQueue{
ID: pq.ID,
UserID: pq.UserID,
Current: pq.Current,
Position: pq.Position,
ChangedBy: pq.ChangedBy,
CreatedAt: pq.CreatedAt,
UpdatedAt: pq.UpdatedAt,
}
if strings.TrimSpace(pq.Items) != "" {
tracks := strings.Split(pq.Items, ",")
for _, t := range tracks {
q.Items = append(q.Items, model.MediaFile{ID: t})
}
}
return q
}
// loadTracks loads the tracks from the database. It receives a list of track IDs and returns a list of MediaFiles
// in the same order as the input list.
func (r *playQueueRepository) loadTracks(tracks model.MediaFiles) model.MediaFiles {
if len(tracks) == 0 {
return nil
}
mfRepo := NewMediaFileRepository(r.ctx, r.db, nil)
trackMap := map[string]model.MediaFile{}
// Create an iterator to collect all track IDs
ids := slice.SeqFunc(tracks, func(t model.MediaFile) string { return t.ID })
// Break the list in chunks, up to 500 items, to avoid hitting SQLITE_MAX_VARIABLE_NUMBER limit
for chunk := range slice.CollectChunks(ids, 500) {
idsFilter := Eq{"media_file.id": chunk}
tracks, err := mfRepo.GetAll(model.QueryOptions{Filters: idsFilter})
if err != nil {
u := loggedUser(r.ctx)
log.Error(r.ctx, "Could not load playqueue/bookmark's tracks", "user", u.UserName, err)
}
for _, t := range tracks {
trackMap[t.ID] = t
}
}
// Create a new list of tracks with the same order as the original
// Exclude tracks that are not in the DB anymore
newTracks := make(model.MediaFiles, 0, len(tracks))
for _, t := range tracks {
if track, ok := trackMap[t.ID]; ok {
newTracks = append(newTracks, track)
}
}
return newTracks
}
func (r *playQueueRepository) clearPlayQueue(userId string) error {
return r.delete(Eq{"user_id": userId})
}
func (r *playQueueRepository) Clear(userId string) error {
return r.clearPlayQueue(userId)
}
var _ model.PlayQueueRepository = (*playQueueRepository)(nil)

View File

@@ -0,0 +1,435 @@
package persistence
import (
"context"
"time"
"github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/conf/configtest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("PlayQueueRepository", func() {
var repo model.PlayQueueRepository
var ctx context.Context
BeforeEach(func() {
DeferCleanup(configtest.SetupConfig())
ctx = log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
repo = NewPlayQueueRepository(ctx, GetDBXBuilder())
})
Describe("Store", func() {
It("stores a complete playqueue", func() {
expected := aPlayQueue("userid", 1, 123, songComeTogether, songDayInALife)
Expect(repo.Store(expected)).To(Succeed())
actual, err := repo.RetrieveWithMediaFiles("userid")
Expect(err).ToNot(HaveOccurred())
AssertPlayQueue(expected, actual)
Expect(countPlayQueues(repo, "userid")).To(Equal(1))
})
It("replaces existing playqueue when storing without column names", func() {
By("Storing initial playqueue")
initial := aPlayQueue("userid", 0, 100, songComeTogether)
Expect(repo.Store(initial)).To(Succeed())
By("Storing replacement playqueue")
replacement := aPlayQueue("userid", 1, 200, songDayInALife, songAntenna)
Expect(repo.Store(replacement)).To(Succeed())
actual, err := repo.RetrieveWithMediaFiles("userid")
Expect(err).ToNot(HaveOccurred())
AssertPlayQueue(replacement, actual)
Expect(countPlayQueues(repo, "userid")).To(Equal(1))
})
It("clears playqueue when storing empty items", func() {
By("Storing initial playqueue")
initial := aPlayQueue("userid", 0, 100, songComeTogether)
Expect(repo.Store(initial)).To(Succeed())
By("Storing empty playqueue")
empty := aPlayQueue("userid", 0, 0)
Expect(repo.Store(empty)).To(Succeed())
By("Verifying playqueue is cleared")
_, err := repo.Retrieve("userid")
Expect(err).To(MatchError(model.ErrNotFound))
})
It("updates only current field when specified", func() {
By("Storing initial playqueue")
initial := aPlayQueue("userid", 0, 100, songComeTogether, songDayInALife)
Expect(repo.Store(initial)).To(Succeed())
By("Getting the existing playqueue to obtain its ID")
existing, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
By("Updating only current field")
update := &model.PlayQueue{
ID: existing.ID, // Use existing ID for partial update
UserID: "userid",
Current: 1,
ChangedBy: "test-update",
}
Expect(repo.Store(update, "current")).To(Succeed())
By("Verifying only current was updated")
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Current).To(Equal(1))
Expect(actual.Position).To(Equal(int64(100))) // Should remain unchanged
Expect(actual.Items).To(HaveLen(2)) // Should remain unchanged
})
It("updates only position field when specified", func() {
By("Storing initial playqueue")
initial := aPlayQueue("userid", 1, 100, songComeTogether, songDayInALife)
Expect(repo.Store(initial)).To(Succeed())
By("Getting the existing playqueue to obtain its ID")
existing, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
By("Updating only position field")
update := &model.PlayQueue{
ID: existing.ID, // Use existing ID for partial update
UserID: "userid",
Position: 500,
ChangedBy: "test-update",
}
Expect(repo.Store(update, "position")).To(Succeed())
By("Verifying only position was updated")
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Position).To(Equal(int64(500)))
Expect(actual.Current).To(Equal(1)) // Should remain unchanged
Expect(actual.Items).To(HaveLen(2)) // Should remain unchanged
})
It("updates multiple specified fields", func() {
By("Storing initial playqueue")
initial := aPlayQueue("userid", 0, 100, songComeTogether)
Expect(repo.Store(initial)).To(Succeed())
By("Getting the existing playqueue to obtain its ID")
existing, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
By("Updating current and position fields")
update := &model.PlayQueue{
ID: existing.ID, // Use existing ID for partial update
UserID: "userid",
Current: 1,
Position: 300,
ChangedBy: "test-update",
}
Expect(repo.Store(update, "current", "position")).To(Succeed())
By("Verifying both fields were updated")
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Current).To(Equal(1))
Expect(actual.Position).To(Equal(int64(300)))
Expect(actual.Items).To(HaveLen(1)) // Should remain unchanged
})
It("preserves existing data when updating with empty items list and column names", func() {
By("Storing initial playqueue")
initial := aPlayQueue("userid", 0, 100, songComeTogether, songDayInALife)
Expect(repo.Store(initial)).To(Succeed())
By("Getting the existing playqueue to obtain its ID")
existing, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
By("Updating only position with empty items")
update := &model.PlayQueue{
ID: existing.ID, // Use existing ID for partial update
UserID: "userid",
Position: 200,
ChangedBy: "test-update",
Items: []model.MediaFile{}, // Empty items
}
Expect(repo.Store(update, "position")).To(Succeed())
By("Verifying items are preserved")
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Position).To(Equal(int64(200)))
Expect(actual.Items).To(HaveLen(2)) // Should remain unchanged
})
It("ensures only one record per user by reusing existing record ID", func() {
By("Storing initial playqueue")
initial := aPlayQueue("userid", 0, 100, songComeTogether)
Expect(repo.Store(initial)).To(Succeed())
initialCount := countPlayQueues(repo, "userid")
Expect(initialCount).To(Equal(1))
By("Storing another playqueue with different ID but same user")
different := aPlayQueue("userid", 1, 200, songDayInALife)
different.ID = "different-id" // Force a different ID
Expect(repo.Store(different)).To(Succeed())
By("Verifying only one record exists for the user")
finalCount := countPlayQueues(repo, "userid")
Expect(finalCount).To(Equal(1))
By("Verifying the record was updated, not duplicated")
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Current).To(Equal(1)) // Should be updated value
Expect(actual.Position).To(Equal(int64(200))) // Should be updated value
Expect(actual.Items).To(HaveLen(1)) // Should be new items
Expect(actual.Items[0].ID).To(Equal(songDayInALife.ID))
})
It("ensures only one record per user even with partial updates", func() {
By("Storing initial playqueue")
initial := aPlayQueue("userid", 0, 100, songComeTogether, songDayInALife)
Expect(repo.Store(initial)).To(Succeed())
initialCount := countPlayQueues(repo, "userid")
Expect(initialCount).To(Equal(1))
By("Storing partial update with different ID but same user")
partialUpdate := &model.PlayQueue{
ID: "completely-different-id", // Use a completely different ID
UserID: "userid",
Current: 1,
ChangedBy: "test-partial",
}
Expect(repo.Store(partialUpdate, "current")).To(Succeed())
By("Verifying only one record still exists for the user")
finalCount := countPlayQueues(repo, "userid")
Expect(finalCount).To(Equal(1))
By("Verifying the existing record was updated with new current value")
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Current).To(Equal(1)) // Should be updated value
Expect(actual.Position).To(Equal(int64(100))) // Should remain unchanged
Expect(actual.Items).To(HaveLen(2)) // Should remain unchanged
})
})
Describe("Retrieve", func() {
It("returns notfound error if there's no playqueue for the user", func() {
_, err := repo.Retrieve("user999")
Expect(err).To(MatchError(model.ErrNotFound))
})
It("retrieves the playqueue with only track IDs (no full MediaFile data)", func() {
By("Storing a playqueue for the user")
expected := aPlayQueue("userid", 1, 123, songComeTogether, songDayInALife)
Expect(repo.Store(expected)).To(Succeed())
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
// Basic playqueue properties should match
Expect(actual.ID).To(Equal(expected.ID))
Expect(actual.UserID).To(Equal(expected.UserID))
Expect(actual.Current).To(Equal(expected.Current))
Expect(actual.Position).To(Equal(expected.Position))
Expect(actual.ChangedBy).To(Equal(expected.ChangedBy))
Expect(actual.Items).To(HaveLen(len(expected.Items)))
// Items should only contain IDs, not full MediaFile data
for i, item := range actual.Items {
Expect(item.ID).To(Equal(expected.Items[i].ID))
// These fields should be empty since we're not loading full MediaFiles
Expect(item.Title).To(BeEmpty())
Expect(item.Path).To(BeEmpty())
Expect(item.Album).To(BeEmpty())
Expect(item.Artist).To(BeEmpty())
}
})
It("returns items with IDs even when some tracks don't exist in the DB", func() {
// Add a new song to the DB
newSong := songRadioactivity
newSong.ID = "temp-track"
newSong.Path = "/new-path"
mfRepo := NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
Expect(mfRepo.Put(&newSong)).To(Succeed())
// Create a playqueue with the new song
pq := aPlayQueue("userid", 0, 0, newSong, songAntenna)
Expect(repo.Store(pq)).To(Succeed())
// Delete the new song from the database
Expect(mfRepo.Delete("temp-track")).To(Succeed())
// Retrieve the playqueue with Retrieve method
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
// The playqueue should still contain both track IDs (including the deleted one)
Expect(actual.Items).To(HaveLen(2))
Expect(actual.Items[0].ID).To(Equal("temp-track"))
Expect(actual.Items[1].ID).To(Equal(songAntenna.ID))
// Items should only contain IDs, no other data
for _, item := range actual.Items {
Expect(item.Title).To(BeEmpty())
Expect(item.Path).To(BeEmpty())
Expect(item.Album).To(BeEmpty())
Expect(item.Artist).To(BeEmpty())
}
})
})
Describe("RetrieveWithMediaFiles", func() {
It("returns notfound error if there's no playqueue for the user", func() {
_, err := repo.RetrieveWithMediaFiles("user999")
Expect(err).To(MatchError(model.ErrNotFound))
})
It("retrieves the playqueue with full MediaFile data", func() {
By("Storing a playqueue for the user")
expected := aPlayQueue("userid", 1, 123, songComeTogether, songDayInALife)
Expect(repo.Store(expected)).To(Succeed())
actual, err := repo.RetrieveWithMediaFiles("userid")
Expect(err).ToNot(HaveOccurred())
AssertPlayQueue(expected, actual)
})
It("does not return tracks if they don't exist in the DB", func() {
// Add a new song to the DB
newSong := songRadioactivity
newSong.ID = "temp-track"
newSong.Path = "/new-path"
mfRepo := NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
Expect(mfRepo.Put(&newSong)).To(Succeed())
// Create a playqueue with the new song
pq := aPlayQueue("userid", 0, 0, newSong, songAntenna)
Expect(repo.Store(pq)).To(Succeed())
// Retrieve the playqueue
actual, err := repo.RetrieveWithMediaFiles("userid")
Expect(err).ToNot(HaveOccurred())
// The playqueue should contain both tracks
AssertPlayQueue(pq, actual)
// Delete the new song
Expect(mfRepo.Delete("temp-track")).To(Succeed())
// Retrieve the playqueue
actual, err = repo.RetrieveWithMediaFiles("userid")
Expect(err).ToNot(HaveOccurred())
// The playqueue should not contain the deleted track
Expect(actual.Items).To(HaveLen(1))
Expect(actual.Items[0].ID).To(Equal(songAntenna.ID))
})
})
Describe("Clear", func() {
It("clears an existing playqueue", func() {
By("Storing a playqueue")
expected := aPlayQueue("userid", 1, 123, songComeTogether, songDayInALife)
Expect(repo.Store(expected)).To(Succeed())
By("Verifying playqueue exists")
_, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
By("Clearing the playqueue")
Expect(repo.Clear("userid")).To(Succeed())
By("Verifying playqueue is cleared")
_, err = repo.Retrieve("userid")
Expect(err).To(MatchError(model.ErrNotFound))
})
It("does not error when clearing non-existent playqueue", func() {
// Clear should not error even if no playqueue exists
Expect(repo.Clear("nonexistent-user")).To(Succeed())
})
It("only clears the specified user's playqueue", func() {
By("Creating users in the database to avoid foreign key constraints")
userRepo := NewUserRepository(ctx, GetDBXBuilder())
user1 := &model.User{ID: "user1", UserName: "user1", Name: "User 1", Email: "user1@test.com"}
user2 := &model.User{ID: "user2", UserName: "user2", Name: "User 2", Email: "user2@test.com"}
Expect(userRepo.Put(user1)).To(Succeed())
Expect(userRepo.Put(user2)).To(Succeed())
By("Storing playqueues for two users")
user1Queue := aPlayQueue("user1", 0, 100, songComeTogether)
user2Queue := aPlayQueue("user2", 1, 200, songDayInALife)
Expect(repo.Store(user1Queue)).To(Succeed())
Expect(repo.Store(user2Queue)).To(Succeed())
By("Clearing only user1's playqueue")
Expect(repo.Clear("user1")).To(Succeed())
By("Verifying user1's playqueue is cleared")
_, err := repo.Retrieve("user1")
Expect(err).To(MatchError(model.ErrNotFound))
By("Verifying user2's playqueue still exists")
actual, err := repo.Retrieve("user2")
Expect(err).ToNot(HaveOccurred())
Expect(actual.UserID).To(Equal("user2"))
Expect(actual.Current).To(Equal(1))
Expect(actual.Position).To(Equal(int64(200)))
})
})
})
func countPlayQueues(repo model.PlayQueueRepository, userId string) int {
r := repo.(*playQueueRepository)
c, err := r.count(squirrel.Select().Where(squirrel.Eq{"user_id": userId}))
if err != nil {
panic(err)
}
return int(c)
}
func AssertPlayQueue(expected, actual *model.PlayQueue) {
Expect(actual.ID).To(Equal(expected.ID))
Expect(actual.UserID).To(Equal(expected.UserID))
Expect(actual.Current).To(Equal(expected.Current))
Expect(actual.Position).To(Equal(expected.Position))
Expect(actual.ChangedBy).To(Equal(expected.ChangedBy))
Expect(actual.Items).To(HaveLen(len(expected.Items)))
for i, item := range actual.Items {
Expect(item.Title).To(Equal(expected.Items[i].Title))
}
}
func aPlayQueue(userId string, current int, position int64, items ...model.MediaFile) *model.PlayQueue {
createdAt := time.Now()
updatedAt := createdAt.Add(time.Minute)
return &model.PlayQueue{
ID: id.NewRandom(),
UserID: userId,
Current: current,
Position: position,
ChangedBy: "test",
Items: items,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
}
}

View File

@@ -0,0 +1,63 @@
package persistence
import (
"context"
"errors"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type propertyRepository struct {
sqlRepository
}
func NewPropertyRepository(ctx context.Context, db dbx.Builder) model.PropertyRepository {
r := &propertyRepository{}
r.ctx = ctx
r.db = db
r.tableName = "property"
return r
}
func (r propertyRepository) Put(id string, value string) error {
update := Update(r.tableName).Set("value", value).Where(Eq{"id": id})
count, err := r.executeSQL(update)
if err != nil {
return err
}
if count > 0 {
return nil
}
insert := Insert(r.tableName).Columns("id", "value").Values(id, value)
_, err = r.executeSQL(insert)
return err
}
func (r propertyRepository) Get(id string) (string, error) {
sel := Select("value").From(r.tableName).Where(Eq{"id": id})
resp := struct {
Value string
}{}
err := r.queryOne(sel, &resp)
if err != nil {
return "", err
}
return resp.Value, nil
}
func (r propertyRepository) DefaultGet(id string, defaultValue string) (string, error) {
value, err := r.Get(id)
if errors.Is(err, model.ErrNotFound) {
return defaultValue, nil
}
if err != nil {
return defaultValue, err
}
return value, nil
}
func (r propertyRepository) Delete(id string) error {
return r.delete(Eq{"id": id})
}

View File

@@ -0,0 +1,34 @@
package persistence
import (
"context"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Property Repository", func() {
var pr model.PropertyRepository
BeforeEach(func() {
pr = NewPropertyRepository(log.NewContext(context.TODO()), GetDBXBuilder())
})
It("saves and restore a new property", func() {
id := "1"
value := "a_value"
Expect(pr.Put(id, value)).To(BeNil())
Expect(pr.Get(id)).To(Equal("a_value"))
})
It("updates a property", func() {
Expect(pr.Put("1", "another_value")).To(BeNil())
Expect(pr.Get("1")).To(Equal("another_value"))
})
It("returns a default value if property does not exist", func() {
Expect(pr.DefaultGet("2", "default")).To(Equal("default"))
})
})

View File

@@ -0,0 +1,139 @@
package persistence
import (
"context"
"errors"
"time"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/pocketbase/dbx"
)
type radioRepository struct {
sqlRepository
}
func NewRadioRepository(ctx context.Context, db dbx.Builder) model.RadioRepository {
r := &radioRepository{}
r.ctx = ctx
r.db = db
r.registerModel(&model.Radio{}, map[string]filterFunc{
"name": containsFilter("name"),
})
return r
}
func (r *radioRepository) isPermitted() bool {
user := loggedUser(r.ctx)
return user.IsAdmin
}
func (r *radioRepository) CountAll(options ...model.QueryOptions) (int64, error) {
sql := r.newSelect()
return r.count(sql, options...)
}
func (r *radioRepository) Delete(id string) error {
if !r.isPermitted() {
return rest.ErrPermissionDenied
}
return r.delete(Eq{"id": id})
}
func (r *radioRepository) Get(id string) (*model.Radio, error) {
sel := r.newSelect().Where(Eq{"id": id}).Columns("*")
res := model.Radio{}
err := r.queryOne(sel, &res)
return &res, err
}
func (r *radioRepository) GetAll(options ...model.QueryOptions) (model.Radios, error) {
sel := r.newSelect(options...).Columns("*")
res := model.Radios{}
err := r.queryAll(sel, &res)
return res, err
}
func (r *radioRepository) Put(radio *model.Radio) error {
if !r.isPermitted() {
return rest.ErrPermissionDenied
}
var values map[string]interface{}
radio.UpdatedAt = time.Now()
if radio.ID == "" {
radio.CreatedAt = time.Now()
radio.ID = id.NewRandom()
values, _ = toSQLArgs(*radio)
} else {
values, _ = toSQLArgs(*radio)
update := Update(r.tableName).Where(Eq{"id": radio.ID}).SetMap(values)
count, err := r.executeSQL(update)
if err != nil {
return err
} else if count > 0 {
return nil
}
}
values["created_at"] = time.Now()
insert := Insert(r.tableName).SetMap(values)
_, err := r.executeSQL(insert)
return err
}
func (r *radioRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *radioRepository) EntityName() string {
return "radio"
}
func (r *radioRepository) NewInstance() interface{} {
return &model.Radio{}
}
func (r *radioRepository) Read(id string) (interface{}, error) {
return r.Get(id)
}
func (r *radioRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *radioRepository) Save(entity interface{}) (string, error) {
t := entity.(*model.Radio)
if !r.isPermitted() {
return "", rest.ErrPermissionDenied
}
err := r.Put(t)
if errors.Is(err, model.ErrNotFound) {
return "", rest.ErrNotFound
}
return t.ID, err
}
func (r *radioRepository) Update(id string, entity interface{}, cols ...string) error {
t := entity.(*model.Radio)
t.ID = id
if !r.isPermitted() {
return rest.ErrPermissionDenied
}
err := r.Put(t)
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
var _ model.RadioRepository = (*radioRepository)(nil)
var _ rest.Repository = (*radioRepository)(nil)
var _ rest.Persistable = (*radioRepository)(nil)

View File

@@ -0,0 +1,175 @@
package persistence
import (
"context"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var (
NewId string = "123-456-789"
)
var _ = Describe("RadioRepository", func() {
var repo model.RadioRepository
Describe("Admin User", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
repo = NewRadioRepository(ctx, GetDBXBuilder())
_ = repo.Put(&radioWithHomePage)
})
AfterEach(func() {
all, _ := repo.GetAll()
for _, radio := range all {
_ = repo.Delete(radio.ID)
}
for i := range testRadios {
r := testRadios[i]
err := repo.Put(&r)
if err != nil {
panic(err)
}
}
})
Describe("Count", func() {
It("returns the number of radios in the DB", func() {
Expect(repo.CountAll()).To(Equal(int64(2)))
})
})
Describe("Delete", func() {
It("deletes existing item", func() {
err := repo.Delete(radioWithHomePage.ID)
Expect(err).To(BeNil())
_, err = repo.Get(radioWithHomePage.ID)
Expect(err).To(MatchError(model.ErrNotFound))
})
})
Describe("Get", func() {
It("returns an existing item", func() {
res, err := repo.Get(radioWithHomePage.ID)
Expect(err).To(BeNil())
Expect(res.ID).To(Equal(radioWithHomePage.ID))
})
It("errors when missing", func() {
_, err := repo.Get("notanid")
Expect(err).To(MatchError(model.ErrNotFound))
})
})
Describe("GetAll", func() {
It("returns all items from the DB", func() {
all, err := repo.GetAll()
Expect(err).To(BeNil())
Expect(all[0].ID).To(Equal(radioWithoutHomePage.ID))
Expect(all[1].ID).To(Equal(radioWithHomePage.ID))
})
})
Describe("Put", func() {
It("successfully updates item", func() {
err := repo.Put(&model.Radio{
ID: radioWithHomePage.ID,
Name: "New Name",
StreamUrl: "https://example.com:4533/app",
})
Expect(err).To(BeNil())
item, err := repo.Get(radioWithHomePage.ID)
Expect(err).To(BeNil())
Expect(item.HomePageUrl).To(Equal(""))
})
It("successfully creates item", func() {
err := repo.Put(&model.Radio{
Name: "New radio",
StreamUrl: "https://example.com:4533/app",
})
Expect(err).To(BeNil())
Expect(repo.CountAll()).To(Equal(int64(3)))
all, err := repo.GetAll()
Expect(err).To(BeNil())
Expect(all[2].StreamUrl).To(Equal("https://example.com:4533/app"))
})
})
})
Describe("Regular User", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: false})
repo = NewRadioRepository(ctx, GetDBXBuilder())
})
Describe("Count", func() {
It("returns the number of radios in the DB", func() {
Expect(repo.CountAll()).To(Equal(int64(2)))
})
})
Describe("Delete", func() {
It("fails to delete items", func() {
err := repo.Delete(radioWithHomePage.ID)
Expect(err).To(Equal(rest.ErrPermissionDenied))
})
})
Describe("Get", func() {
It("returns an existing item", func() {
res, err := repo.Get(radioWithHomePage.ID)
Expect(err).To((BeNil()))
Expect(res.ID).To(Equal(radioWithHomePage.ID))
})
It("errors when missing", func() {
_, err := repo.Get("notanid")
Expect(err).To(MatchError(model.ErrNotFound))
})
})
Describe("GetAll", func() {
It("returns all items from the DB", func() {
all, err := repo.GetAll()
Expect(err).To(BeNil())
Expect(all[0].ID).To(Equal(radioWithoutHomePage.ID))
Expect(all[1].ID).To(Equal(radioWithHomePage.ID))
})
})
Describe("Put", func() {
It("fails to update item", func() {
err := repo.Put(&model.Radio{
ID: radioWithHomePage.ID,
Name: "New Name",
StreamUrl: "https://example.com:4533/app",
})
Expect(err).To(Equal(rest.ErrPermissionDenied))
})
})
})
})

View File

@@ -0,0 +1,100 @@
package persistence
import (
"context"
"errors"
"time"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/pocketbase/dbx"
)
type scrobbleBufferRepository struct {
sqlRepository
}
type dbScrobbleBuffer struct {
dbMediaFile
*model.ScrobbleEntry `structs:",flatten"`
}
func (t *dbScrobbleBuffer) PostScan() error {
if err := t.dbMediaFile.PostScan(); err != nil {
return err
}
t.ScrobbleEntry.MediaFile = *t.dbMediaFile.MediaFile
t.ScrobbleEntry.MediaFile.ID = t.MediaFileID
return nil
}
func NewScrobbleBufferRepository(ctx context.Context, db dbx.Builder) model.ScrobbleBufferRepository {
r := &scrobbleBufferRepository{}
r.ctx = ctx
r.db = db
r.tableName = "scrobble_buffer"
return r
}
func (r *scrobbleBufferRepository) UserIDs(service string) ([]string, error) {
sql := Select().Columns("user_id").
From(r.tableName).
Where(And{
Eq{"service": service},
}).
GroupBy("user_id").
OrderBy("count(*)")
var userIds []string
err := r.queryAllSlice(sql, &userIds)
return userIds, err
}
func (r *scrobbleBufferRepository) Enqueue(service, userId, mediaFileId string, playTime time.Time) error {
ins := Insert(r.tableName).SetMap(map[string]interface{}{
"id": id.NewRandom(),
"user_id": userId,
"service": service,
"media_file_id": mediaFileId,
"play_time": playTime,
"enqueue_time": time.Now(),
})
_, err := r.executeSQL(ins)
return err
}
func (r *scrobbleBufferRepository) Next(service string, userId string) (*model.ScrobbleEntry, error) {
// Put `s.*` last or else m.id overrides s.id
sql := Select().Columns("m.*, s.*").
From(r.tableName+" s").
LeftJoin("media_file m on m.id = s.media_file_id").
Where(And{
Eq{"service": service},
Eq{"user_id": userId},
}).
OrderBy("play_time", "s.rowid").Limit(1)
var res dbScrobbleBuffer
err := r.queryOne(sql, &res)
if errors.Is(err, model.ErrNotFound) {
return nil, nil
}
if err != nil {
return nil, err
}
res.ScrobbleEntry.Participants, err = r.getParticipants(&res.ScrobbleEntry.MediaFile)
if err != nil {
return nil, err
}
return res.ScrobbleEntry, nil
}
func (r *scrobbleBufferRepository) Dequeue(entry *model.ScrobbleEntry) error {
return r.delete(Eq{"id": entry.ID})
}
func (r *scrobbleBufferRepository) Length() (int64, error) {
return r.count(Select())
}
var _ model.ScrobbleBufferRepository = (*scrobbleBufferRepository)(nil)

View File

@@ -0,0 +1,208 @@
package persistence
import (
"context"
"time"
"github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ScrobbleBufferRepository", func() {
var scrobble model.ScrobbleBufferRepository
var rawRepo sqlRepository
enqueueTime := time.Date(2025, 01, 01, 00, 00, 00, 00, time.Local)
var ids []string
var insertManually = func(service, userId, mediaFileId string, playTime time.Time) {
id := id.NewRandom()
ids = append(ids, id)
ins := squirrel.Insert("scrobble_buffer").SetMap(map[string]interface{}{
"id": id,
"user_id": userId,
"service": service,
"media_file_id": mediaFileId,
"play_time": playTime,
"enqueue_time": enqueueTime,
})
_, err := rawRepo.executeSQL(ins)
Expect(err).ToNot(HaveOccurred())
}
BeforeEach(func() {
ctx := request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", UserName: "johndoe", IsAdmin: true})
db := GetDBXBuilder()
scrobble = NewScrobbleBufferRepository(ctx, db)
rawRepo = sqlRepository{
ctx: ctx,
tableName: "scrobble_buffer",
db: db,
}
ids = []string{}
})
AfterEach(func() {
del := squirrel.Delete(rawRepo.tableName)
_, err := rawRepo.executeSQL(del)
Expect(err).ToNot(HaveOccurred())
})
Describe("Without data", func() {
Describe("Count", func() {
It("returns zero when empty", func() {
count, err := scrobble.Length()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(BeZero())
})
})
Describe("Dequeue", func() {
It("is a no-op when deleting a nonexistent item", func() {
err := scrobble.Dequeue(&model.ScrobbleEntry{ID: "fake"})
Expect(err).ToNot(HaveOccurred())
count, err := scrobble.Length()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(0)))
})
})
Describe("Next", func() {
It("should not fail with no item for the service", func() {
entry, err := scrobble.Next("fake", "userid")
Expect(entry).To(BeNil())
Expect(err).ToNot(HaveOccurred())
})
})
Describe("UserIds", func() {
It("should return empty list with no data", func() {
ids, err := scrobble.UserIDs("service")
Expect(err).ToNot(HaveOccurred())
Expect(ids).To(BeEmpty())
})
})
})
Describe("With data", func() {
timeA := enqueueTime.Add(24 * time.Hour)
timeB := enqueueTime.Add(48 * time.Hour)
timeC := enqueueTime.Add(72 * time.Hour)
timeD := enqueueTime.Add(96 * time.Hour)
BeforeEach(func() {
insertManually("a", "userid", "1001", timeB)
insertManually("a", "userid", "1002", timeA)
insertManually("a", "2222", "1003", timeC)
insertManually("b", "2222", "1004", timeD)
})
Describe("Count", func() {
It("Returns count when populated", func() {
count, err := scrobble.Length()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(4)))
})
})
Describe("Dequeue", func() {
It("is a no-op when deleting a nonexistent item", func() {
err := scrobble.Dequeue(&model.ScrobbleEntry{ID: "fake"})
Expect(err).ToNot(HaveOccurred())
count, err := scrobble.Length()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(4)))
})
It("deletes an item when specified properly", func() {
err := scrobble.Dequeue(&model.ScrobbleEntry{ID: ids[3]})
Expect(err).ToNot(HaveOccurred())
count, err := scrobble.Length()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(3)))
entry, err := scrobble.Next("b", "2222")
Expect(err).ToNot(HaveOccurred())
Expect(entry).To(BeNil())
})
})
Describe("Enqueue", func() {
DescribeTable("enqueues an item properly",
func(service, userId, fileId string, playTime time.Time) {
now := time.Now()
err := scrobble.Enqueue(service, userId, fileId, playTime)
Expect(err).ToNot(HaveOccurred())
count, err := scrobble.Length()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(5)))
entry, err := scrobble.Next(service, userId)
Expect(err).ToNot(HaveOccurred())
Expect(entry).ToNot(BeNil())
Expect(entry.EnqueueTime).To(BeTemporally("~", now, 100*time.Millisecond))
Expect(entry.MediaFileID).To(Equal(fileId))
Expect(entry.PlayTime).To(BeTemporally("==", playTime))
},
Entry("to an existing service with multiple values", "a", "userid", "1004", enqueueTime),
Entry("to a new service", "c", "2222", "1001", timeD),
Entry("to an existing service as new user", "b", "userid", "1003", timeC),
)
})
Describe("Next", func() {
DescribeTable("Returns the next item when populated",
func(service, id string, playTime time.Time, fileId, artistId string) {
entry, err := scrobble.Next(service, id)
Expect(err).ToNot(HaveOccurred())
Expect(entry).ToNot(BeNil())
Expect(entry.Service).To(Equal(service))
Expect(entry.UserID).To(Equal(id))
Expect(entry.PlayTime).To(BeTemporally("==", playTime))
Expect(entry.EnqueueTime).To(BeTemporally("==", enqueueTime))
Expect(entry.MediaFileID).To(Equal(fileId))
Expect(entry.MediaFile.Participants).To(HaveLen(1))
artists, ok := entry.MediaFile.Participants[model.RoleArtist]
Expect(ok).To(BeTrue(), "no artist role in participants")
Expect(artists).To(HaveLen(1))
Expect(artists[0].ID).To(Equal(artistId))
},
Entry("Service with multiple values for one user", "a", "userid", timeA, "1002", "3"),
Entry("Service with users", "a", "2222", timeC, "1003", "2"),
Entry("Service with one user", "b", "2222", timeD, "1004", "2"),
)
})
Describe("UserIds", func() {
It("should return ordered list for services", func() {
ids, err := scrobble.UserIDs("a")
Expect(err).ToNot(HaveOccurred())
Expect(ids).To(Equal([]string{"2222", "userid"}))
})
It("should return for a different service", func() {
ids, err := scrobble.UserIDs("b")
Expect(err).ToNot(HaveOccurred())
Expect(ids).To(Equal([]string{"2222"}))
})
})
})
})

View File

@@ -0,0 +1,34 @@
package persistence
import (
"context"
"time"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type scrobbleRepository struct {
sqlRepository
}
func NewScrobbleRepository(ctx context.Context, db dbx.Builder) model.ScrobbleRepository {
r := &scrobbleRepository{}
r.ctx = ctx
r.db = db
r.tableName = "scrobbles"
return r
}
func (r *scrobbleRepository) RecordScrobble(mediaFileID string, submissionTime time.Time) error {
userID := loggedUser(r.ctx).ID
values := map[string]interface{}{
"media_file_id": mediaFileID,
"user_id": userID,
"submission_time": submissionTime.Unix(),
}
insert := Insert(r.tableName).SetMap(values)
_, err := r.executeSQL(insert)
return err
}

View File

@@ -0,0 +1,84 @@
package persistence
import (
"context"
"time"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
var _ = Describe("ScrobbleRepository", func() {
var repo model.ScrobbleRepository
var rawRepo sqlRepository
var ctx context.Context
var fileID string
var userID string
BeforeEach(func() {
fileID = id.NewRandom()
userID = id.NewRandom()
ctx = request.WithUser(log.NewContext(GinkgoT().Context()), model.User{ID: userID, UserName: "johndoe", IsAdmin: true})
db := GetDBXBuilder()
repo = NewScrobbleRepository(ctx, db)
rawRepo = sqlRepository{
ctx: ctx,
tableName: "scrobbles",
db: db,
}
})
AfterEach(func() {
_, _ = rawRepo.db.Delete("scrobbles", dbx.HashExp{"media_file_id": fileID}).Execute()
_, _ = rawRepo.db.Delete("media_file", dbx.HashExp{"id": fileID}).Execute()
_, _ = rawRepo.db.Delete("user", dbx.HashExp{"id": userID}).Execute()
})
Describe("RecordScrobble", func() {
It("records a scrobble event", func() {
submissionTime := time.Now().UTC()
// Insert User
_, err := rawRepo.db.Insert("user", dbx.Params{
"id": userID,
"user_name": "user",
"password": "pw",
"created_at": time.Now(),
"updated_at": time.Now(),
}).Execute()
Expect(err).ToNot(HaveOccurred())
// Insert MediaFile
_, err = rawRepo.db.Insert("media_file", dbx.Params{
"id": fileID,
"path": "path",
"created_at": time.Now(),
"updated_at": time.Now(),
}).Execute()
Expect(err).ToNot(HaveOccurred())
err = repo.RecordScrobble(fileID, submissionTime)
Expect(err).ToNot(HaveOccurred())
// Verify insertion
var scrobble struct {
MediaFileID string `db:"media_file_id"`
UserID string `db:"user_id"`
SubmissionTime int64 `db:"submission_time"`
}
err = rawRepo.db.Select("*").From("scrobbles").
Where(dbx.HashExp{"media_file_id": fileID, "user_id": userID}).
One(&scrobble)
Expect(err).ToNot(HaveOccurred())
Expect(scrobble.MediaFileID).To(Equal(fileID))
Expect(scrobble.UserID).To(Equal(userID))
Expect(scrobble.SubmissionTime).To(Equal(submissionTime.Unix()))
})
})
})

View File

@@ -0,0 +1,202 @@
package persistence
import (
"context"
"errors"
"fmt"
"strings"
"time"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
"github.com/pocketbase/dbx"
)
type shareRepository struct {
sqlRepository
}
func NewShareRepository(ctx context.Context, db dbx.Builder) model.ShareRepository {
r := &shareRepository{}
r.ctx = ctx
r.db = db
r.registerModel(&model.Share{}, nil)
r.setSortMappings(map[string]string{
"username": "username",
})
return r
}
func (r *shareRepository) Delete(id string) error {
err := r.delete(Eq{"id": id})
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
func (r *shareRepository) selectShare(options ...model.QueryOptions) SelectBuilder {
return r.newSelect(options...).Join("user u on u.id = share.user_id").
Columns("share.*", "user_name as username")
}
func (r *shareRepository) Exists(id string) (bool, error) {
return r.exists(Eq{"id": id})
}
func (r *shareRepository) Get(id string) (*model.Share, error) {
sel := r.selectShare().Where(Eq{"share.id": id})
var res model.Share
err := r.queryOne(sel, &res)
if err != nil {
return nil, err
}
err = r.loadMedia(&res)
return &res, err
}
func (r *shareRepository) GetAll(options ...model.QueryOptions) (model.Shares, error) {
sq := r.selectShare(options...)
res := model.Shares{}
err := r.queryAll(sq, &res)
if err != nil {
return nil, err
}
for i := range res {
err = r.loadMedia(&res[i])
if err != nil {
return nil, fmt.Errorf("error loading media for share %s: %w", res[i].ID, err)
}
}
return res, err
}
func (r *shareRepository) loadMedia(share *model.Share) error {
var err error
ids := strings.Split(share.ResourceIDs, ",")
if len(ids) == 0 {
return nil
}
noMissing := func(cond Sqlizer) Sqlizer {
return And{cond, Eq{"missing": false}}
}
switch share.ResourceType {
case "artist":
albumRepo := NewAlbumRepository(r.ctx, r.db, nil)
share.Albums, err = albumRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"album_artist_id": ids}), Sort: "artist"})
if err != nil {
return err
}
mfRepo := NewMediaFileRepository(r.ctx, r.db, nil)
share.Tracks, err = mfRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"album_artist_id": ids}), Sort: "artist"})
return err
case "album":
albumRepo := NewAlbumRepository(r.ctx, r.db, nil)
share.Albums, err = albumRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"album.id": ids})})
if err != nil {
return err
}
mfRepo := NewMediaFileRepository(r.ctx, r.db, nil)
share.Tracks, err = mfRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"album_id": ids}), Sort: "album"})
return err
case "playlist":
// Create a context with a fake admin user, to be able to access all playlists
ctx := request.WithUser(r.ctx, model.User{IsAdmin: true})
plsRepo := NewPlaylistRepository(ctx, r.db)
tracks, err := plsRepo.Tracks(ids[0], true).GetAll(model.QueryOptions{Sort: "id", Filters: noMissing(Eq{})})
if err != nil {
return err
}
if len(tracks) >= 0 {
share.Tracks = tracks.MediaFiles()
}
return nil
case "media_file":
mfRepo := NewMediaFileRepository(r.ctx, r.db, nil)
tracks, err := mfRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"media_file.id": ids})})
share.Tracks = sortByIdPosition(tracks, ids)
return err
}
log.Warn(r.ctx, "Unsupported Share ResourceType", "share", share.ID, "resourceType", share.ResourceType)
return nil
}
func sortByIdPosition(mfs model.MediaFiles, ids []string) model.MediaFiles {
m := map[string]int{}
for i, mf := range mfs {
m[mf.ID] = i
}
var sorted model.MediaFiles
for _, id := range ids {
if idx, ok := m[id]; ok {
sorted = append(sorted, mfs[idx])
}
}
return sorted
}
func (r *shareRepository) Update(id string, entity interface{}, cols ...string) error {
s := entity.(*model.Share)
// TODO Validate record
s.ID = id
s.UpdatedAt = time.Now()
cols = append(cols, "updated_at")
_, err := r.put(id, s, cols...)
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
func (r *shareRepository) Save(entity interface{}) (string, error) {
s := entity.(*model.Share)
// TODO Validate record
u := loggedUser(r.ctx)
if s.UserID == "" {
s.UserID = u.ID
}
s.CreatedAt = time.Now()
s.UpdatedAt = time.Now()
id, err := r.put(s.ID, s)
if errors.Is(err, model.ErrNotFound) {
return "", rest.ErrNotFound
}
return id, err
}
func (r *shareRepository) CountAll(options ...model.QueryOptions) (int64, error) {
return r.count(r.selectShare(), options...)
}
func (r *shareRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *shareRepository) EntityName() string {
return "share"
}
func (r *shareRepository) NewInstance() interface{} {
return &model.Share{}
}
func (r *shareRepository) Read(id string) (interface{}, error) {
sel := r.selectShare().Where(Eq{"share.id": id})
var res model.Share
err := r.queryOne(sel, &res)
return &res, err
}
func (r *shareRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
sq := r.selectShare(r.parseRestOptions(r.ctx, options...))
res := model.Shares{}
err := r.queryAll(sq, &res)
return res, err
}
var _ model.ShareRepository = (*shareRepository)(nil)
var _ rest.Repository = (*shareRepository)(nil)
var _ rest.Persistable = (*shareRepository)(nil)

View File

@@ -0,0 +1,133 @@
package persistence
import (
"context"
"time"
"github.com/navidrome/navidrome/conf/configtest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("ShareRepository", func() {
var repo model.ShareRepository
var ctx context.Context
var adminUser = model.User{ID: "admin", UserName: "admin", IsAdmin: true}
BeforeEach(func() {
DeferCleanup(configtest.SetupConfig())
ctx = request.WithUser(log.NewContext(context.TODO()), adminUser)
repo = NewShareRepository(ctx, GetDBXBuilder())
// Insert the admin user into the database (required for foreign key constraint)
ur := NewUserRepository(ctx, GetDBXBuilder())
err := ur.Put(&adminUser)
Expect(err).ToNot(HaveOccurred())
// Clean up shares
db := GetDBXBuilder()
_, err = db.NewQuery("DELETE FROM share").Execute()
Expect(err).ToNot(HaveOccurred())
})
Describe("Headless Access", func() {
Context("Repository creation and basic operations", func() {
It("should create repository successfully with no user context", func() {
// Create repository with no user context (headless)
headlessRepo := NewShareRepository(context.Background(), GetDBXBuilder())
Expect(headlessRepo).ToNot(BeNil())
})
It("should handle GetAll for headless processes", func() {
// Create a simple share directly in database
shareID := "headless-test-share"
_, err := GetDBXBuilder().NewQuery(`
INSERT INTO share (id, user_id, description, resource_type, resource_ids, created_at, updated_at)
VALUES ({:id}, {:user}, {:desc}, {:type}, {:ids}, {:created}, {:updated})
`).Bind(map[string]interface{}{
"id": shareID,
"user": adminUser.ID,
"desc": "Headless Test Share",
"type": "song",
"ids": "song-1",
"created": time.Now(),
"updated": time.Now(),
}).Execute()
Expect(err).ToNot(HaveOccurred())
// Headless process should see all shares
headlessRepo := NewShareRepository(context.Background(), GetDBXBuilder())
shares, err := headlessRepo.GetAll()
Expect(err).ToNot(HaveOccurred())
found := false
for _, s := range shares {
if s.ID == shareID {
found = true
break
}
}
Expect(found).To(BeTrue(), "Headless process should see all shares")
})
It("should handle individual share retrieval for headless processes", func() {
// Create a simple share
shareID := "headless-get-share"
_, err := GetDBXBuilder().NewQuery(`
INSERT INTO share (id, user_id, description, resource_type, resource_ids, created_at, updated_at)
VALUES ({:id}, {:user}, {:desc}, {:type}, {:ids}, {:created}, {:updated})
`).Bind(map[string]interface{}{
"id": shareID,
"user": adminUser.ID,
"desc": "Headless Get Share",
"type": "song",
"ids": "song-2",
"created": time.Now(),
"updated": time.Now(),
}).Execute()
Expect(err).ToNot(HaveOccurred())
// Headless process should be able to get the share
headlessRepo := NewShareRepository(context.Background(), GetDBXBuilder())
share, err := headlessRepo.Get(shareID)
Expect(err).ToNot(HaveOccurred())
Expect(share.ID).To(Equal(shareID))
Expect(share.Description).To(Equal("Headless Get Share"))
})
})
})
Describe("SQL ambiguity fix verification", func() {
It("should handle share operations without SQL ambiguity errors", func() {
// This test verifies that the loadMedia function doesn't cause SQL ambiguity
// The key fix was using "album.id" instead of "id" in the album query filters
// Create a share that would trigger the loadMedia function
shareID := "sql-test-share"
_, err := GetDBXBuilder().NewQuery(`
INSERT INTO share (id, user_id, description, resource_type, resource_ids, created_at, updated_at)
VALUES ({:id}, {:user}, {:desc}, {:type}, {:ids}, {:created}, {:updated})
`).Bind(map[string]interface{}{
"id": shareID,
"user": adminUser.ID,
"desc": "SQL Test Share",
"type": "album",
"ids": "non-existent-album", // Won't find albums, but shouldn't cause SQL errors
"created": time.Now(),
"updated": time.Now(),
}).Execute()
Expect(err).ToNot(HaveOccurred())
// The Get operation should work without SQL ambiguity errors
// even if no albums are found
share, err := repo.Get(shareID)
Expect(err).ToNot(HaveOccurred())
Expect(share.ID).To(Equal(shareID))
// Albums array should be empty since we used non-existent album ID
Expect(share.Albums).To(BeEmpty())
})
})
})

View File

@@ -0,0 +1,130 @@
package persistence
import (
"database/sql"
"errors"
"fmt"
"time"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/consts"
"github.com/navidrome/navidrome/log"
)
const annotationTable = "annotation"
func (r sqlRepository) withAnnotation(query SelectBuilder, idField string) SelectBuilder {
userID := loggedUser(r.ctx).ID
if userID == invalidUserId {
return query
}
query = query.
LeftJoin("annotation on ("+
"annotation.item_id = "+idField+
" AND annotation.user_id = '"+userID+"')").
Columns(
"coalesce(starred, 0) as starred",
"coalesce(rating, 0) as rating",
"starred_at",
"play_date",
"rated_at",
)
if conf.Server.AlbumPlayCountMode == consts.AlbumPlayCountModeNormalized && r.tableName == "album" {
query = query.Columns(
fmt.Sprintf("round(coalesce(round(cast(play_count as float) / coalesce(%[1]s.song_count, 1), 1), 0)) as play_count", r.tableName),
)
} else {
query = query.Columns("coalesce(play_count, 0) as play_count")
}
return query
}
func (r sqlRepository) annId(itemID ...string) And {
userID := loggedUser(r.ctx).ID
return And{
Eq{annotationTable + ".user_id": userID},
Eq{annotationTable + ".item_type": r.tableName},
Eq{annotationTable + ".item_id": itemID},
}
}
func (r sqlRepository) annUpsert(values map[string]interface{}, itemIDs ...string) error {
upd := Update(annotationTable).Where(r.annId(itemIDs...))
for f, v := range values {
upd = upd.Set(f, v)
}
c, err := r.executeSQL(upd)
if c == 0 || errors.Is(err, sql.ErrNoRows) {
userID := loggedUser(r.ctx).ID
for _, itemID := range itemIDs {
values["user_id"] = userID
values["item_type"] = r.tableName
values["item_id"] = itemID
ins := Insert(annotationTable).SetMap(values)
_, err = r.executeSQL(ins)
if err != nil {
return err
}
}
}
return err
}
func (r sqlRepository) SetStar(starred bool, ids ...string) error {
starredAt := time.Now()
return r.annUpsert(map[string]interface{}{"starred": starred, "starred_at": starredAt}, ids...)
}
func (r sqlRepository) SetRating(rating int, itemID string) error {
ratedAt := time.Now()
return r.annUpsert(map[string]interface{}{"rating": rating, "rated_at": ratedAt}, itemID)
}
func (r sqlRepository) IncPlayCount(itemID string, ts time.Time) error {
upd := Update(annotationTable).Where(r.annId(itemID)).
Set("play_count", Expr("play_count+1")).
Set("play_date", Expr("max(ifnull(play_date,''),?)", ts))
c, err := r.executeSQL(upd)
if c == 0 || errors.Is(err, sql.ErrNoRows) {
userID := loggedUser(r.ctx).ID
values := map[string]interface{}{}
values["user_id"] = userID
values["item_type"] = r.tableName
values["item_id"] = itemID
values["play_count"] = 1
values["play_date"] = ts
ins := Insert(annotationTable).SetMap(values)
_, err = r.executeSQL(ins)
if err != nil {
return err
}
}
return err
}
func (r sqlRepository) ReassignAnnotation(prevID string, newID string) error {
if prevID == newID || prevID == "" || newID == "" {
return nil
}
upd := Update(annotationTable).Where(And{
Eq{annotationTable + ".item_type": r.tableName},
Eq{annotationTable + ".item_id": prevID},
}).Set("item_id", newID)
_, err := r.executeSQL(upd)
return err
}
func (r sqlRepository) cleanAnnotations() error {
del := Delete(annotationTable).Where(Eq{"item_type": r.tableName}).Where("item_id not in (select id from " + r.tableName + ")")
c, err := r.executeSQL(del)
if err != nil {
return fmt.Errorf("error cleaning up %s annotations: %w", r.tableName, err)
}
if c > 0 {
log.Debug(r.ctx, "Clean-up annotations", "table", r.tableName, "totalDeleted", c)
}
return nil
}

View File

@@ -0,0 +1,470 @@
package persistence
import (
"context"
"crypto/md5"
"database/sql"
"errors"
"fmt"
"iter"
"reflect"
"regexp"
"strings"
"time"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
id2 "github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
"github.com/navidrome/navidrome/utils/hasher"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
// sqlRepository is the base repository for all SQL repositories. It provides common functions to interact with the DB.
// When creating a new repository using this base, you must:
//
// - Embed this struct.
// - Set ctx and db fields. ctx should be the context passed to the constructor method, usually obtained from the request
// - Call registerModel with the model instance and any possible filters.
// - If the model has a different table name than the default (lowercase of the model name), it should be set manually
// using the tableName field.
// - Sort mappings must be set with setSortMappings method. If a sort field is not in the map, it will be used as the name of the column.
//
// All fields in filters and sortMappings must be in snake_case. Only sorts and filters based on real field names or
// defined in the mappings will be allowed.
type sqlRepository struct {
ctx context.Context
tableName string
db dbx.Builder
// Do not set these fields manually, they are set by the registerModel method
filterMappings map[string]filterFunc
isFieldWhiteListed fieldWhiteListedFunc
// Do not set this field manually, it is set by the setSortMappings method
sortMappings map[string]string
}
const invalidUserId = "-1"
func loggedUser(ctx context.Context) *model.User {
if user, ok := request.UserFrom(ctx); !ok {
return &model.User{ID: invalidUserId}
} else {
return &user
}
}
func (r *sqlRepository) registerModel(instance any, filters map[string]filterFunc) {
if r.tableName == "" {
r.tableName = strings.TrimPrefix(reflect.TypeOf(instance).String(), "*model.")
r.tableName = toSnakeCase(r.tableName)
}
r.tableName = strings.ToLower(r.tableName)
r.isFieldWhiteListed = registerModelWhiteList(instance)
r.filterMappings = filters
}
// setSortMappings sets the mappings for the sort fields. If the sort field is not in the map, it will be used as is.
//
// If PreferSortTags is enabled, it will map the order fields to the corresponding sort expression,
// which gives precedence to sort tags.
// Ex: order_title => (coalesce(nullif(sort_title,”),order_title) collate nocase)
// To avoid performance issues, indexes should be created for these sort expressions
//
// NOTE: if an individual item has spaces, it should be wrapped in parentheses. For example,
// you should write "(lyrics != '[]')". This prevents the item being split unexpectedly.
// Without parentheses, "lyrics != '[]'" would be mapped as simply "lyrics"
func (r *sqlRepository) setSortMappings(mappings map[string]string, tableName ...string) {
tn := r.tableName
if len(tableName) > 0 {
tn = tableName[0]
}
if conf.Server.PreferSortTags {
for k, v := range mappings {
v = mapSortOrder(tn, v)
mappings[k] = v
}
}
r.sortMappings = mappings
}
func (r sqlRepository) newSelect(options ...model.QueryOptions) SelectBuilder {
sq := Select().From(r.tableName)
if len(options) > 0 {
r.resetSeededRandom(options)
sq = r.applyOptions(sq, options...)
sq = r.applyFilters(sq, options...)
}
return sq
}
func (r sqlRepository) applyOptions(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
if len(options) > 0 {
if options[0].Max > 0 {
sq = sq.Limit(uint64(options[0].Max))
}
if options[0].Offset > 0 {
sq = sq.Offset(uint64(options[0].Offset))
}
if options[0].Sort != "" {
sq = sq.OrderBy(r.buildSortOrder(options[0].Sort, options[0].Order))
}
}
return sq
}
// TODO Change all sortMappings to have a consistent case
func (r sqlRepository) sortMapping(sort string) string {
if mapping, ok := r.sortMappings[sort]; ok {
return mapping
}
if mapping, ok := r.sortMappings[toCamelCase(sort)]; ok {
return mapping
}
sort = toSnakeCase(sort)
if mapping, ok := r.sortMappings[sort]; ok {
return mapping
}
return sort
}
func (r sqlRepository) buildSortOrder(sort, order string) string {
sort = r.sortMapping(sort)
order = strings.ToLower(strings.TrimSpace(order))
var reverseOrder string
if order == "desc" {
reverseOrder = "asc"
} else {
order = "asc"
reverseOrder = "desc"
}
parts := strings.FieldsFunc(sort, splitFunc(','))
newSort := make([]string, 0, len(parts))
for _, p := range parts {
f := strings.FieldsFunc(p, splitFunc(' '))
newField := make([]string, 1, len(f))
newField[0] = f[0]
if len(f) == 1 {
newField = append(newField, order)
} else {
if f[1] == "asc" {
newField = append(newField, order)
} else {
newField = append(newField, reverseOrder)
}
}
newSort = append(newSort, strings.Join(newField, " "))
}
return strings.Join(newSort, ", ")
}
func splitFunc(delimiter rune) func(c rune) bool {
open := 0
return func(c rune) bool {
if c == '(' {
open++
return false
}
if open > 0 {
if c == ')' {
open--
}
return false
}
return c == delimiter
}
}
func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
if len(options) > 0 && options[0].Filters != nil {
sq = sq.Where(options[0].Filters)
}
return sq
}
func (r *sqlRepository) withTableName(filter filterFunc) filterFunc {
return func(field string, value any) Sqlizer {
if r.tableName != "" {
field = r.tableName + "." + field
}
return filter(field, value)
}
}
// libraryIdFilter is a filter function to be added to resources that have a library_id column.
func libraryIdFilter(_ string, value interface{}) Sqlizer {
return Eq{"library_id": value}
}
// applyLibraryFilter adds library filtering to queries for tables that have a library_id column
// This ensures users only see content from libraries they have access to
func (r sqlRepository) applyLibraryFilter(sq SelectBuilder, tableName ...string) SelectBuilder {
user := loggedUser(r.ctx)
// If the user is an admin, or the user ID is invalid (e.g., when no user is logged in), skip the library filter
if user.IsAdmin || user.ID == invalidUserId {
return sq
}
table := r.tableName
if len(tableName) > 0 {
table = tableName[0]
}
// Get user's accessible library IDs
// Use subquery to filter by user's library access
return sq.Where(Expr(table+".library_id IN ("+
"SELECT ul.library_id FROM user_library ul WHERE ul.user_id = ?)", user.ID))
}
func (r sqlRepository) seedKey() string {
// Seed keys must be all lowercase, or else SQLite3 will encode it, making it not match the seed
// used in the query. Hashing the user ID and converting it to a hex string will do the trick
userIDHash := md5.Sum([]byte(loggedUser(r.ctx).ID))
return fmt.Sprintf("%s|%x", r.tableName, userIDHash)
}
func (r sqlRepository) resetSeededRandom(options []model.QueryOptions) {
if len(options) == 0 || options[0].Sort != "random" {
return
}
options[0].Sort = fmt.Sprintf("SEEDEDRAND('%s', %s.id)", r.seedKey(), r.tableName)
if options[0].Seed != "" {
hasher.SetSeed(r.seedKey(), options[0].Seed)
return
}
if options[0].Offset == 0 {
hasher.Reseed(r.seedKey())
}
}
func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
query, args, err := r.toSQL(sq)
if err != nil {
return 0, err
}
start := time.Now()
var c int64
res, err := r.db.NewQuery(query).Bind(args).WithContext(r.ctx).Execute()
if res != nil {
c, _ = res.RowsAffected()
}
r.logSQL(query, args, err, c, start)
if err != nil {
if err.Error() != "LastInsertId is not supported by this driver" {
return 0, err
}
}
return c, err
}
var placeholderRegex = regexp.MustCompile(`\?`)
func (r sqlRepository) toSQL(sq Sqlizer) (string, dbx.Params, error) {
query, args, err := sq.ToSql()
if err != nil {
return "", nil, err
}
// Replace query placeholders with named params
params := make(dbx.Params, len(args))
counter := 0
result := placeholderRegex.ReplaceAllStringFunc(query, func(_ string) string {
p := fmt.Sprintf("p%d", counter)
params[p] = args[counter]
counter++
return "{:" + p + "}"
})
return result, params, nil
}
func (r sqlRepository) queryOne(sq Sqlizer, response interface{}) error {
query, args, err := r.toSQL(sq)
if err != nil {
return err
}
start := time.Now()
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).One(response)
if errors.Is(err, sql.ErrNoRows) {
r.logSQL(query, args, nil, 0, start)
return model.ErrNotFound
}
r.logSQL(query, args, err, 1, start)
return err
}
// queryWithStableResults is a helper function to execute a query and return an iterator that will yield its results
// from a cursor, guaranteeing that the results will be stable, even if the underlying data changes.
func queryWithStableResults[T any](r sqlRepository, sq SelectBuilder, options ...model.QueryOptions) (iter.Seq2[T, error], error) {
if len(options) > 0 && options[0].Offset > 0 {
sq = r.optimizePagination(sq, options[0])
}
query, args, err := r.toSQL(sq)
if err != nil {
return nil, err
}
start := time.Now()
rows, err := r.db.NewQuery(query).Bind(args).WithContext(r.ctx).Rows()
r.logSQL(query, args, err, -1, start)
if err != nil {
return nil, err
}
return func(yield func(T, error) bool) {
defer rows.Close()
for rows.Next() {
var row T
err := rows.ScanStruct(&row)
if !yield(row, err) || err != nil {
return
}
}
if err := rows.Err(); err != nil {
var empty T
yield(empty, err)
}
}, nil
}
func (r sqlRepository) queryAll(sq SelectBuilder, response interface{}, options ...model.QueryOptions) error {
if len(options) > 0 && options[0].Offset > 0 {
sq = r.optimizePagination(sq, options[0])
}
query, args, err := r.toSQL(sq)
if err != nil {
return err
}
start := time.Now()
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).All(response)
if errors.Is(err, sql.ErrNoRows) {
r.logSQL(query, args, nil, -1, start)
return model.ErrNotFound
}
r.logSQL(query, args, err, int64(reflect.ValueOf(response).Elem().Len()), start)
return err
}
// queryAllSlice is a helper function to query a single column and return the result in a slice
func (r sqlRepository) queryAllSlice(sq SelectBuilder, response interface{}) error {
query, args, err := r.toSQL(sq)
if err != nil {
return err
}
start := time.Now()
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).Column(response)
if errors.Is(err, sql.ErrNoRows) {
r.logSQL(query, args, nil, -1, start)
return model.ErrNotFound
}
r.logSQL(query, args, err, int64(reflect.ValueOf(response).Elem().Len()), start)
return err
}
// optimizePagination uses a less inefficient pagination, by not using OFFSET.
// See https://gist.github.com/ssokolow/262503
func (r sqlRepository) optimizePagination(sq SelectBuilder, options model.QueryOptions) SelectBuilder {
if options.Offset > conf.Server.DevOffsetOptimize {
sq = sq.RemoveOffset()
rowidSq := sq.RemoveColumns().Columns(r.tableName + ".rowid")
rowidSq = rowidSq.Limit(uint64(options.Offset))
rowidSql, args, _ := rowidSq.ToSql()
sq = sq.Where(r.tableName+".rowid not in ("+rowidSql+")", args...)
}
return sq
}
func (r sqlRepository) exists(cond Sqlizer) (bool, error) {
existsQuery := Select("count(*) as exist").From(r.tableName).Where(cond)
var res struct{ Exist int64 }
err := r.queryOne(existsQuery, &res)
return res.Exist > 0, err
}
func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOptions) (int64, error) {
countQuery = countQuery.
RemoveColumns().Columns("count(distinct " + r.tableName + ".id) as count").
RemoveOffset().RemoveLimit().
OrderBy(r.tableName + ".id"). // To remove any ORDER BY clause that could slow down the query
From(r.tableName)
countQuery = r.applyFilters(countQuery, options...)
var res struct{ Count int64 }
err := r.queryOne(countQuery, &res)
return res.Count, err
}
func (r sqlRepository) putByMatch(filter Sqlizer, id string, m interface{}, colsToUpdate ...string) (string, error) {
if id != "" {
return r.put(id, m, colsToUpdate...)
}
existsQuery := r.newSelect().Columns("id").From(r.tableName).Where(filter)
var res struct{ ID string }
err := r.queryOne(existsQuery, &res)
if err != nil && !errors.Is(err, model.ErrNotFound) {
return "", err
}
return r.put(res.ID, m, colsToUpdate...)
}
func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (newId string, err error) {
values, err := toSQLArgs(m)
if err != nil {
return "", fmt.Errorf("error preparing values to write to DB: %w", err)
}
// If there's an ID, try to update first
if id != "" {
updateValues := map[string]interface{}{}
// This is a map of the columns that need to be updated, if specified
c2upd := slice.ToMap(colsToUpdate, func(s string) (string, struct{}) {
return toSnakeCase(s), struct{}{}
})
for k, v := range values {
if _, found := c2upd[k]; len(c2upd) == 0 || found {
updateValues[k] = v
}
}
updateValues["id"] = id
delete(updateValues, "created_at")
// To avoid updating the media_file birth_time on each scan. Not the best solution, but it works for now
// TODO move to mediafile_repository when each repo has its own upsert method
delete(updateValues, "birth_time")
update := Update(r.tableName).Where(Eq{"id": id}).SetMap(updateValues)
count, err := r.executeSQL(update)
if err != nil {
return "", err
}
if count > 0 {
return id, nil
}
}
// If it does not have an ID OR the ID was not found (when it is a new record with predefined id)
if id == "" {
id = id2.NewRandom()
values["id"] = id
}
insert := Insert(r.tableName).SetMap(values)
_, err = r.executeSQL(insert)
return id, err
}
func (r sqlRepository) delete(cond Sqlizer) error {
del := Delete(r.tableName).Where(cond)
_, err := r.executeSQL(del)
if errors.Is(err, sql.ErrNoRows) {
return model.ErrNotFound
}
return err
}
func (r sqlRepository) logSQL(sql string, args dbx.Params, err error, rowsAffected int64, start time.Time) {
elapsed := time.Since(start)
if err == nil || errors.Is(err, context.Canceled) {
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
} else {
log.Error(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
}
}

View File

@@ -0,0 +1,284 @@
package persistence
import (
"context"
"github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
"github.com/navidrome/navidrome/utils/hasher"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("sqlRepository", func() {
var r sqlRepository
BeforeEach(func() {
r.ctx = request.WithUser(context.Background(), model.User{ID: "user-id"})
r.tableName = "table"
})
Describe("applyOptions", func() {
var sq squirrel.SelectBuilder
BeforeEach(func() {
sq = squirrel.Select("*").From("test")
r.sortMappings = map[string]string{
"name": "title",
}
})
It("does not add any clauses when options is empty", func() {
sq = r.applyOptions(sq, model.QueryOptions{})
sql, _, _ := sq.ToSql()
Expect(sql).To(Equal("SELECT * FROM test"))
})
It("adds all option clauses", func() {
sq = r.applyOptions(sq, model.QueryOptions{
Sort: "name",
Order: "desc",
Max: 1,
Offset: 2,
})
sql, _, _ := sq.ToSql()
Expect(sql).To(Equal("SELECT * FROM test ORDER BY title desc LIMIT 1 OFFSET 2"))
})
})
Describe("toSQL", func() {
It("returns error for invalid SQL", func() {
sq := squirrel.Select("*").From("test").Where(1)
_, _, err := r.toSQL(sq)
Expect(err).To(HaveOccurred())
})
It("returns the same query when there are no placeholders", func() {
sq := squirrel.Select("*").From("test")
query, params, err := r.toSQL(sq)
Expect(err).NotTo(HaveOccurred())
Expect(query).To(Equal("SELECT * FROM test"))
Expect(params).To(BeEmpty())
})
It("replaces one placeholder correctly", func() {
sq := squirrel.Select("*").From("test").Where(squirrel.Eq{"id": 1})
query, params, err := r.toSQL(sq)
Expect(err).NotTo(HaveOccurred())
Expect(query).To(Equal("SELECT * FROM test WHERE id = {:p0}"))
Expect(params).To(HaveKeyWithValue("p0", 1))
})
It("replaces multiple placeholders correctly", func() {
sq := squirrel.Select("*").From("test").Where(squirrel.Eq{"id": 1, "name": "test"})
query, params, err := r.toSQL(sq)
Expect(err).NotTo(HaveOccurred())
Expect(query).To(Equal("SELECT * FROM test WHERE id = {:p0} AND name = {:p1}"))
Expect(params).To(HaveKeyWithValue("p0", 1))
Expect(params).To(HaveKeyWithValue("p1", "test"))
})
})
Describe("sanitizeSort", func() {
BeforeEach(func() {
r.registerModel(&struct {
Field string `structs:"field"`
}{}, nil)
r.sortMappings = map[string]string{
"sort1": "mappedSort1",
}
})
When("sanitizing sort", func() {
It("returns empty if the sort key is not found in the model nor in the mappings", func() {
sort, _ := r.sanitizeSort("unknown", "")
Expect(sort).To(BeEmpty())
})
It("returns the mapped value when sort key exists", func() {
sort, _ := r.sanitizeSort("sort1", "")
Expect(sort).To(Equal("mappedSort1"))
})
It("is case insensitive", func() {
sort, _ := r.sanitizeSort("Sort1", "")
Expect(sort).To(Equal("mappedSort1"))
})
It("returns the field if it is a valid field", func() {
sort, _ := r.sanitizeSort("field", "")
Expect(sort).To(Equal("field"))
})
It("is case insensitive for fields", func() {
sort, _ := r.sanitizeSort("FIELD", "")
Expect(sort).To(Equal("field"))
})
})
When("sanitizing order", func() {
It("returns 'asc' if order is empty", func() {
_, order := r.sanitizeSort("", "")
Expect(order).To(Equal(""))
})
It("returns 'asc' if order is 'asc'", func() {
_, order := r.sanitizeSort("", "ASC")
Expect(order).To(Equal("asc"))
})
It("returns 'desc' if order is 'desc'", func() {
_, order := r.sanitizeSort("", "desc")
Expect(order).To(Equal("desc"))
})
It("returns 'asc' if order is unknown", func() {
_, order := r.sanitizeSort("", "something")
Expect(order).To(Equal("asc"))
})
})
})
Describe("buildSortOrder", func() {
BeforeEach(func() {
r.sortMappings = map[string]string{}
})
Context("single field", func() {
It("sorts by specified field", func() {
sql := r.buildSortOrder("name", "desc")
Expect(sql).To(Equal("name desc"))
})
It("defaults to 'asc'", func() {
sql := r.buildSortOrder("name", "")
Expect(sql).To(Equal("name asc"))
})
It("inverts pre-defined order", func() {
sql := r.buildSortOrder("name desc", "desc")
Expect(sql).To(Equal("name asc"))
})
It("forces snake case for field names", func() {
sql := r.buildSortOrder("AlbumArtist", "asc")
Expect(sql).To(Equal("album_artist asc"))
})
})
Context("multiple fields", func() {
It("handles multiple fields", func() {
sql := r.buildSortOrder("name desc,age asc, status desc ", "asc")
Expect(sql).To(Equal("name desc, age asc, status desc"))
})
It("inverts multiple fields", func() {
sql := r.buildSortOrder("name desc, age, status asc", "desc")
Expect(sql).To(Equal("name asc, age desc, status desc"))
})
It("handles spaces in mapped field", func() {
r.sortMappings = map[string]string{
"has_lyrics": "(lyrics != '[]'), updated_at",
}
sql := r.buildSortOrder("has_lyrics", "desc")
Expect(sql).To(Equal("(lyrics != '[]') desc, updated_at desc"))
})
})
Context("function fields", func() {
It("handles functions with multiple params", func() {
sql := r.buildSortOrder("substr(id, 7)", "asc")
Expect(sql).To(Equal("substr(id, 7) asc"))
})
It("handles functions with multiple params mixed with multiple fields", func() {
sql := r.buildSortOrder("name desc, substr(id, 7), status asc", "desc")
Expect(sql).To(Equal("name asc, substr(id, 7) desc, status desc"))
})
It("handles nested functions", func() {
sql := r.buildSortOrder("name desc, coalesce(nullif(release_date, ''), nullif(original_date, '')), status asc", "desc")
Expect(sql).To(Equal("name asc, coalesce(nullif(release_date, ''), nullif(original_date, '')) desc, status desc"))
})
})
})
Describe("resetSeededRandom", func() {
var id string
BeforeEach(func() {
id = r.seedKey()
hasher.SetSeed(id, "")
})
It("does not reset seed if sort is not random", func() {
var options []model.QueryOptions
r.resetSeededRandom(options)
Expect(hasher.CurrentSeed(id)).To(BeEmpty())
})
It("resets seed if sort is random", func() {
options := []model.QueryOptions{{Sort: "random"}}
r.resetSeededRandom(options)
Expect(hasher.CurrentSeed(id)).NotTo(BeEmpty())
})
It("resets seed if sort is random and seed is provided", func() {
options := []model.QueryOptions{{Sort: "random", Seed: "seed"}}
r.resetSeededRandom(options)
Expect(hasher.CurrentSeed(id)).To(Equal("seed"))
})
It("keeps seed when paginating", func() {
options := []model.QueryOptions{{Sort: "random", Seed: "seed", Offset: 0}}
r.resetSeededRandom(options)
Expect(hasher.CurrentSeed(id)).To(Equal("seed"))
options = []model.QueryOptions{{Sort: "random", Offset: 1}}
r.resetSeededRandom(options)
Expect(hasher.CurrentSeed(id)).To(Equal("seed"))
})
})
Describe("applyLibraryFilter", func() {
var sq squirrel.SelectBuilder
BeforeEach(func() {
sq = squirrel.Select("*").From("test_table")
})
Context("Admin User", func() {
BeforeEach(func() {
r.ctx = request.WithUser(context.Background(), model.User{ID: "admin", IsAdmin: true})
})
It("should not apply library filter for admin users", func() {
result := r.applyLibraryFilter(sq)
sql, _, _ := result.ToSql()
Expect(sql).To(Equal("SELECT * FROM test_table"))
})
})
Context("Regular User", func() {
BeforeEach(func() {
r.ctx = request.WithUser(context.Background(), model.User{ID: "user123", IsAdmin: false})
})
It("should apply library filter for regular users", func() {
result := r.applyLibraryFilter(sq)
sql, args, _ := result.ToSql()
Expect(sql).To(ContainSubstring("IN (SELECT ul.library_id FROM user_library ul WHERE ul.user_id = ?)"))
Expect(args).To(ContainElement("user123"))
})
It("should use custom table name when provided", func() {
result := r.applyLibraryFilter(sq, "custom_table")
sql, args, _ := result.ToSql()
Expect(sql).To(ContainSubstring("custom_table.library_id IN"))
Expect(args).To(ContainElement("user123"))
})
})
Context("Headless Process (No User Context)", func() {
BeforeEach(func() {
r.ctx = context.Background() // No user context
})
It("should not apply library filter for headless processes", func() {
result := r.applyLibraryFilter(sq)
sql, _, _ := result.ToSql()
Expect(sql).To(Equal("SELECT * FROM test_table"))
})
It("should not apply library filter even with custom table name", func() {
result := r.applyLibraryFilter(sq, "custom_table")
sql, _, _ := result.ToSql()
Expect(sql).To(Equal("SELECT * FROM test_table"))
})
})
})
})

View File

@@ -0,0 +1,157 @@
package persistence
import (
"database/sql"
"errors"
"fmt"
"time"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
)
const bookmarkTable = "bookmark"
func (r sqlRepository) withBookmark(query SelectBuilder, idField string) SelectBuilder {
userID := loggedUser(r.ctx).ID
if userID == invalidUserId {
return query
}
return query.
LeftJoin("bookmark on (" +
"bookmark.item_id = " + idField +
" AND bookmark.user_id = '" + userID + "')").
Columns("coalesce(position, 0) as bookmark_position")
}
func (r sqlRepository) bmkID(itemID ...string) And {
return And{
Eq{bookmarkTable + ".user_id": loggedUser(r.ctx).ID},
Eq{bookmarkTable + ".item_type": r.tableName},
Eq{bookmarkTable + ".item_id": itemID},
}
}
func (r sqlRepository) bmkUpsert(itemID, comment string, position int64) error {
client, _ := request.ClientFrom(r.ctx)
user, _ := request.UserFrom(r.ctx)
values := map[string]interface{}{
"comment": comment,
"position": position,
"updated_at": time.Now(),
"changed_by": client,
}
upd := Update(bookmarkTable).Where(r.bmkID(itemID)).SetMap(values)
c, err := r.executeSQL(upd)
if err == nil {
log.Debug(r.ctx, "Updated bookmark", "id", itemID, "user", user.UserName, "position", position, "comment", comment)
}
if c == 0 || errors.Is(err, sql.ErrNoRows) {
values["user_id"] = user.ID
values["item_type"] = r.tableName
values["item_id"] = itemID
values["created_at"] = time.Now()
values["updated_at"] = time.Now()
ins := Insert(bookmarkTable).SetMap(values)
_, err = r.executeSQL(ins)
if err != nil {
return err
}
log.Debug(r.ctx, "Added bookmark", "id", itemID, "user", user.UserName, "position", position, "comment", comment)
}
return err
}
func (r sqlRepository) AddBookmark(id, comment string, position int64) error {
user, _ := request.UserFrom(r.ctx)
err := r.bmkUpsert(id, comment, position)
if err != nil {
log.Error(r.ctx, "Error adding bookmark", "id", id, "user", user.UserName, "position", position, "comment", comment)
}
return err
}
func (r sqlRepository) DeleteBookmark(id string) error {
user, _ := request.UserFrom(r.ctx)
del := Delete(bookmarkTable).Where(r.bmkID(id))
_, err := r.executeSQL(del)
if err != nil {
log.Error(r.ctx, "Error removing bookmark", "id", id, "user", user.UserName)
}
return err
}
type bookmark struct {
UserID string `json:"user_id"`
ItemID string `json:"item_id"`
ItemType string `json:"item_type"`
Comment string `json:"comment"`
Position int64 `json:"position"`
ChangedBy string `json:"changed_by"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
func (r sqlRepository) GetBookmarks() (model.Bookmarks, error) {
user, _ := request.UserFrom(r.ctx)
idField := r.tableName + ".id"
sq := r.newSelect().Columns(r.tableName + ".*")
sq = r.withAnnotation(sq, idField)
sq = r.withBookmark(sq, idField).Where(NotEq{bookmarkTable + ".item_id": nil})
var mfs dbMediaFiles // TODO Decouple from media_file
err := r.queryAll(sq, &mfs)
if err != nil {
log.Error(r.ctx, "Error getting mediafiles with bookmarks", "user", user.UserName, err)
return nil, err
}
ids := make([]string, len(mfs))
mfMap := make(map[string]int)
for i, mf := range mfs {
ids[i] = mf.ID
mfMap[mf.ID] = i
}
sq = Select("*").From(bookmarkTable).Where(r.bmkID(ids...))
var bmks []bookmark
err = r.queryAll(sq, &bmks)
if err != nil {
log.Error(r.ctx, "Error getting bookmarks", "user", user.UserName, "ids", ids, err)
return nil, err
}
resp := make(model.Bookmarks, len(bmks))
for i, bmk := range bmks {
if itemIdx, ok := mfMap[bmk.ItemID]; !ok {
log.Debug(r.ctx, "Invalid bookmark", "id", bmk.ItemID, "user", user.UserName)
continue
} else {
resp[i] = model.Bookmark{
Comment: bmk.Comment,
Position: bmk.Position,
CreatedAt: bmk.CreatedAt,
UpdatedAt: bmk.UpdatedAt,
ChangedBy: bmk.ChangedBy,
Item: *mfs[itemIdx].MediaFile,
}
}
}
return resp, nil
}
func (r sqlRepository) cleanBookmarks() error {
del := Delete(bookmarkTable).Where(Eq{"item_type": r.tableName}).Where("item_id not in (select id from " + r.tableName + ")")
c, err := r.executeSQL(del)
if err != nil {
return fmt.Errorf("error cleaning up %s bookmarks: %w", r.tableName, err)
}
if c > 0 {
log.Debug(r.ctx, "Clean-up bookmarks", "totalDeleted", c, "itemType", r.tableName)
}
return nil
}

View File

@@ -0,0 +1,74 @@
package persistence
import (
"context"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("sqlBookmarks", func() {
var mr model.MediaFileRepository
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid"})
mr = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
})
Describe("Bookmarks", func() {
It("returns an empty collection if there are no bookmarks", func() {
Expect(mr.GetBookmarks()).To(BeEmpty())
})
It("saves and overrides bookmarks", func() {
By("Saving the bookmark")
Expect(mr.AddBookmark(songAntenna.ID, "this is a comment", 123)).To(BeNil())
bms, err := mr.GetBookmarks()
Expect(err).ToNot(HaveOccurred())
Expect(bms).To(HaveLen(1))
Expect(bms[0].Item.ID).To(Equal(songAntenna.ID))
Expect(bms[0].Item.Title).To(Equal(songAntenna.Title))
Expect(bms[0].Comment).To(Equal("this is a comment"))
Expect(bms[0].Position).To(Equal(int64(123)))
created := bms[0].CreatedAt
updated := bms[0].UpdatedAt
Expect(created.IsZero()).To(BeFalse())
Expect(updated).To(BeTemporally(">=", created))
By("Overriding the bookmark")
Expect(mr.AddBookmark(songAntenna.ID, "another comment", 333)).To(BeNil())
bms, err = mr.GetBookmarks()
Expect(err).ToNot(HaveOccurred())
Expect(bms[0].Item.ID).To(Equal(songAntenna.ID))
Expect(bms[0].Comment).To(Equal("another comment"))
Expect(bms[0].Position).To(Equal(int64(333)))
Expect(bms[0].CreatedAt).To(Equal(created))
Expect(bms[0].UpdatedAt).To(BeTemporally(">=", updated))
By("Saving another bookmark")
Expect(mr.AddBookmark(songComeTogether.ID, "one more comment", 444)).To(BeNil())
bms, err = mr.GetBookmarks()
Expect(err).ToNot(HaveOccurred())
Expect(bms).To(HaveLen(2))
By("Delete bookmark")
Expect(mr.DeleteBookmark(songAntenna.ID)).To(Succeed())
bms, err = mr.GetBookmarks()
Expect(err).ToNot(HaveOccurred())
Expect(bms).To(HaveLen(1))
Expect(bms[0].Item.ID).To(Equal(songComeTogether.ID))
Expect(bms[0].Item.Title).To(Equal(songComeTogether.Title))
Expect(mr.DeleteBookmark(songComeTogether.ID)).To(Succeed())
Expect(mr.GetBookmarks()).To(BeEmpty())
})
})
})

View File

@@ -0,0 +1,119 @@
package persistence
import (
"encoding/json"
"fmt"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/slice"
)
type participant struct {
ID string `json:"id"`
Name string `json:"name"`
SubRole string `json:"subRole,omitempty"`
}
// flatParticipant represents a flattened participant structure for SQL processing
type flatParticipant struct {
ArtistID string `json:"artist_id"`
Role string `json:"role"`
SubRole string `json:"sub_role,omitempty"`
}
func marshalParticipants(participants model.Participants) string {
dbParticipants := make(map[model.Role][]participant)
for role, artists := range participants {
for _, artist := range artists {
dbParticipants[role] = append(dbParticipants[role], participant{ID: artist.ID, SubRole: artist.SubRole, Name: artist.Name})
}
}
res, _ := json.Marshal(dbParticipants)
return string(res)
}
func unmarshalParticipants(data string) (model.Participants, error) {
var dbParticipants map[model.Role][]participant
err := json.Unmarshal([]byte(data), &dbParticipants)
if err != nil {
return nil, fmt.Errorf("parsing participants: %w", err)
}
participants := make(model.Participants, len(dbParticipants))
for role, participantList := range dbParticipants {
artists := slice.Map(participantList, func(p participant) model.Participant {
return model.Participant{Artist: model.Artist{ID: p.ID, Name: p.Name}, SubRole: p.SubRole}
})
participants[role] = artists
}
return participants, nil
}
func (r sqlRepository) updateParticipants(itemID string, participants model.Participants) error {
ids := participants.AllIDs()
sqd := Delete(r.tableName + "_artists").Where(And{Eq{r.tableName + "_id": itemID}, NotEq{"artist_id": ids}})
_, err := r.executeSQL(sqd)
if err != nil {
return err
}
if len(participants) == 0 {
return nil
}
var flatParticipants []flatParticipant
for role, artists := range participants {
for _, artist := range artists {
flatParticipants = append(flatParticipants, flatParticipant{
ArtistID: artist.ID,
Role: role.String(),
SubRole: artist.SubRole,
})
}
}
participantsJSON, err := json.Marshal(flatParticipants)
if err != nil {
return fmt.Errorf("marshaling participants: %w", err)
}
// Build the INSERT query using json_each and INNER JOIN to artist table
// to automatically filter out non-existent artist IDs
query := fmt.Sprintf(`
INSERT INTO %[1]s_artists (%[1]s_id, artist_id, role, sub_role)
SELECT ?,
json_extract(value, '$.artist_id') as artist_id,
json_extract(value, '$.role') as role,
COALESCE(json_extract(value, '$.sub_role'), '') as sub_role
-- Parse the flat JSON array: [{"artist_id": "id", "role": "role", "sub_role": "subRole"}]
FROM json_each(?) -- Iterate through each array element
-- CRITICAL: Only insert records for artists that actually exist in the database
JOIN artist ON artist.id = json_extract(value, '$.artist_id') -- Filter out non-existent artist IDs via INNER JOIN
-- Handle duplicate insertions gracefully (e.g., if called multiple times)
ON CONFLICT (artist_id, %[1]s_id, role, sub_role) DO NOTHING -- Ignore duplicates
`, r.tableName)
_, err = r.executeSQL(Expr(query, itemID, string(participantsJSON)))
return err
}
func (r *sqlRepository) getParticipants(m *model.MediaFile) (model.Participants, error) {
artistRepo := NewArtistRepository(r.ctx, r.db, nil)
ids := m.Participants.AllIDs()
artists, err := artistRepo.GetAll(model.QueryOptions{Filters: Eq{"artist.id": ids}})
if err != nil {
return nil, fmt.Errorf("getting participants: %w", err)
}
artistMap := slice.ToMap(artists, func(a model.Artist) (string, model.Artist) {
return a.ID, a
})
p := m.Participants
for role, artistList := range p {
for idx, artist := range artistList {
if a, ok := artistMap[artist.ID]; ok {
p[role][idx].Artist = a
}
}
}
return p, nil
}

180
persistence/sql_restful.go Normal file
View File

@@ -0,0 +1,180 @@
package persistence
import (
"cmp"
"context"
"fmt"
"reflect"
"strings"
"sync"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/fatih/structs"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
)
type filterFunc = func(field string, value any) Sqlizer
func (r *sqlRepository) parseRestFilters(ctx context.Context, options rest.QueryOptions) Sqlizer {
if len(options.Filters) == 0 {
return nil
}
filters := And{}
for f, v := range options.Filters {
// Ignore filters with empty values
if v == "" {
continue
}
// Look for a custom filter function
f = strings.ToLower(f)
if ff, ok := r.filterMappings[f]; ok {
if filter := ff(f, v); filter != nil {
filters = append(filters, filter)
}
continue
}
// Ignore invalid filters (not based on a field or filter function)
if r.isFieldWhiteListed != nil && !r.isFieldWhiteListed(f) {
log.Warn(ctx, "Ignoring filter not whitelisted", "filter", f, "table", r.tableName)
continue
}
// For fields ending in "id", use an exact match
if strings.HasSuffix(f, "id") {
filters = append(filters, eqFilter(f, v))
continue
}
// Default to a "starts with" filter
filters = append(filters, startsWithFilter(f, v))
}
return filters
}
func (r *sqlRepository) parseRestOptions(ctx context.Context, options ...rest.QueryOptions) model.QueryOptions {
qo := model.QueryOptions{}
if len(options) > 0 {
qo.Sort, qo.Order = r.sanitizeSort(options[0].Sort, options[0].Order)
qo.Max = options[0].Max
qo.Offset = options[0].Offset
if seed, ok := options[0].Filters["seed"].(string); ok {
qo.Seed = seed
delete(options[0].Filters, "seed")
}
qo.Filters = r.parseRestFilters(ctx, options[0])
}
return qo
}
func (r sqlRepository) sanitizeSort(sort, order string) (string, string) {
if sort != "" {
sort = toSnakeCase(sort)
if mapped, ok := r.sortMappings[sort]; ok {
sort = mapped
} else {
if !r.isFieldWhiteListed(sort) {
log.Warn(r.ctx, "Ignoring sort not whitelisted", "sort", sort, "table", r.tableName)
sort = ""
}
}
}
if order != "" {
order = strings.ToLower(order)
if order != "desc" {
order = "asc"
}
}
return sort, order
}
func eqFilter(field string, value any) Sqlizer {
return Eq{field: value}
}
func startsWithFilter(field string, value any) Sqlizer {
return Like{field: fmt.Sprintf("%s%%", value)}
}
func containsFilter(field string) func(string, any) Sqlizer {
return func(_ string, value any) Sqlizer {
return Like{field: fmt.Sprintf("%%%s%%", value)}
}
}
func booleanFilter(field string, value any) Sqlizer {
v := strings.ToLower(value.(string))
return Eq{field: v == "true"}
}
func fullTextFilter(tableName string, mbidFields ...string) func(string, any) Sqlizer {
return func(field string, value any) Sqlizer {
v := strings.ToLower(value.(string))
cond := cmp.Or(
mbidExpr(tableName, v, mbidFields...),
fullTextExpr(tableName, v),
)
return cond
}
}
func substringFilter(field string, value any) Sqlizer {
parts := strings.Fields(value.(string))
filters := And{}
for _, part := range parts {
filters = append(filters, Like{field: "%" + part + "%"})
}
return filters
}
func idFilter(tableName string) func(string, any) Sqlizer {
return func(field string, value any) Sqlizer { return Eq{tableName + ".id": value} }
}
func invalidFilter(ctx context.Context) func(string, any) Sqlizer {
return func(field string, value any) Sqlizer {
log.Warn(ctx, "Invalid filter", "fieldName", field, "value", value)
return Eq{"1": "0"}
}
}
var (
whiteList = map[string]map[string]struct{}{}
mutex sync.RWMutex
)
func registerModelWhiteList(instance any) fieldWhiteListedFunc {
name := reflect.TypeOf(instance).String()
registerFieldWhiteList(name, instance)
return getFieldWhiteListedFunc(name)
}
func registerFieldWhiteList(name string, instance any) {
mutex.Lock()
defer mutex.Unlock()
if whiteList[name] != nil {
return
}
m := structs.Map(instance)
whiteList[name] = map[string]struct{}{}
for k := range m {
whiteList[name][toSnakeCase(k)] = struct{}{}
}
ma := structs.Map(model.Annotations{})
for k := range ma {
whiteList[name][toSnakeCase(k)] = struct{}{}
}
}
type fieldWhiteListedFunc func(field string) bool
func getFieldWhiteListedFunc(tableName string) fieldWhiteListedFunc {
return func(field string) bool {
mutex.RLock()
defer mutex.RUnlock()
if _, ok := whiteList[tableName]; !ok {
return false
}
_, ok := whiteList[tableName][field]
return ok
}
}

View File

@@ -0,0 +1,235 @@
package persistence
import (
"context"
"strings"
"github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/conf/configtest"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("sqlRestful", func() {
Describe("parseRestFilters", func() {
var r sqlRepository
var options rest.QueryOptions
BeforeEach(func() {
r = sqlRepository{}
})
It("returns nil if filters is empty", func() {
options.Filters = nil
Expect(r.parseRestFilters(context.Background(), options)).To(BeNil())
})
It(`returns nil if tries a filter with fullTextExpr("'")`, func() {
r.filterMappings = map[string]filterFunc{
"name": fullTextFilter("table"),
}
options.Filters = map[string]interface{}{"name": "'"}
Expect(r.parseRestFilters(context.Background(), options)).To(BeEmpty())
})
It("does not add nill filters", func() {
r.filterMappings = map[string]filterFunc{
"name": func(string, any) squirrel.Sqlizer {
return nil
},
}
options.Filters = map[string]interface{}{"name": "joe"}
Expect(r.parseRestFilters(context.Background(), options)).To(BeEmpty())
})
It("returns a '=' condition for 'id' filter", func() {
options.Filters = map[string]interface{}{"id": "123"}
Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Eq{"id": "123"}}))
})
It("returns a 'in' condition for multiples 'id' filters", func() {
options.Filters = map[string]interface{}{"id": []string{"123", "456"}}
Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Eq{"id": []string{"123", "456"}}}))
})
It("returns a 'like' condition for other filters", func() {
options.Filters = map[string]interface{}{"name": "joe"}
Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Like{"name": "joe%"}}))
})
It("uses the custom filter", func() {
r.filterMappings = map[string]filterFunc{
"test": func(field string, value interface{}) squirrel.Sqlizer {
return squirrel.Gt{field: value}
},
}
options.Filters = map[string]interface{}{"test": 100}
Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Gt{"test": 100}}))
})
})
Describe("fullTextFilter function", func() {
var filter filterFunc
var tableName string
var mbidFields []string
BeforeEach(func() {
DeferCleanup(configtest.SetupConfig())
tableName = "test_table"
mbidFields = []string{"mbid", "artist_mbid"}
filter = fullTextFilter(tableName, mbidFields...)
})
Context("when value is a valid UUID", func() {
It("returns only the mbid filter (precedence over full text)", func() {
uuid := "550e8400-e29b-41d4-a716-446655440000"
result := filter("search", uuid)
expected := squirrel.Or{
squirrel.Eq{"test_table.mbid": uuid},
squirrel.Eq{"test_table.artist_mbid": uuid},
}
Expect(result).To(Equal(expected))
})
It("falls back to full text when no mbid fields are provided", func() {
noMbidFilter := fullTextFilter(tableName)
uuid := "550e8400-e29b-41d4-a716-446655440000"
result := noMbidFilter("search", uuid)
// mbidExpr with no fields returns nil, so cmp.Or falls back to fullTextExpr
expected := squirrel.And{
squirrel.Like{"test_table.full_text": "% 550e8400-e29b-41d4-a716-446655440000%"},
}
Expect(result).To(Equal(expected))
})
})
Context("when value is not a valid UUID", func() {
It("returns full text search condition only", func() {
result := filter("search", "beatles")
// mbidExpr returns nil for non-UUIDs, so fullTextExpr result is returned directly
expected := squirrel.And{
squirrel.Like{"test_table.full_text": "% beatles%"},
}
Expect(result).To(Equal(expected))
})
It("handles multi-word search terms", func() {
result := filter("search", "the beatles abbey road")
// Should return And condition directly
andCondition, ok := result.(squirrel.And)
Expect(ok).To(BeTrue())
Expect(andCondition).To(HaveLen(4))
// Check that all words are present (order may vary)
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% the%"}))
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% beatles%"}))
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% abbey%"}))
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% road%"}))
})
})
Context("when SearchFullString config changes behavior", func() {
It("uses different separator with SearchFullString=false", func() {
conf.Server.SearchFullString = false
result := filter("search", "test query")
andCondition, ok := result.(squirrel.And)
Expect(ok).To(BeTrue())
Expect(andCondition).To(HaveLen(2))
// Check that all words are present with leading space (order may vary)
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% test%"}))
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% query%"}))
})
It("uses no separator with SearchFullString=true", func() {
conf.Server.SearchFullString = true
result := filter("search", "test query")
andCondition, ok := result.(squirrel.And)
Expect(ok).To(BeTrue())
Expect(andCondition).To(HaveLen(2))
// Check that all words are present without leading space (order may vary)
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "%test%"}))
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "%query%"}))
})
})
Context("edge cases", func() {
It("returns nil for empty string", func() {
result := filter("search", "")
Expect(result).To(BeNil())
})
It("returns nil for string with only whitespace", func() {
result := filter("search", " ")
Expect(result).To(BeNil())
})
It("handles special characters that are sanitized", func() {
result := filter("search", "don't")
expected := squirrel.And{
squirrel.Like{"test_table.full_text": "% dont%"}, // str.SanitizeStrings removes quotes
}
Expect(result).To(Equal(expected))
})
It("returns nil for single quote (SQL injection protection)", func() {
result := filter("search", "'")
Expect(result).To(BeNil())
})
It("handles mixed case UUIDs", func() {
uuid := "550E8400-E29B-41D4-A716-446655440000"
result := filter("search", uuid)
// Should return only mbid filter (uppercase UUID should work)
expected := squirrel.Or{
squirrel.Eq{"test_table.mbid": strings.ToLower(uuid)},
squirrel.Eq{"test_table.artist_mbid": strings.ToLower(uuid)},
}
Expect(result).To(Equal(expected))
})
It("handles invalid UUID format gracefully", func() {
result := filter("search", "550e8400-invalid-uuid")
// Should return full text filter since UUID is invalid
expected := squirrel.And{
squirrel.Like{"test_table.full_text": "% 550e8400-invalid-uuid%"},
}
Expect(result).To(Equal(expected))
})
It("handles empty mbid fields array", func() {
emptyMbidFilter := fullTextFilter(tableName, []string{}...)
result := emptyMbidFilter("search", "test")
// mbidExpr with empty fields returns nil, so cmp.Or falls back to fullTextExpr
expected := squirrel.And{
squirrel.Like{"test_table.full_text": "% test%"},
}
Expect(result).To(Equal(expected))
})
It("converts value to lowercase before processing", func() {
result := filter("search", "TEST")
// The function converts to lowercase internally
expected := squirrel.And{
squirrel.Like{"test_table.full_text": "% test%"},
}
Expect(result).To(Equal(expected))
})
})
})
})

77
persistence/sql_search.go Normal file
View File

@@ -0,0 +1,77 @@
package persistence
import (
"strings"
. "github.com/Masterminds/squirrel"
"github.com/google/uuid"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/str"
)
func formatFullText(text ...string) string {
fullText := str.SanitizeStrings(text...)
return " " + fullText
}
// doSearch performs a full-text search with the specified parameters.
// The naturalOrder is used to sort results when no full-text filter is applied. It is useful for cases like
// OpenSubsonic, where an empty search query should return all results in a natural order. Normally the parameter
// should be `tableName + ".rowid"`, but some repositories (ex: artist) may use a different natural order.
func (r sqlRepository) doSearch(sq SelectBuilder, q string, offset, size int, results any, naturalOrder string, orderBys ...string) error {
q = strings.TrimSpace(q)
q = strings.TrimSuffix(q, "*")
if len(q) < 2 {
return nil
}
filter := fullTextExpr(r.tableName, q)
if filter != nil {
sq = sq.Where(filter)
sq = sq.OrderBy(orderBys...)
} else {
// This is to speed up the results of `search3?query=""`, for OpenSubsonic
// If the filter is empty, we sort by the specified natural order.
sq = sq.OrderBy(naturalOrder)
}
sq = sq.Where(Eq{r.tableName + ".missing": false})
sq = sq.Limit(uint64(size)).Offset(uint64(offset))
return r.queryAll(sq, results, model.QueryOptions{Offset: offset})
}
func (r sqlRepository) searchByMBID(sq SelectBuilder, mbid string, mbidFields []string, results any) error {
sq = sq.Where(mbidExpr(r.tableName, mbid, mbidFields...))
sq = sq.Where(Eq{r.tableName + ".missing": false})
return r.queryAll(sq, results)
}
func mbidExpr(tableName, mbid string, mbidFields ...string) Sqlizer {
if uuid.Validate(mbid) != nil || len(mbidFields) == 0 {
return nil
}
mbid = strings.ToLower(mbid)
var cond []Sqlizer
for _, mbidField := range mbidFields {
cond = append(cond, Eq{tableName + "." + mbidField: mbid})
}
return Or(cond)
}
func fullTextExpr(tableName string, s string) Sqlizer {
q := str.SanitizeStrings(s)
if q == "" {
return nil
}
var sep string
if !conf.Server.SearchFullString {
sep = " "
}
parts := strings.Split(q, " ")
filters := And{}
for _, part := range parts {
filters = append(filters, Like{tableName + ".full_text": "%" + sep + part + "%"})
}
return filters
}

View File

@@ -0,0 +1,14 @@
package persistence
import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("sqlRepository", func() {
Describe("formatFullText", func() {
It("prefixes with a space", func() {
Expect(formatFullText("legiao urbana")).To(Equal(" legiao urbana"))
})
})
})

168
persistence/sql_tags.go Normal file
View File

@@ -0,0 +1,168 @@
package persistence
import (
"context"
"encoding/json"
"fmt"
"strings"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
// Format of a tag in the DB
type dbTag struct {
ID string `json:"id"`
Value string `json:"value"`
}
type dbTags map[model.TagName][]dbTag
func unmarshalTags(data string) (model.Tags, error) {
var dbTags dbTags
err := json.Unmarshal([]byte(data), &dbTags)
if err != nil {
return nil, fmt.Errorf("parsing tags: %w", err)
}
res := make(model.Tags, len(dbTags))
for name, tags := range dbTags {
res[name] = make([]string, len(tags))
for i, tag := range tags {
res[name][i] = tag.Value
}
}
return res, nil
}
func marshalTags(tags model.Tags) string {
dbTags := dbTags{}
for name, values := range tags {
for _, value := range values {
t := model.NewTag(name, value)
dbTags[name] = append(dbTags[name], dbTag{ID: t.ID, Value: value})
}
}
res, _ := json.Marshal(dbTags)
return string(res)
}
func tagIDFilter(name string, idValue any) Sqlizer {
name = strings.TrimSuffix(name, "_id")
return Exists(
fmt.Sprintf(`json_tree(tags, "$.%s")`, name),
And{
NotEq{"json_tree.atom": nil},
Eq{"value": idValue},
},
)
}
// tagLibraryIdFilter filters tags based on library access through the library_tag table
func tagLibraryIdFilter(_ string, value interface{}) Sqlizer {
return Eq{"library_tag.library_id": value}
}
// baseTagRepository provides common functionality for all tag-based repositories.
// It handles CRUD operations with optional filtering by tag name.
type baseTagRepository struct {
sqlRepository
tagFilter *model.TagName // nil = no filter (all tags), non-nil = filter by specific tag name
}
// newBaseTagRepository creates a new base tag repository with optional tag filtering.
// If tagFilter is nil, the repository will work with all tags.
// If tagFilter is provided, the repository will only work with tags of that specific name.
func newBaseTagRepository(ctx context.Context, db dbx.Builder, tagFilter *model.TagName) *baseTagRepository {
r := &baseTagRepository{
tagFilter: tagFilter,
}
r.ctx = ctx
r.db = db
r.tableName = "tag"
r.registerModel(&model.Tag{}, map[string]filterFunc{
"name": containsFilter("tag_value"),
"library_id": tagLibraryIdFilter,
})
r.setSortMappings(map[string]string{
"name": "tag_value",
})
return r
}
// applyLibraryFiltering adds the appropriate library joins based on user context
func (r *baseTagRepository) applyLibraryFiltering(sq SelectBuilder) SelectBuilder {
// Add library_tag join
sq = sq.LeftJoin("library_tag on library_tag.tag_id = tag.id")
// For authenticated users, also join with user_library to filter by accessible libraries
user := loggedUser(r.ctx)
if user.ID != invalidUserId {
sq = sq.Join("user_library on user_library.library_id = library_tag.library_id AND user_library.user_id = ?", user.ID)
}
return sq
}
// newSelect overrides the base implementation to apply tag name filtering and library filtering.
func (r *baseTagRepository) newSelect(options ...model.QueryOptions) SelectBuilder {
sq := r.sqlRepository.newSelect(options...)
// Apply tag name filtering if specified
if r.tagFilter != nil {
sq = sq.Where(Eq{"tag.tag_name": *r.tagFilter})
}
// Apply library filtering and set up aggregation columns
sq = r.applyLibraryFiltering(sq).Columns(
"tag.id",
"tag.tag_name",
"tag.tag_value",
"COALESCE(SUM(library_tag.album_count), 0) as album_count",
"COALESCE(SUM(library_tag.media_file_count), 0) as song_count",
).GroupBy("tag.id", "tag.tag_name", "tag.tag_value")
return sq
}
// ResourceRepository interface implementation
func (r *baseTagRepository) Count(options ...rest.QueryOptions) (int64, error) {
sq := Select("COUNT(DISTINCT tag.id)").From("tag")
// Apply tag name filtering if specified
if r.tagFilter != nil {
sq = sq.Where(Eq{"tag.tag_name": *r.tagFilter})
}
// Apply library filtering
sq = r.applyLibraryFiltering(sq)
return r.count(sq, r.parseRestOptions(r.ctx, options...))
}
func (r *baseTagRepository) Read(id string) (interface{}, error) {
query := r.newSelect().Where(Eq{"id": id})
var res model.Tag
err := r.queryOne(query, &res)
return &res, err
}
func (r *baseTagRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
query := r.newSelect(r.parseRestOptions(r.ctx, options...))
var res model.TagList
err := r.queryAll(query, &res)
return res, err
}
func (r *baseTagRepository) EntityName() string {
return "tag"
}
func (r *baseTagRepository) NewInstance() interface{} {
return model.Tag{}
}
// Interface compliance check
var _ model.ResourceRepository = (*baseTagRepository)(nil)

View File

@@ -0,0 +1,263 @@
package persistence
import (
"context"
"time"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/conf/configtest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
const (
adminUserID = "userid"
regularUserID = "2222"
libraryID1 = 1
libraryID2 = 2
libraryID3 = 3
tagNameGenre = "genre"
tagValueRock = "rock"
tagValuePop = "pop"
tagValueJazz = "jazz"
)
var _ = Describe("Tag Library Filtering", func() {
var (
tagRockID = id.NewTagID(tagNameGenre, tagValueRock)
tagPopID = id.NewTagID(tagNameGenre, tagValuePop)
tagJazzID = id.NewTagID(tagNameGenre, tagValueJazz)
)
expectTagValues := func(tagList model.TagList, expected []string) {
tagValues := make([]string, len(tagList))
for i, tag := range tagList {
tagValues[i] = tag.TagValue
}
Expect(tagValues).To(ContainElements(expected))
}
BeforeEach(func() {
DeferCleanup(configtest.SetupConfig())
// Generate unique path suffix to avoid conflicts with other tests
uniqueSuffix := time.Now().Format("20060102150405.000")
// Clean up database
db := GetDBXBuilder()
_, err := db.NewQuery("DELETE FROM library_tag").Execute()
Expect(err).ToNot(HaveOccurred())
_, err = db.NewQuery("DELETE FROM tag").Execute()
Expect(err).ToNot(HaveOccurred())
_, err = db.NewQuery("DELETE FROM user_library WHERE user_id != {:admin} AND user_id != {:regular}").
Bind(dbx.Params{"admin": adminUserID, "regular": regularUserID}).Execute()
Expect(err).ToNot(HaveOccurred())
_, err = db.NewQuery("DELETE FROM library WHERE id > 1").Execute()
Expect(err).ToNot(HaveOccurred())
// Create test libraries with unique names and paths to avoid conflicts with other tests
_, err = db.NewQuery("INSERT INTO library (id, name, path) VALUES ({:id}, {:name}, {:path})").
Bind(dbx.Params{"id": libraryID2, "name": "Library 2-" + uniqueSuffix, "path": "/music/lib2-" + uniqueSuffix}).Execute()
Expect(err).ToNot(HaveOccurred())
_, err = db.NewQuery("INSERT INTO library (id, name, path) VALUES ({:id}, {:name}, {:path})").
Bind(dbx.Params{"id": libraryID3, "name": "Library 3-" + uniqueSuffix, "path": "/music/lib3-" + uniqueSuffix}).Execute()
Expect(err).ToNot(HaveOccurred())
// Give admin access to all libraries
for _, libID := range []int{libraryID1, libraryID2, libraryID3} {
_, err = db.NewQuery("INSERT OR IGNORE INTO user_library (user_id, library_id) VALUES ({:user}, {:lib})").
Bind(dbx.Params{"user": adminUserID, "lib": libID}).Execute()
Expect(err).ToNot(HaveOccurred())
}
// Create test tags
adminCtx := request.WithUser(log.NewContext(context.TODO()), adminUser)
tagRepo := NewTagRepository(adminCtx, GetDBXBuilder())
createTag := func(libraryID int, name, value string) {
tag := model.Tag{ID: id.NewTagID(name, value), TagName: model.TagName(name), TagValue: value}
err := tagRepo.Add(libraryID, tag)
Expect(err).ToNot(HaveOccurred())
}
createTag(libraryID1, tagNameGenre, tagValueRock)
createTag(libraryID2, tagNameGenre, tagValuePop)
createTag(libraryID3, tagNameGenre, tagValueJazz)
createTag(libraryID2, tagNameGenre, tagValueRock) // Rock appears in both lib1 and lib2
// Set tag counts (manually for testing)
setCounts := func(tagID string, libID, albums, songs int) {
_, err := db.NewQuery("UPDATE library_tag SET album_count = {:albums}, media_file_count = {:songs} WHERE tag_id = {:tag} AND library_id = {:lib}").
Bind(dbx.Params{"albums": albums, "songs": songs, "tag": tagID, "lib": libID}).Execute()
Expect(err).ToNot(HaveOccurred())
}
setCounts(tagRockID, libraryID1, 5, 20)
setCounts(tagPopID, libraryID2, 3, 10)
setCounts(tagJazzID, libraryID3, 2, 8)
setCounts(tagRockID, libraryID2, 1, 4)
// Give regular user access to library 2 only
_, err = db.NewQuery("INSERT INTO user_library (user_id, library_id) VALUES ({:user}, {:lib})").
Bind(dbx.Params{"user": regularUserID, "lib": libraryID2}).Execute()
Expect(err).ToNot(HaveOccurred())
})
Describe("TagRepository Library Filtering", func() {
// Helper to create repository and read all tags
readAllTags := func(user *model.User, filters ...rest.QueryOptions) model.TagList {
var ctx context.Context
if user != nil {
ctx = request.WithUser(log.NewContext(context.TODO()), *user)
} else {
ctx = context.Background() // Headless context
}
tagRepo := NewTagRepository(ctx, GetDBXBuilder())
repo := tagRepo.(model.ResourceRepository)
var opts rest.QueryOptions
if len(filters) > 0 {
opts = filters[0]
}
tags, err := repo.ReadAll(opts)
Expect(err).ToNot(HaveOccurred())
return tags.(model.TagList)
}
// Helper to count tags
countTags := func(user *model.User) int64 {
var ctx context.Context
if user != nil {
ctx = request.WithUser(log.NewContext(context.TODO()), *user)
} else {
ctx = context.Background()
}
tagRepo := NewTagRepository(ctx, GetDBXBuilder())
repo := tagRepo.(model.ResourceRepository)
count, err := repo.Count()
Expect(err).ToNot(HaveOccurred())
return count
}
Context("Admin User", func() {
It("should see all tags regardless of library", func() {
tags := readAllTags(&adminUser)
Expect(tags).To(HaveLen(3))
})
})
Context("Regular User with Limited Library Access", func() {
It("should only see tags from accessible libraries", func() {
tags := readAllTags(&regularUser)
// Should see rock (libraries 1,2) and pop (library 2), but not jazz (library 3)
Expect(tags).To(HaveLen(2))
})
It("should respect explicit library_id filters within accessible libraries", func() {
tags := readAllTags(&regularUser, rest.QueryOptions{
Filters: map[string]interface{}{"library_id": libraryID2},
})
// Should see only tags from library 2: pop and rock(lib2)
Expect(tags).To(HaveLen(2))
expectTagValues(tags, []string{tagValuePop, tagValueRock})
})
It("should not return tags when filtering by inaccessible library", func() {
tags := readAllTags(&regularUser, rest.QueryOptions{
Filters: map[string]interface{}{"library_id": libraryID3},
})
// Should return no tags since user can't access library 3
Expect(tags).To(HaveLen(0))
})
It("should filter by library 1 correctly", func() {
tags := readAllTags(&regularUser, rest.QueryOptions{
Filters: map[string]interface{}{"library_id": libraryID1},
})
// Should see only rock from library 1
Expect(tags).To(HaveLen(1))
Expect(tags[0].TagValue).To(Equal(tagValueRock))
})
})
Context("Headless Processes (No User Context)", func() {
It("should see all tags from all libraries when no user is in context", func() {
tags := readAllTags(nil) // nil = headless context
// Should see all tags from all libraries (no filtering applied)
Expect(tags).To(HaveLen(3))
expectTagValues(tags, []string{tagValueRock, tagValuePop, tagValueJazz})
})
It("should count all tags from all libraries when no user is in context", func() {
count := countTags(nil)
// Should count all tags from all libraries
Expect(count).To(Equal(int64(3)))
})
It("should calculate proper statistics from all libraries for headless processes", func() {
tags := readAllTags(nil)
// Find the rock tag (appears in libraries 1 and 2)
var rockTag *model.Tag
for _, tag := range tags {
if tag.TagValue == tagValueRock {
rockTag = &tag
break
}
}
Expect(rockTag).ToNot(BeNil())
// Should have stats from all libraries where rock appears
// Library 1: 5 albums, 20 songs
// Library 2: 1 album, 4 songs
// Total: 6 albums, 24 songs
Expect(rockTag.AlbumCount).To(Equal(6))
Expect(rockTag.SongCount).To(Equal(24))
})
It("should allow headless processes to apply explicit library_id filters", func() {
tags := readAllTags(nil, rest.QueryOptions{
Filters: map[string]interface{}{"library_id": libraryID3},
})
// Should see only jazz from library 3
Expect(tags).To(HaveLen(1))
Expect(tags[0].TagValue).To(Equal(tagValueJazz))
})
})
Context("Admin User with Explicit Library Filtering", func() {
It("should see all tags when no filter is applied", func() {
tags := readAllTags(&adminUser)
Expect(tags).To(HaveLen(3))
})
It("should respect explicit library_id filters", func() {
tags := readAllTags(&adminUser, rest.QueryOptions{
Filters: map[string]interface{}{"library_id": libraryID3},
})
// Should see only jazz from library 3
Expect(tags).To(HaveLen(1))
Expect(tags[0].TagValue).To(Equal(tagValueJazz))
})
It("should filter by library 2 correctly", func() {
tags := readAllTags(&adminUser, rest.QueryOptions{
Filters: map[string]interface{}{"library_id": libraryID2},
})
// Should see pop and rock from library 2
Expect(tags).To(HaveLen(2))
expectTagValues(tags, []string{tagValuePop, tagValueRock})
})
})
})
})

View File

@@ -0,0 +1,99 @@
package persistence
import (
"context"
"fmt"
"slices"
"time"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type tagRepository struct {
*baseTagRepository
}
func NewTagRepository(ctx context.Context, db dbx.Builder) model.TagRepository {
return &tagRepository{
baseTagRepository: newBaseTagRepository(ctx, db, nil), // nil = no filter, works with all tags
}
}
func (r *tagRepository) Add(libraryID int, tags ...model.Tag) error {
for chunk := range slices.Chunk(tags, 200) {
sq := Insert(r.tableName).Columns("id", "tag_name", "tag_value").
Suffix("on conflict (id) do nothing")
for _, t := range chunk {
sq = sq.Values(t.ID, t.TagName, t.TagValue)
}
_, err := r.executeSQL(sq)
if err != nil {
return err
}
// Create library_tag entries for library filtering
libSq := Insert("library_tag").Columns("tag_id", "library_id", "album_count", "media_file_count").
Suffix("on conflict (tag_id, library_id) do nothing")
for _, t := range chunk {
libSq = libSq.Values(t.ID, libraryID, 0, 0)
}
_, err = r.executeSQL(libSq)
if err != nil {
return fmt.Errorf("adding library_tag entries: %w", err)
}
}
return nil
}
// UpdateCounts updates the library_tag table with per-library statistics.
// Only genres are being updated for now.
func (r *tagRepository) UpdateCounts() error {
template := `
INSERT INTO library_tag (tag_id, library_id, %[1]s_count)
SELECT jt.value as tag_id, %[1]s.library_id, count(distinct %[1]s.id) as %[1]s_count
FROM %[1]s
JOIN json_tree(%[1]s.tags, '$.genre') as jt ON jt.atom IS NOT NULL AND jt.key = 'id'
JOIN tag ON tag.id = jt.value
GROUP BY jt.value, %[1]s.library_id
ON CONFLICT (tag_id, library_id)
DO UPDATE SET %[1]s_count = excluded.%[1]s_count;
`
for _, table := range []string{"album", "media_file"} {
start := time.Now()
query := Expr(fmt.Sprintf(template, table))
c, err := r.executeSQL(query)
log.Debug(r.ctx, "Updated library tag counts", "table", table, "elapsed", time.Since(start), "updated", c)
if err != nil {
return fmt.Errorf("updating %s library tag counts: %w", table, err)
}
}
return nil
}
func (r *tagRepository) purgeUnused() error {
del := Delete(r.tableName).Where(`
id not in (select jt.value
from album left join json_tree(album.tags, '$') as jt
where atom is not null
and key = 'id'
UNION
select jt.value
from media_file left join json_tree(media_file.tags, '$') as jt
where atom is not null
and key = 'id')
`)
c, err := r.executeSQL(del)
if err != nil {
return fmt.Errorf("error purging %s unused tags: %w", r.tableName, err)
}
if c > 0 {
log.Debug(r.ctx, "Purged unused tags", "totalDeleted", c, "table", r.tableName)
}
return err
}
var _ model.ResourceRepository = &tagRepository{}

View File

@@ -0,0 +1,311 @@
package persistence
import (
"context"
"slices"
"strings"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/conf/configtest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
var _ = Describe("TagRepository", func() {
var repo model.TagRepository
var restRepo model.ResourceRepository
var ctx context.Context
BeforeEach(func() {
DeferCleanup(configtest.SetupConfig())
ctx = request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", UserName: "johndoe", IsAdmin: true})
tagRepo := NewTagRepository(ctx, GetDBXBuilder())
repo = tagRepo
restRepo = tagRepo.(model.ResourceRepository)
// Clean the database before each test to ensure isolation
db := GetDBXBuilder()
_, err := db.NewQuery("DELETE FROM tag").Execute()
Expect(err).ToNot(HaveOccurred())
_, err = db.NewQuery("DELETE FROM library_tag").Execute()
Expect(err).ToNot(HaveOccurred())
// Ensure library 1 exists (if it doesn't already)
_, err = db.NewQuery("INSERT OR IGNORE INTO library (id, name, path, default_new_users) VALUES (1, 'Test Library', '/test', true)").Execute()
Expect(err).ToNot(HaveOccurred())
// Ensure the admin user has access to library 1
_, err = db.NewQuery("INSERT OR IGNORE INTO user_library (user_id, library_id) VALUES ('userid', 1)").Execute()
Expect(err).ToNot(HaveOccurred())
// Add comprehensive test data that covers all test scenarios
newTag := func(name, value string) model.Tag {
return model.Tag{ID: id.NewTagID(name, value), TagName: model.TagName(name), TagValue: value}
}
err = repo.Add(1,
// Genre tags
newTag("genre", "rock"),
newTag("genre", "pop"),
newTag("genre", "jazz"),
newTag("genre", "electronic"),
newTag("genre", "classical"),
newTag("genre", "ambient"),
newTag("genre", "techno"),
newTag("genre", "house"),
newTag("genre", "trance"),
newTag("genre", "Alternative Rock"),
newTag("genre", "Blues"),
newTag("genre", "Country"),
// Mood tags
newTag("mood", "happy"),
newTag("mood", "sad"),
newTag("mood", "energetic"),
newTag("mood", "calm"),
// Other tag types
newTag("instrument", "guitar"),
newTag("instrument", "piano"),
newTag("decade", "1980s"),
newTag("decade", "1990s"),
)
Expect(err).ToNot(HaveOccurred())
})
Describe("Add", func() {
It("should handle adding new tags", func() {
newTag := model.Tag{
ID: id.NewTagID("genre", "experimental"),
TagName: "genre",
TagValue: "experimental",
}
err := repo.Add(1, newTag)
Expect(err).ToNot(HaveOccurred())
// Verify tag was added
result, err := restRepo.Read(newTag.ID)
Expect(err).ToNot(HaveOccurred())
resultTag := result.(*model.Tag)
Expect(resultTag.TagValue).To(Equal("experimental"))
// Check count increased
count, err := restRepo.Count()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(21))) // 20 from dataset + 1 new
})
It("should handle duplicate tags gracefully", func() {
// Try to add a duplicate tag
duplicateTag := model.Tag{
ID: id.NewTagID("genre", "rock"), // This already exists
TagName: "genre",
TagValue: "rock",
}
count, err := restRepo.Count()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(20))) // Still 20 tags
err = repo.Add(1, duplicateTag)
Expect(err).ToNot(HaveOccurred()) // Should not error
// Count should remain the same
count, err = restRepo.Count()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(20))) // Still 20 tags
})
})
Describe("UpdateCounts", func() {
It("should update tag counts successfully", func() {
err := repo.UpdateCounts()
Expect(err).ToNot(HaveOccurred())
})
It("should handle empty database gracefully", func() {
// Clear the database first
db := GetDBXBuilder()
_, err := db.NewQuery("DELETE FROM tag").Execute()
Expect(err).ToNot(HaveOccurred())
err = repo.UpdateCounts()
Expect(err).ToNot(HaveOccurred())
})
It("should handle albums with non-existent tag IDs in JSON gracefully", func() {
// Regression test for foreign key constraint error
// Create an album with tag IDs in JSON that don't exist in tag table
db := GetDBXBuilder()
// First, create a non-existent tag ID (this simulates tags in JSON that aren't in tag table)
nonExistentTagID := id.NewTagID("genre", "nonexistent-genre")
// Create album with JSON containing the non-existent tag ID
albumWithBadTags := `{"genre":[{"id":"` + nonExistentTagID + `","value":"nonexistent-genre"}]}`
// Insert album directly into database with the problematic JSON
_, err := db.NewQuery("INSERT INTO album (id, name, library_id, tags) VALUES ({:id}, {:name}, {:lib}, {:tags})").
Bind(dbx.Params{
"id": "test-album-bad-tags",
"name": "Album With Bad Tags",
"lib": 1,
"tags": albumWithBadTags,
}).Execute()
Expect(err).ToNot(HaveOccurred())
// This should not fail with foreign key constraint error
err = repo.UpdateCounts()
Expect(err).ToNot(HaveOccurred())
// Cleanup
_, err = db.NewQuery("DELETE FROM album WHERE id = {:id}").
Bind(dbx.Params{"id": "test-album-bad-tags"}).Execute()
Expect(err).ToNot(HaveOccurred())
})
It("should handle media files with non-existent tag IDs in JSON gracefully", func() {
// Regression test for foreign key constraint error with media files
db := GetDBXBuilder()
// Create a non-existent tag ID
nonExistentTagID := id.NewTagID("genre", "another-nonexistent-genre")
// Create media file with JSON containing the non-existent tag ID
mediaFileWithBadTags := `{"genre":[{"id":"` + nonExistentTagID + `","value":"another-nonexistent-genre"}]}`
// Insert media file directly into database with the problematic JSON
_, err := db.NewQuery("INSERT INTO media_file (id, title, library_id, tags) VALUES ({:id}, {:title}, {:lib}, {:tags})").
Bind(dbx.Params{
"id": "test-media-bad-tags",
"title": "Media File With Bad Tags",
"lib": 1,
"tags": mediaFileWithBadTags,
}).Execute()
Expect(err).ToNot(HaveOccurred())
// This should not fail with foreign key constraint error
err = repo.UpdateCounts()
Expect(err).ToNot(HaveOccurred())
// Cleanup
_, err = db.NewQuery("DELETE FROM media_file WHERE id = {:id}").
Bind(dbx.Params{"id": "test-media-bad-tags"}).Execute()
Expect(err).ToNot(HaveOccurred())
})
})
Describe("Count", func() {
It("should return correct count of tags", func() {
count, err := restRepo.Count()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(int64(20))) // From the test dataset
})
})
Describe("Read", func() {
It("should return existing tag", func() {
rockID := id.NewTagID("genre", "rock")
result, err := restRepo.Read(rockID)
Expect(err).ToNot(HaveOccurred())
resultTag := result.(*model.Tag)
Expect(resultTag.ID).To(Equal(rockID))
Expect(resultTag.TagName).To(Equal(model.TagName("genre")))
Expect(resultTag.TagValue).To(Equal("rock"))
})
It("should return error for non-existent tag", func() {
_, err := restRepo.Read("non-existent-id")
Expect(err).To(HaveOccurred())
})
})
Describe("ReadAll", func() {
It("should return all tags from dataset", func() {
result, err := restRepo.ReadAll()
Expect(err).ToNot(HaveOccurred())
tags := result.(model.TagList)
Expect(tags).To(HaveLen(20))
})
It("should filter tags by partial value correctly", func() {
options := rest.QueryOptions{
Filters: map[string]interface{}{"name": "%rock%"}, // Tags containing 'rock'
}
result, err := restRepo.ReadAll(options)
Expect(err).ToNot(HaveOccurred())
tags := result.(model.TagList)
Expect(tags).To(HaveLen(2)) // "rock" and "Alternative Rock"
// Verify all returned tags contain 'rock' in their value
for _, tag := range tags {
Expect(strings.ToLower(tag.TagValue)).To(ContainSubstring("rock"))
}
})
It("should filter tags by partial value using LIKE", func() {
options := rest.QueryOptions{
Filters: map[string]interface{}{"name": "%e%"}, // Tags containing 'e'
}
result, err := restRepo.ReadAll(options)
Expect(err).ToNot(HaveOccurred())
tags := result.(model.TagList)
Expect(tags).To(HaveLen(8)) // electronic, house, trance, energetic, Blues, decade x2, Alternative Rock
// Verify all returned tags contain 'e' in their value
for _, tag := range tags {
Expect(strings.ToLower(tag.TagValue)).To(ContainSubstring("e"))
}
})
It("should sort tags by value ascending", func() {
options := rest.QueryOptions{
Filters: map[string]interface{}{"name": "%r%"}, // Tags containing 'r'
Sort: "name",
Order: "asc",
}
result, err := restRepo.ReadAll(options)
Expect(err).ToNot(HaveOccurred())
tags := result.(model.TagList)
Expect(tags).To(HaveLen(7))
Expect(slices.IsSortedFunc(tags, func(a, b model.Tag) int {
return strings.Compare(strings.ToLower(a.TagValue), strings.ToLower(b.TagValue))
}))
})
It("should sort tags by value descending", func() {
options := rest.QueryOptions{
Filters: map[string]interface{}{"name": "%r%"}, // Tags containing 'r'
Sort: "name",
Order: "desc",
}
result, err := restRepo.ReadAll(options)
Expect(err).ToNot(HaveOccurred())
tags := result.(model.TagList)
Expect(tags).To(HaveLen(7))
Expect(slices.IsSortedFunc(tags, func(a, b model.Tag) int {
return strings.Compare(strings.ToLower(b.TagValue), strings.ToLower(a.TagValue)) // Descending order
}))
})
})
Describe("EntityName", func() {
It("should return correct entity name", func() {
name := restRepo.EntityName()
Expect(name).To(Equal("tag"))
})
})
Describe("NewInstance", func() {
It("should return new tag instance", func() {
instance := restRepo.NewInstance()
Expect(instance).To(BeAssignableToTypeOf(model.Tag{}))
})
})
})

View File

@@ -0,0 +1,112 @@
package persistence
import (
"context"
"errors"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type transcodingRepository struct {
sqlRepository
}
func NewTranscodingRepository(ctx context.Context, db dbx.Builder) model.TranscodingRepository {
r := &transcodingRepository{}
r.ctx = ctx
r.db = db
r.registerModel(&model.Transcoding{}, nil)
return r
}
func (r *transcodingRepository) Get(id string) (*model.Transcoding, error) {
sel := r.newSelect().Columns("*").Where(Eq{"id": id})
var res model.Transcoding
err := r.queryOne(sel, &res)
return &res, err
}
func (r *transcodingRepository) CountAll(qo ...model.QueryOptions) (int64, error) {
return r.count(Select(), qo...)
}
func (r *transcodingRepository) FindByFormat(format string) (*model.Transcoding, error) {
sel := r.newSelect().Columns("*").Where(Eq{"target_format": format})
var res model.Transcoding
err := r.queryOne(sel, &res)
return &res, err
}
func (r *transcodingRepository) Put(t *model.Transcoding) error {
if !loggedUser(r.ctx).IsAdmin {
return rest.ErrPermissionDenied
}
_, err := r.put(t.ID, t)
return err
}
func (r *transcodingRepository) Count(options ...rest.QueryOptions) (int64, error) {
return r.count(Select(), r.parseRestOptions(r.ctx, options...))
}
func (r *transcodingRepository) Read(id string) (interface{}, error) {
return r.Get(id)
}
func (r *transcodingRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
sel := r.newSelect(r.parseRestOptions(r.ctx, options...)).Columns("*")
res := model.Transcodings{}
err := r.queryAll(sel, &res)
return res, err
}
func (r *transcodingRepository) EntityName() string {
return "transcoding"
}
func (r *transcodingRepository) NewInstance() interface{} {
return &model.Transcoding{}
}
func (r *transcodingRepository) Save(entity interface{}) (string, error) {
if !loggedUser(r.ctx).IsAdmin {
return "", rest.ErrPermissionDenied
}
t := entity.(*model.Transcoding)
id, err := r.put(t.ID, t)
if errors.Is(err, model.ErrNotFound) {
return "", rest.ErrNotFound
}
return id, err
}
func (r *transcodingRepository) Update(id string, entity interface{}, cols ...string) error {
if !loggedUser(r.ctx).IsAdmin {
return rest.ErrPermissionDenied
}
t := entity.(*model.Transcoding)
t.ID = id
_, err := r.put(id, t)
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
func (r *transcodingRepository) Delete(id string) error {
if !loggedUser(r.ctx).IsAdmin {
return rest.ErrPermissionDenied
}
err := r.delete(Eq{"id": id})
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
var _ model.TranscodingRepository = (*transcodingRepository)(nil)
var _ rest.Repository = (*transcodingRepository)(nil)
var _ rest.Persistable = (*transcodingRepository)(nil)

View File

@@ -0,0 +1,96 @@
package persistence
import (
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TranscodingRepository", func() {
var repo model.TranscodingRepository
var adminRepo model.TranscodingRepository
BeforeEach(func() {
ctx := log.NewContext(GinkgoT().Context())
ctx = request.WithUser(ctx, regularUser)
repo = NewTranscodingRepository(ctx, GetDBXBuilder())
adminCtx := log.NewContext(GinkgoT().Context())
adminCtx = request.WithUser(adminCtx, adminUser)
adminRepo = NewTranscodingRepository(adminCtx, GetDBXBuilder())
})
AfterEach(func() {
// Clean up any transcoding created during the tests
tc, err := adminRepo.FindByFormat("test_format")
if err == nil {
err = adminRepo.(*transcodingRepository).Delete(tc.ID)
Expect(err).ToNot(HaveOccurred())
}
})
Describe("Admin User", func() {
It("creates a new transcoding", func() {
base, err := adminRepo.CountAll()
Expect(err).ToNot(HaveOccurred())
err = adminRepo.Put(&model.Transcoding{ID: "new", Name: "new", TargetFormat: "test_format", DefaultBitRate: 320, Command: "ffmpeg"})
Expect(err).ToNot(HaveOccurred())
count, err := adminRepo.CountAll()
Expect(err).ToNot(HaveOccurred())
Expect(count).To(Equal(base + 1))
})
It("updates an existing transcoding", func() {
tr := &model.Transcoding{ID: "upd", Name: "old", TargetFormat: "test_format", DefaultBitRate: 100, Command: "ffmpeg"}
Expect(adminRepo.Put(tr)).To(Succeed())
tr.Name = "updated"
err := adminRepo.Put(tr)
Expect(err).ToNot(HaveOccurred())
res, err := adminRepo.FindByFormat("test_format")
Expect(err).ToNot(HaveOccurred())
Expect(res.Name).To(Equal("updated"))
})
It("deletes a transcoding", func() {
err := adminRepo.Put(&model.Transcoding{ID: "to-delete", Name: "temp", TargetFormat: "test_format", DefaultBitRate: 256, Command: "ffmpeg"})
Expect(err).ToNot(HaveOccurred())
err = adminRepo.(*transcodingRepository).Delete("to-delete")
Expect(err).ToNot(HaveOccurred())
_, err = adminRepo.Get("to-delete")
Expect(err).To(MatchError(model.ErrNotFound))
})
})
Describe("Regular User", func() {
It("fails to create", func() {
err := repo.Put(&model.Transcoding{ID: "bad", Name: "bad", TargetFormat: "test_format", DefaultBitRate: 64, Command: "ffmpeg"})
Expect(err).To(Equal(rest.ErrPermissionDenied))
})
It("fails to update", func() {
tr := &model.Transcoding{ID: "updreg", Name: "old", TargetFormat: "test_format", DefaultBitRate: 64, Command: "ffmpeg"}
Expect(adminRepo.Put(tr)).To(Succeed())
tr.Name = "bad"
err := repo.Put(tr)
Expect(err).To(Equal(rest.ErrPermissionDenied))
//_ = adminRepo.(*transcodingRepository).Delete("updreg")
})
It("fails to delete", func() {
tr := &model.Transcoding{ID: "delreg", Name: "temp", TargetFormat: "test_format", DefaultBitRate: 64, Command: "ffmpeg"}
Expect(adminRepo.Put(tr)).To(Succeed())
err := repo.(*transcodingRepository).Delete("delreg")
Expect(err).To(Equal(rest.ErrPermissionDenied))
//_ = adminRepo.(*transcodingRepository).Delete("delreg")
})
})
})

View File

@@ -0,0 +1,63 @@
package persistence
import (
"context"
"errors"
. "github.com/Masterminds/squirrel"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type userPropsRepository struct {
sqlRepository
}
func NewUserPropsRepository(ctx context.Context, db dbx.Builder) model.UserPropsRepository {
r := &userPropsRepository{}
r.ctx = ctx
r.db = db
r.tableName = "user_props"
return r
}
func (r userPropsRepository) Put(userId, key string, value string) error {
update := Update(r.tableName).Set("value", value).Where(And{Eq{"user_id": userId}, Eq{"key": key}})
count, err := r.executeSQL(update)
if err != nil {
return err
}
if count > 0 {
return nil
}
insert := Insert(r.tableName).Columns("user_id", "key", "value").Values(userId, key, value)
_, err = r.executeSQL(insert)
return err
}
func (r userPropsRepository) Get(userId, key string) (string, error) {
sel := Select("value").From(r.tableName).Where(And{Eq{"user_id": userId}, Eq{"key": key}})
resp := struct {
Value string
}{}
err := r.queryOne(sel, &resp)
if err != nil {
return "", err
}
return resp.Value, nil
}
func (r userPropsRepository) DefaultGet(userId, key string, defaultValue string) (string, error) {
value, err := r.Get(userId, key)
if errors.Is(err, model.ErrNotFound) {
return defaultValue, nil
}
if err != nil {
return defaultValue, err
}
return value, nil
}
func (r userPropsRepository) Delete(userId, key string) error {
return r.delete(And{Eq{"user_id": userId}, Eq{"key": key}})
}

View File

@@ -0,0 +1,475 @@
package persistence
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/consts"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/utils"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
type userRepository struct {
sqlRepository
}
type dbUser struct {
*model.User `structs:",flatten"`
LibrariesJSON string `structs:"-" json:"-"`
}
func (u *dbUser) PostScan() error {
if u.LibrariesJSON != "" {
if err := json.Unmarshal([]byte(u.LibrariesJSON), &u.User.Libraries); err != nil {
return fmt.Errorf("parsing user libraries from db: %w", err)
}
}
return nil
}
type dbUsers []dbUser
func (us dbUsers) toModels() model.Users {
return slice.Map(us, func(u dbUser) model.User { return *u.User })
}
var (
once sync.Once
encKey []byte
)
func NewUserRepository(ctx context.Context, db dbx.Builder) model.UserRepository {
r := &userRepository{}
r.ctx = ctx
r.db = db
r.tableName = "user"
r.registerModel(&model.User{}, map[string]filterFunc{
"id": idFilter(r.tableName),
"password": invalidFilter(ctx),
"name": r.withTableName(startsWithFilter),
})
once.Do(func() {
_ = r.initPasswordEncryptionKey()
})
return r
}
// selectUserWithLibraries returns a SelectBuilder that includes library information
func (r *userRepository) selectUserWithLibraries(options ...model.QueryOptions) SelectBuilder {
return r.newSelect(options...).
Columns(`user.*`,
`COALESCE(json_group_array(json_object(
'id', library.id,
'name', library.name,
'path', library.path,
'remote_path', library.remote_path,
'last_scan_at', library.last_scan_at,
'last_scan_started_at', library.last_scan_started_at,
'full_scan_in_progress', library.full_scan_in_progress,
'updated_at', library.updated_at,
'created_at', library.created_at
)) FILTER (WHERE library.id IS NOT NULL), '[]') AS libraries_json`).
LeftJoin("user_library ul ON user.id = ul.user_id").
LeftJoin("library ON ul.library_id = library.id").
GroupBy("user.id")
}
func (r *userRepository) CountAll(qo ...model.QueryOptions) (int64, error) {
return r.count(Select(), qo...)
}
func (r *userRepository) Get(id string) (*model.User, error) {
sel := r.selectUserWithLibraries().Where(Eq{"user.id": id})
var res dbUser
err := r.queryOne(sel, &res)
if err != nil {
return nil, err
}
return res.User, nil
}
func (r *userRepository) GetAll(options ...model.QueryOptions) (model.Users, error) {
sel := r.selectUserWithLibraries(options...)
var res dbUsers
err := r.queryAll(sel, &res)
if err != nil {
return nil, err
}
return res.toModels(), nil
}
func (r *userRepository) Put(u *model.User) error {
if u.ID == "" {
u.ID = id.NewRandom()
}
u.UpdatedAt = time.Now()
if u.NewPassword != "" {
_ = r.encryptPassword(u)
}
values, err := toSQLArgs(*u)
if err != nil {
return fmt.Errorf("error converting user to SQL args: %w", err)
}
delete(values, "current_password")
// Save/update the user
update := Update(r.tableName).Where(Eq{"id": u.ID}).SetMap(values)
count, err := r.executeSQL(update)
if err != nil {
return err
}
isNewUser := count == 0
if isNewUser {
values["created_at"] = time.Now()
insert := Insert(r.tableName).SetMap(values)
_, err = r.executeSQL(insert)
if err != nil {
return err
}
}
// Auto-assign all libraries to admin users in a single SQL operation
if u.IsAdmin {
sql := Expr(
"INSERT OR IGNORE INTO user_library (user_id, library_id) SELECT ?, id FROM library",
u.ID,
)
if _, err := r.executeSQL(sql); err != nil {
return fmt.Errorf("failed to assign all libraries to admin user: %w", err)
}
} else if isNewUser { // Only for new regular users
// Auto-assign default libraries to new regular users
sql := Expr(
"INSERT OR IGNORE INTO user_library (user_id, library_id) SELECT ?, id FROM library WHERE default_new_users = true",
u.ID,
)
if _, err := r.executeSQL(sql); err != nil {
return fmt.Errorf("failed to assign default libraries to new user: %w", err)
}
}
return nil
}
func (r *userRepository) FindFirstAdmin() (*model.User, error) {
sel := r.selectUserWithLibraries(model.QueryOptions{Sort: "updated_at", Max: 1}).Where(Eq{"user.is_admin": true})
var usr dbUser
err := r.queryOne(sel, &usr)
if err != nil {
return nil, err
}
return usr.User, nil
}
func (r *userRepository) FindByUsername(username string) (*model.User, error) {
sel := r.selectUserWithLibraries().Where(Expr("user.user_name = ? COLLATE NOCASE", username))
var usr dbUser
err := r.queryOne(sel, &usr)
if err != nil {
return nil, err
}
return usr.User, nil
}
func (r *userRepository) FindByUsernameWithPassword(username string) (*model.User, error) {
usr, err := r.FindByUsername(username)
if err != nil {
return nil, err
}
_ = r.decryptPassword(usr)
return usr, nil
}
func (r *userRepository) UpdateLastLoginAt(id string) error {
upd := Update(r.tableName).Where(Eq{"id": id}).Set("last_login_at", time.Now())
_, err := r.executeSQL(upd)
return err
}
func (r *userRepository) UpdateLastAccessAt(id string) error {
now := time.Now()
upd := Update(r.tableName).Where(Eq{"id": id}).Set("last_access_at", now)
_, err := r.executeSQL(upd)
return err
}
func (r *userRepository) Count(options ...rest.QueryOptions) (int64, error) {
usr := loggedUser(r.ctx)
if !usr.IsAdmin {
return 0, rest.ErrPermissionDenied
}
return r.CountAll(r.parseRestOptions(r.ctx, options...))
}
func (r *userRepository) Read(id string) (any, error) {
usr := loggedUser(r.ctx)
if !usr.IsAdmin && usr.ID != id {
return nil, rest.ErrPermissionDenied
}
usr, err := r.Get(id)
if errors.Is(err, model.ErrNotFound) {
return nil, rest.ErrNotFound
}
return usr, err
}
func (r *userRepository) ReadAll(options ...rest.QueryOptions) (any, error) {
usr := loggedUser(r.ctx)
if !usr.IsAdmin {
return nil, rest.ErrPermissionDenied
}
return r.GetAll(r.parseRestOptions(r.ctx, options...))
}
func (r *userRepository) EntityName() string {
return "user"
}
func (r *userRepository) NewInstance() any {
return &model.User{}
}
func (r *userRepository) Save(entity any) (string, error) {
usr := loggedUser(r.ctx)
if !usr.IsAdmin {
return "", rest.ErrPermissionDenied
}
u := entity.(*model.User)
if err := validateUsernameUnique(r, u); err != nil {
return "", err
}
err := r.Put(u)
if err != nil {
return "", err
}
return u.ID, err
}
func (r *userRepository) Update(id string, entity any, _ ...string) error {
u := entity.(*model.User)
u.ID = id
usr := loggedUser(r.ctx)
if !usr.IsAdmin && usr.ID != u.ID {
return rest.ErrPermissionDenied
}
if !usr.IsAdmin {
if !conf.Server.EnableUserEditing {
return rest.ErrPermissionDenied
}
u.IsAdmin = false
u.UserName = usr.UserName
}
// Decrypt the user's existing password before validating. This is required otherwise the existing password entered by the user will never match.
if err := r.decryptPassword(usr); err != nil {
return err
}
if err := validatePasswordChange(u, usr); err != nil {
return err
}
if err := validateUsernameUnique(r, u); err != nil {
return err
}
err := r.Put(u)
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
func validatePasswordChange(newUser *model.User, logged *model.User) error {
err := &rest.ValidationError{Errors: map[string]string{}}
if logged.IsAdmin && newUser.ID != logged.ID {
return nil
}
if newUser.NewPassword == "" {
if newUser.CurrentPassword == "" {
return nil
}
err.Errors["password"] = "ra.validation.required"
}
if !strings.HasPrefix(logged.Password, consts.PasswordAutogenPrefix) {
if newUser.CurrentPassword == "" {
err.Errors["currentPassword"] = "ra.validation.required"
}
if newUser.CurrentPassword != logged.Password {
err.Errors["currentPassword"] = "ra.validation.passwordDoesNotMatch"
}
}
if len(err.Errors) > 0 {
return err
}
return nil
}
func validateUsernameUnique(r model.UserRepository, u *model.User) error {
usr, err := r.FindByUsername(u.UserName)
if errors.Is(err, model.ErrNotFound) {
return nil
}
if err != nil {
return err
}
if usr.ID != u.ID {
return &rest.ValidationError{Errors: map[string]string{"userName": "ra.validation.unique"}}
}
return nil
}
func (r *userRepository) Delete(id string) error {
usr := loggedUser(r.ctx)
if !usr.IsAdmin {
return rest.ErrPermissionDenied
}
err := r.delete(Eq{"id": id})
if errors.Is(err, model.ErrNotFound) {
return rest.ErrNotFound
}
return err
}
func keyTo32Bytes(input string) []byte {
data := sha256.Sum256([]byte(input))
return data[0:]
}
func (r *userRepository) initPasswordEncryptionKey() error {
encKey = keyTo32Bytes(consts.DefaultEncryptionKey)
if conf.Server.PasswordEncryptionKey == "" {
return nil
}
key := keyTo32Bytes(conf.Server.PasswordEncryptionKey)
keySum := fmt.Sprintf("%x", sha256.Sum256(key))
props := NewPropertyRepository(r.ctx, r.db)
savedKeySum, err := props.Get(consts.PasswordsEncryptedKey)
// If passwords are already encrypted
if err == nil {
if savedKeySum != keySum {
log.Error("Password Encryption Key changed! Users won't be able to login!")
return errors.New("passwordEncryptionKey changed")
}
encKey = key
return nil
}
// if not, try to re-encrypt all current passwords with new encryption key,
// assuming they were encrypted with the DefaultEncryptionKey
sql := r.newSelect().Columns("id", "user_name", "password")
users := model.Users{}
err = r.queryAll(sql, &users)
if err != nil {
log.Error("Could not encrypt all passwords", err)
return err
}
log.Warn("New PasswordEncryptionKey set. Encrypting all passwords", "numUsers", len(users))
if err = r.decryptAllPasswords(users); err != nil {
return err
}
encKey = key
for i := range users {
u := users[i]
u.NewPassword = u.Password
if err := r.encryptPassword(&u); err == nil {
upd := Update(r.tableName).Set("password", u.NewPassword).Where(Eq{"id": u.ID})
_, err = r.executeSQL(upd)
if err != nil {
log.Error("Password NOT encrypted! This may cause problems!", "user", u.UserName, "id", u.ID, err)
} else {
log.Warn("Password encrypted successfully", "user", u.UserName, "id", u.ID)
}
}
}
err = props.Put(consts.PasswordsEncryptedKey, keySum)
if err != nil {
log.Error("Could not flag passwords as encrypted. It will cause login errors", err)
return err
}
return nil
}
// encrypts u.NewPassword
func (r *userRepository) encryptPassword(u *model.User) error {
encPassword, err := utils.Encrypt(r.ctx, encKey, u.NewPassword)
if err != nil {
log.Error(r.ctx, "Error encrypting user's password", "user", u.UserName, err)
return err
}
u.NewPassword = encPassword
return nil
}
// decrypts u.Password
func (r *userRepository) decryptPassword(u *model.User) error {
plaintext, err := utils.Decrypt(r.ctx, encKey, u.Password)
if err != nil {
log.Error(r.ctx, "Error decrypting user's password", "user", u.UserName, err)
return err
}
u.Password = plaintext
return nil
}
func (r *userRepository) decryptAllPasswords(users model.Users) error {
for i := range users {
if err := r.decryptPassword(&users[i]); err != nil {
return err
}
}
return nil
}
// Library association methods
func (r *userRepository) GetUserLibraries(userID string) (model.Libraries, error) {
sel := Select("l.*").
From("library l").
Join("user_library ul ON l.id = ul.library_id").
Where(Eq{"ul.user_id": userID}).
OrderBy("l.name")
var res model.Libraries
err := r.queryAll(sel, &res)
return res, err
}
func (r *userRepository) SetUserLibraries(userID string, libraryIDs []int) error {
// Remove existing associations
delSql := Delete("user_library").Where(Eq{"user_id": userID})
if _, err := r.executeSQL(delSql); err != nil {
return err
}
// Add new associations
if len(libraryIDs) > 0 {
insert := Insert("user_library").Columns("user_id", "library_id")
for _, libID := range libraryIDs {
insert = insert.Values(userID, libID)
}
_, err := r.executeSQL(insert)
return err
}
return nil
}
var _ model.UserRepository = (*userRepository)(nil)
var _ rest.Repository = (*userRepository)(nil)
var _ rest.Persistable = (*userRepository)(nil)

View File

@@ -0,0 +1,573 @@
package persistence
import (
"context"
"errors"
"slices"
"github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/consts"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/id"
"github.com/navidrome/navidrome/tests"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("UserRepository", func() {
var repo model.UserRepository
BeforeEach(func() {
repo = NewUserRepository(log.NewContext(GinkgoT().Context()), GetDBXBuilder())
})
Describe("Put/Get/FindByUsername", func() {
usr := model.User{
ID: "123",
UserName: "AdMiN",
Name: "Admin",
Email: "admin@admin.com",
NewPassword: "wordpass",
IsAdmin: true,
}
It("saves the user to the DB", func() {
Expect(repo.Put(&usr)).To(BeNil())
})
It("returns the newly created user", func() {
actual, err := repo.Get("123")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Name).To(Equal("Admin"))
})
It("find the user by case-insensitive username", func() {
actual, err := repo.FindByUsername("aDmIn")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Name).To(Equal("Admin"))
})
It("find the user by username and decrypts the password", func() {
actual, err := repo.FindByUsernameWithPassword("aDmIn")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Name).To(Equal("Admin"))
Expect(actual.Password).To(Equal("wordpass"))
})
It("updates the name and keep the same password", func() {
usr.Name = "Jane Doe"
usr.NewPassword = ""
Expect(repo.Put(&usr)).To(BeNil())
actual, err := repo.FindByUsernameWithPassword("admin")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Name).To(Equal("Jane Doe"))
Expect(actual.Password).To(Equal("wordpass"))
})
It("updates password if specified", func() {
usr.NewPassword = "newpass"
Expect(repo.Put(&usr)).To(BeNil())
actual, err := repo.FindByUsernameWithPassword("admin")
Expect(err).ToNot(HaveOccurred())
Expect(actual.Password).To(Equal("newpass"))
})
})
Describe("validatePasswordChange", func() {
var loggedUser *model.User
BeforeEach(func() {
loggedUser = &model.User{ID: "1", UserName: "logan"}
})
It("does nothing if passwords are not specified", func() {
user := &model.User{ID: "2", UserName: "johndoe"}
err := validatePasswordChange(user, loggedUser)
Expect(err).ToNot(HaveOccurred())
})
Context("Autogenerated password (used with Reverse Proxy Authentication)", func() {
var user model.User
BeforeEach(func() {
loggedUser.IsAdmin = false
loggedUser.Password = consts.PasswordAutogenPrefix + id.NewRandom()
})
It("does nothing if passwords are not specified", func() {
user = *loggedUser
err := validatePasswordChange(&user, loggedUser)
Expect(err).ToNot(HaveOccurred())
})
It("does not requires currentPassword for regular user", func() {
user = *loggedUser
user.CurrentPassword = ""
user.NewPassword = "new"
err := validatePasswordChange(&user, loggedUser)
Expect(err).ToNot(HaveOccurred())
})
It("does not requires currentPassword for admin", func() {
loggedUser.IsAdmin = true
user = *loggedUser
user.CurrentPassword = ""
user.NewPassword = "new"
err := validatePasswordChange(&user, loggedUser)
Expect(err).ToNot(HaveOccurred())
})
})
Context("Logged User is admin", func() {
BeforeEach(func() {
loggedUser.IsAdmin = true
})
It("can change other user's passwords without currentPassword", func() {
user := &model.User{ID: "2", UserName: "johndoe"}
user.NewPassword = "new"
err := validatePasswordChange(user, loggedUser)
Expect(err).ToNot(HaveOccurred())
})
It("requires currentPassword to change its own", func() {
user := *loggedUser
user.NewPassword = "new"
err := validatePasswordChange(&user, loggedUser)
var verr *rest.ValidationError
errors.As(err, &verr)
Expect(verr.Errors).To(HaveLen(1))
Expect(verr.Errors).To(HaveKeyWithValue("currentPassword", "ra.validation.required"))
})
It("does not allow to change password to empty string", func() {
loggedUser.Password = "abc123"
user := *loggedUser
user.CurrentPassword = "abc123"
err := validatePasswordChange(&user, loggedUser)
var verr *rest.ValidationError
errors.As(err, &verr)
Expect(verr.Errors).To(HaveLen(1))
Expect(verr.Errors).To(HaveKeyWithValue("password", "ra.validation.required"))
})
It("fails if currentPassword does not match", func() {
loggedUser.Password = "abc123"
user := *loggedUser
user.CurrentPassword = "current"
user.NewPassword = "new"
err := validatePasswordChange(&user, loggedUser)
var verr *rest.ValidationError
errors.As(err, &verr)
Expect(verr.Errors).To(HaveLen(1))
Expect(verr.Errors).To(HaveKeyWithValue("currentPassword", "ra.validation.passwordDoesNotMatch"))
})
It("can change own password if requirements are met", func() {
loggedUser.Password = "abc123"
user := *loggedUser
user.CurrentPassword = "abc123"
user.NewPassword = "new"
err := validatePasswordChange(&user, loggedUser)
Expect(err).ToNot(HaveOccurred())
})
})
Context("Logged User is a regular user", func() {
BeforeEach(func() {
loggedUser.IsAdmin = false
})
It("requires currentPassword", func() {
user := *loggedUser
user.NewPassword = "new"
err := validatePasswordChange(&user, loggedUser)
var verr *rest.ValidationError
errors.As(err, &verr)
Expect(verr.Errors).To(HaveLen(1))
Expect(verr.Errors).To(HaveKeyWithValue("currentPassword", "ra.validation.required"))
})
It("does not allow to change password to empty string", func() {
loggedUser.Password = "abc123"
user := *loggedUser
user.CurrentPassword = "abc123"
err := validatePasswordChange(&user, loggedUser)
var verr *rest.ValidationError
errors.As(err, &verr)
Expect(verr.Errors).To(HaveLen(1))
Expect(verr.Errors).To(HaveKeyWithValue("password", "ra.validation.required"))
})
It("fails if currentPassword does not match", func() {
loggedUser.Password = "abc123"
user := *loggedUser
user.CurrentPassword = "current"
user.NewPassword = "new"
err := validatePasswordChange(&user, loggedUser)
var verr *rest.ValidationError
errors.As(err, &verr)
Expect(verr.Errors).To(HaveLen(1))
Expect(verr.Errors).To(HaveKeyWithValue("currentPassword", "ra.validation.passwordDoesNotMatch"))
})
It("can change own password if requirements are met", func() {
loggedUser.Password = "abc123"
user := *loggedUser
user.CurrentPassword = "abc123"
user.NewPassword = "new"
err := validatePasswordChange(&user, loggedUser)
Expect(err).ToNot(HaveOccurred())
})
})
})
Describe("validateUsernameUnique", func() {
var repo *tests.MockedUserRepo
var existingUser *model.User
BeforeEach(func() {
existingUser = &model.User{ID: "1", UserName: "johndoe"}
repo = tests.CreateMockUserRepo()
err := repo.Put(existingUser)
Expect(err).ToNot(HaveOccurred())
})
It("allows unique usernames", func() {
var newUser = &model.User{ID: "2", UserName: "unique_username"}
err := validateUsernameUnique(repo, newUser)
Expect(err).ToNot(HaveOccurred())
})
It("returns ValidationError if username already exists", func() {
var newUser = &model.User{ID: "2", UserName: "johndoe"}
err := validateUsernameUnique(repo, newUser)
var verr *rest.ValidationError
isValidationError := errors.As(err, &verr)
Expect(isValidationError).To(BeTrue())
Expect(verr.Errors).To(HaveKeyWithValue("userName", "ra.validation.unique"))
})
It("returns generic error if repository call fails", func() {
repo.Error = errors.New("fake error")
var newUser = &model.User{ID: "2", UserName: "newuser"}
err := validateUsernameUnique(repo, newUser)
Expect(err).To(MatchError("fake error"))
})
})
Describe("Library Association Methods", func() {
var userID string
var library1, library2 model.Library
BeforeEach(func() {
// Create a test user first to satisfy foreign key constraints
testUser := model.User{
ID: "test-user-id",
UserName: "testuser",
Name: "Test User",
Email: "test@example.com",
NewPassword: "password",
IsAdmin: false,
}
Expect(repo.Put(&testUser)).To(BeNil())
userID = testUser.ID
library1 = model.Library{ID: 0, Name: "Library 500", Path: "/path/500"}
library2 = model.Library{ID: 0, Name: "Library 501", Path: "/path/501"}
// Create test libraries
libRepo := NewLibraryRepository(log.NewContext(context.TODO()), GetDBXBuilder())
Expect(libRepo.Put(&library1)).To(BeNil())
Expect(libRepo.Put(&library2)).To(BeNil())
})
AfterEach(func() {
// Clean up user-library associations to ensure test isolation
_ = repo.SetUserLibraries(userID, []int{})
// Clean up test libraries to ensure isolation between test groups
libRepo := NewLibraryRepository(log.NewContext(context.TODO()), GetDBXBuilder())
_ = libRepo.(*libraryRepository).delete(squirrel.Eq{"id": []int{library1.ID, library2.ID}})
})
Describe("GetUserLibraries", func() {
It("returns empty list when user has no library associations", func() {
libraries, err := repo.GetUserLibraries("non-existent-user")
Expect(err).ToNot(HaveOccurred())
Expect(libraries).To(HaveLen(0))
})
It("returns user's associated libraries", func() {
err := repo.SetUserLibraries(userID, []int{library1.ID, library2.ID})
Expect(err).ToNot(HaveOccurred())
libraries, err := repo.GetUserLibraries(userID)
Expect(err).ToNot(HaveOccurred())
Expect(libraries).To(HaveLen(2))
libIDs := []int{libraries[0].ID, libraries[1].ID}
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
})
})
Describe("SetUserLibraries", func() {
It("sets user's library associations", func() {
libraryIDs := []int{library1.ID, library2.ID}
err := repo.SetUserLibraries(userID, libraryIDs)
Expect(err).ToNot(HaveOccurred())
libraries, err := repo.GetUserLibraries(userID)
Expect(err).ToNot(HaveOccurred())
Expect(libraries).To(HaveLen(2))
})
It("replaces existing associations", func() {
// Set initial associations
err := repo.SetUserLibraries(userID, []int{library1.ID, library2.ID})
Expect(err).ToNot(HaveOccurred())
// Replace with just one library
err = repo.SetUserLibraries(userID, []int{library1.ID})
Expect(err).ToNot(HaveOccurred())
libraries, err := repo.GetUserLibraries(userID)
Expect(err).ToNot(HaveOccurred())
Expect(libraries).To(HaveLen(1))
Expect(libraries[0].ID).To(Equal(library1.ID))
})
It("removes all associations when passed empty slice", func() {
// Set initial associations
err := repo.SetUserLibraries(userID, []int{library1.ID, library2.ID})
Expect(err).ToNot(HaveOccurred())
// Remove all
err = repo.SetUserLibraries(userID, []int{})
Expect(err).ToNot(HaveOccurred())
libraries, err := repo.GetUserLibraries(userID)
Expect(err).ToNot(HaveOccurred())
Expect(libraries).To(HaveLen(0))
})
})
})
Describe("Admin User Auto-Assignment", func() {
var (
libRepo model.LibraryRepository
library1 model.Library
library2 model.Library
initialLibCount int
)
BeforeEach(func() {
libRepo = NewLibraryRepository(log.NewContext(context.TODO()), GetDBXBuilder())
// Count initial libraries
existingLibs, err := libRepo.GetAll()
Expect(err).ToNot(HaveOccurred())
initialLibCount = len(existingLibs)
library1 = model.Library{ID: 0, Name: "Admin Test Library 1", Path: "/admin/test/path1"}
library2 = model.Library{ID: 0, Name: "Admin Test Library 2", Path: "/admin/test/path2"}
// Create test libraries
Expect(libRepo.Put(&library1)).To(BeNil())
Expect(libRepo.Put(&library2)).To(BeNil())
})
AfterEach(func() {
// Clean up test libraries and their associations
_ = libRepo.(*libraryRepository).delete(squirrel.Eq{"id": []int{library1.ID, library2.ID}})
// Clean up user-library associations for these test libraries
_, _ = repo.(*userRepository).executeSQL(squirrel.Delete("user_library").Where(squirrel.Eq{"library_id": []int{library1.ID, library2.ID}}))
})
It("automatically assigns all libraries to admin users when created", func() {
adminUser := model.User{
ID: "admin-user-id-1",
UserName: "adminuser1",
Name: "Admin User",
Email: "admin1@example.com",
NewPassword: "password",
IsAdmin: true,
}
err := repo.Put(&adminUser)
Expect(err).ToNot(HaveOccurred())
// Admin should automatically have access to all libraries (including existing ones)
libraries, err := repo.GetUserLibraries(adminUser.ID)
Expect(err).ToNot(HaveOccurred())
Expect(libraries).To(HaveLen(initialLibCount + 2)) // Initial libraries + our 2 test libraries
libIDs := make([]int, len(libraries))
for i, lib := range libraries {
libIDs[i] = lib.ID
}
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
})
It("automatically assigns all libraries to admin users when updated", func() {
// Create regular user first
regularUser := model.User{
ID: "regular-user-id-1",
UserName: "regularuser1",
Name: "Regular User",
Email: "regular1@example.com",
NewPassword: "password",
IsAdmin: false,
}
err := repo.Put(&regularUser)
Expect(err).ToNot(HaveOccurred())
// Give them access to just one library
err = repo.SetUserLibraries(regularUser.ID, []int{library1.ID})
Expect(err).ToNot(HaveOccurred())
// Promote to admin
regularUser.IsAdmin = true
err = repo.Put(&regularUser)
Expect(err).ToNot(HaveOccurred())
// Should now have access to all libraries (including existing ones)
libraries, err := repo.GetUserLibraries(regularUser.ID)
Expect(err).ToNot(HaveOccurred())
Expect(libraries).To(HaveLen(initialLibCount + 2)) // Initial libraries + our 2 test libraries
libIDs := make([]int, len(libraries))
for i, lib := range libraries {
libIDs[i] = lib.ID
}
// Should include our test libraries plus all existing ones
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
})
It("assigns default libraries to regular users", func() {
regularUser := model.User{
ID: "regular-user-id-2",
UserName: "regularuser2",
Name: "Regular User",
Email: "regular2@example.com",
NewPassword: "password",
IsAdmin: false,
}
err := repo.Put(&regularUser)
Expect(err).ToNot(HaveOccurred())
// Regular user should be assigned to default libraries (library ID 1 from migration)
libraries, err := repo.GetUserLibraries(regularUser.ID)
Expect(err).ToNot(HaveOccurred())
Expect(libraries).To(HaveLen(1))
Expect(libraries[0].ID).To(Equal(1))
Expect(libraries[0].DefaultNewUsers).To(BeTrue())
})
})
Describe("Libraries Field Population", func() {
var (
libRepo model.LibraryRepository
library1 model.Library
library2 model.Library
testUser model.User
)
BeforeEach(func() {
libRepo = NewLibraryRepository(log.NewContext(context.TODO()), GetDBXBuilder())
library1 = model.Library{ID: 0, Name: "Field Test Library 1", Path: "/field/test/path1"}
library2 = model.Library{ID: 0, Name: "Field Test Library 2", Path: "/field/test/path2"}
// Create test libraries
Expect(libRepo.Put(&library1)).To(BeNil())
Expect(libRepo.Put(&library2)).To(BeNil())
// Create test user
testUser = model.User{
ID: "field-test-user",
UserName: "fieldtestuser",
Name: "Field Test User",
Email: "fieldtest@example.com",
NewPassword: "password",
IsAdmin: false,
}
Expect(repo.Put(&testUser)).To(BeNil())
// Assign libraries to user
Expect(repo.SetUserLibraries(testUser.ID, []int{library1.ID, library2.ID})).To(BeNil())
})
AfterEach(func() {
// Clean up test libraries and their associations
_ = libRepo.(*libraryRepository).delete(squirrel.Eq{"id": []int{library1.ID, library2.ID}})
_ = repo.(*userRepository).delete(squirrel.Eq{"id": testUser.ID})
// Clean up user-library associations for these test libraries
_, _ = repo.(*userRepository).executeSQL(squirrel.Delete("user_library").Where(squirrel.Eq{"library_id": []int{library1.ID, library2.ID}}))
})
It("populates Libraries field when getting a single user", func() {
user, err := repo.Get(testUser.ID)
Expect(err).ToNot(HaveOccurred())
Expect(user.Libraries).To(HaveLen(2))
libIDs := []int{user.Libraries[0].ID, user.Libraries[1].ID}
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
// Check that library details are properly populated
for _, lib := range user.Libraries {
switch lib.ID {
case library1.ID:
Expect(lib.Name).To(Equal("Field Test Library 1"))
Expect(lib.Path).To(Equal("/field/test/path1"))
case library2.ID:
Expect(lib.Name).To(Equal("Field Test Library 2"))
Expect(lib.Path).To(Equal("/field/test/path2"))
}
}
})
It("populates Libraries field when getting all users", func() {
users, err := repo.(*userRepository).GetAll()
Expect(err).ToNot(HaveOccurred())
// Find our test user in the results
found := slices.IndexFunc(users, func(u model.User) bool { return u.ID == testUser.ID })
Expect(found).ToNot(Equal(-1))
foundUser := users[found]
Expect(foundUser).ToNot(BeNil())
Expect(foundUser.Libraries).To(HaveLen(2))
libIDs := []int{foundUser.Libraries[0].ID, foundUser.Libraries[1].ID}
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
})
It("populates Libraries field when finding user by username", func() {
user, err := repo.FindByUsername(testUser.UserName)
Expect(err).ToNot(HaveOccurred())
Expect(user.Libraries).To(HaveLen(2))
libIDs := []int{user.Libraries[0].ID, user.Libraries[1].ID}
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
})
It("returns default Libraries array for new regular users", func() {
// Create a user with no explicit library associations - should get default libraries
userWithoutLibs := model.User{
ID: "no-libs-user",
UserName: "nolibsuser",
Name: "No Libs User",
Email: "nolibs@example.com",
NewPassword: "password",
IsAdmin: false,
}
Expect(repo.Put(&userWithoutLibs)).To(BeNil())
defer func() { _ = repo.(*userRepository).delete(squirrel.Eq{"id": userWithoutLibs.ID}) }()
user, err := repo.Get(userWithoutLibs.ID)
Expect(err).ToNot(HaveOccurred())
Expect(user.Libraries).ToNot(BeNil())
// Regular users should be assigned to default libraries (library ID 1 from migration)
Expect(user.Libraries).To(HaveLen(1))
Expect(user.Libraries[0].ID).To(Equal(1))
})
})
Describe("filters", func() {
It("qualifies id filter with table name", func() {
r := repo.(*userRepository)
qo := r.parseRestOptions(r.ctx, rest.QueryOptions{Filters: map[string]any{"id": "123"}})
sel := r.selectUserWithLibraries(qo)
query, _, err := r.toSQL(sel)
Expect(err).NotTo(HaveOccurred())
Expect(query).To(ContainSubstring("user.id = {:p0}"))
})
})
})