update
Some checks failed
Pipeline: Test, Lint, Build / Get version info (push) Has been cancelled
Pipeline: Test, Lint, Build / Lint Go code (push) Has been cancelled
Pipeline: Test, Lint, Build / Test Go code (push) Has been cancelled
Pipeline: Test, Lint, Build / Test JS code (push) Has been cancelled
Pipeline: Test, Lint, Build / Lint i18n files (push) Has been cancelled
Pipeline: Test, Lint, Build / Check Docker configuration (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (darwin/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (darwin/arm64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/386) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v5) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v6) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v7) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (windows/386) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (windows/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Push to GHCR (push) Has been cancelled
Pipeline: Test, Lint, Build / Push to Docker Hub (push) Has been cancelled
Pipeline: Test, Lint, Build / Cleanup digest artifacts (push) Has been cancelled
Pipeline: Test, Lint, Build / Build Windows installers (push) Has been cancelled
Pipeline: Test, Lint, Build / Package/Release (push) Has been cancelled
Pipeline: Test, Lint, Build / Upload Linux PKG (push) Has been cancelled
Close stale issues and PRs / stale (push) Has been cancelled
POEditor import / update-translations (push) Has been cancelled
Some checks failed
Pipeline: Test, Lint, Build / Get version info (push) Has been cancelled
Pipeline: Test, Lint, Build / Lint Go code (push) Has been cancelled
Pipeline: Test, Lint, Build / Test Go code (push) Has been cancelled
Pipeline: Test, Lint, Build / Test JS code (push) Has been cancelled
Pipeline: Test, Lint, Build / Lint i18n files (push) Has been cancelled
Pipeline: Test, Lint, Build / Check Docker configuration (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (darwin/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (darwin/arm64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/386) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v5) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v6) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm/v7) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (linux/arm64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (windows/386) (push) Has been cancelled
Pipeline: Test, Lint, Build / Build (windows/amd64) (push) Has been cancelled
Pipeline: Test, Lint, Build / Push to GHCR (push) Has been cancelled
Pipeline: Test, Lint, Build / Push to Docker Hub (push) Has been cancelled
Pipeline: Test, Lint, Build / Cleanup digest artifacts (push) Has been cancelled
Pipeline: Test, Lint, Build / Build Windows installers (push) Has been cancelled
Pipeline: Test, Lint, Build / Package/Release (push) Has been cancelled
Pipeline: Test, Lint, Build / Upload Linux PKG (push) Has been cancelled
Close stale issues and PRs / stale (push) Has been cancelled
POEditor import / update-translations (push) Has been cancelled
This commit is contained in:
442
persistence/album_repository.go
Normal file
442
persistence/album_repository.go
Normal file
@@ -0,0 +1,442 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type albumRepository struct {
|
||||
sqlRepository
|
||||
ms *MeilisearchService
|
||||
}
|
||||
|
||||
type dbAlbum struct {
|
||||
*model.Album `structs:",flatten"`
|
||||
Discs string `structs:"-" json:"discs"`
|
||||
Participants string `structs:"-" json:"-"`
|
||||
Tags string `structs:"-" json:"-"`
|
||||
FolderIDs string `structs:"-" json:"-"`
|
||||
}
|
||||
|
||||
func (a *dbAlbum) PostScan() error {
|
||||
var err error
|
||||
if a.Discs != "" {
|
||||
if err = json.Unmarshal([]byte(a.Discs), &a.Album.Discs); err != nil {
|
||||
return fmt.Errorf("parsing album discs from db: %w", err)
|
||||
}
|
||||
}
|
||||
a.Album.Participants, err = unmarshalParticipants(a.Participants)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing album from db: %w", err)
|
||||
}
|
||||
if a.Tags != "" {
|
||||
a.Album.Tags, err = unmarshalTags(a.Tags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing album from db: %w", err)
|
||||
}
|
||||
a.Genre, a.Genres = a.Album.Tags.ToGenres()
|
||||
}
|
||||
if a.FolderIDs != "" {
|
||||
var ids []string
|
||||
if err = json.Unmarshal([]byte(a.FolderIDs), &ids); err != nil {
|
||||
return fmt.Errorf("parsing album folder_ids from db: %w", err)
|
||||
}
|
||||
a.Album.FolderIDs = ids
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *dbAlbum) PostMapArgs(args map[string]any) error {
|
||||
fullText := []string{a.Name, a.SortAlbumName, a.AlbumArtist}
|
||||
fullText = append(fullText, a.Album.Participants.AllNames()...)
|
||||
fullText = append(fullText, slices.Collect(maps.Values(a.Album.Discs))...)
|
||||
fullText = append(fullText, a.Album.Tags[model.TagAlbumVersion]...)
|
||||
fullText = append(fullText, a.Album.Tags[model.TagCatalogNumber]...)
|
||||
args["full_text"] = formatFullText(fullText...)
|
||||
|
||||
args["tags"] = marshalTags(a.Album.Tags)
|
||||
args["participants"] = marshalParticipants(a.Album.Participants)
|
||||
|
||||
folderIDs, err := json.Marshal(a.Album.FolderIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling album folder_ids: %w", err)
|
||||
}
|
||||
args["folder_ids"] = string(folderIDs)
|
||||
|
||||
b, err := json.Marshal(a.Album.Discs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling album discs: %w", err)
|
||||
}
|
||||
args["discs"] = string(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbAlbums []dbAlbum
|
||||
|
||||
func (as dbAlbums) toModels() model.Albums {
|
||||
return slice.Map(as, func(a dbAlbum) model.Album { return *a.Album })
|
||||
}
|
||||
|
||||
func NewAlbumRepository(ctx context.Context, db dbx.Builder, ms *MeilisearchService) model.AlbumRepository {
|
||||
r := &albumRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.ms = ms
|
||||
r.tableName = "album"
|
||||
r.registerModel(&model.Album{}, albumFilters())
|
||||
r.setSortMappings(map[string]string{
|
||||
"name": "order_album_name, order_album_artist_name",
|
||||
"artist": "compilation, order_album_artist_name, order_album_name",
|
||||
"album_artist": "compilation, order_album_artist_name, order_album_name",
|
||||
// TODO Rename this to just year (or date)
|
||||
"max_year": "coalesce(nullif(original_date,''), cast(max_year as text)), release_date, name",
|
||||
"random": "random",
|
||||
"recently_added": recentlyAddedSort(),
|
||||
"starred_at": "starred, starred_at",
|
||||
"rated_at": "rating, rated_at",
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
var albumFilters = sync.OnceValue(func() map[string]filterFunc {
|
||||
filters := map[string]filterFunc{
|
||||
"id": idFilter("album"),
|
||||
"name": fullTextFilter("album", "mbz_album_id", "mbz_release_group_id"),
|
||||
"compilation": booleanFilter,
|
||||
"artist_id": artistFilter,
|
||||
"year": yearFilter,
|
||||
"recently_played": recentlyPlayedFilter,
|
||||
"starred": booleanFilter,
|
||||
"has_rating": hasRatingFilter,
|
||||
"missing": booleanFilter,
|
||||
"genre_id": tagIDFilter,
|
||||
"role_total_id": allRolesFilter,
|
||||
"library_id": libraryIdFilter,
|
||||
}
|
||||
// Add all album tags as filters
|
||||
for tag := range model.AlbumLevelTags() {
|
||||
filters[string(tag)] = tagIDFilter
|
||||
}
|
||||
|
||||
for role := range model.AllRoles {
|
||||
filters["role_"+role+"_id"] = artistRoleFilter
|
||||
}
|
||||
|
||||
return filters
|
||||
})
|
||||
|
||||
func recentlyAddedSort() string {
|
||||
if conf.Server.RecentlyAddedByModTime {
|
||||
return "updated_at"
|
||||
}
|
||||
return "created_at"
|
||||
}
|
||||
|
||||
func recentlyPlayedFilter(string, interface{}) Sqlizer {
|
||||
return Gt{"play_count": 0}
|
||||
}
|
||||
|
||||
func hasRatingFilter(string, interface{}) Sqlizer {
|
||||
return Gt{"rating": 0}
|
||||
}
|
||||
|
||||
func yearFilter(_ string, value interface{}) Sqlizer {
|
||||
return Or{
|
||||
And{
|
||||
Gt{"min_year": 0},
|
||||
LtOrEq{"min_year": value},
|
||||
GtOrEq{"max_year": value},
|
||||
},
|
||||
Eq{"max_year": value},
|
||||
}
|
||||
}
|
||||
|
||||
func artistFilter(_ string, value interface{}) Sqlizer {
|
||||
return Or{
|
||||
Exists("json_tree(participants, '$.albumartist')", Eq{"value": value}),
|
||||
Exists("json_tree(participants, '$.artist')", Eq{"value": value}),
|
||||
}
|
||||
}
|
||||
|
||||
func artistRoleFilter(name string, value interface{}) Sqlizer {
|
||||
roleName := strings.TrimSuffix(strings.TrimPrefix(name, "role_"), "_id")
|
||||
|
||||
// Check if the role name is valid. If not, return an invalid filter
|
||||
if _, ok := model.AllRoles[roleName]; !ok {
|
||||
return Gt{"": nil}
|
||||
}
|
||||
return Exists(fmt.Sprintf("json_tree(participants, '$.%s')", roleName), Eq{"value": value})
|
||||
}
|
||||
|
||||
func allRolesFilter(_ string, value interface{}) Sqlizer {
|
||||
return Like{"participants": fmt.Sprintf(`%%"%s"%%`, value)}
|
||||
}
|
||||
|
||||
func (r *albumRepository) CountAll(options ...model.QueryOptions) (int64, error) {
|
||||
query := r.newSelect()
|
||||
query = r.withAnnotation(query, "album.id")
|
||||
query = r.applyLibraryFilter(query)
|
||||
return r.count(query, options...)
|
||||
}
|
||||
|
||||
func (r *albumRepository) Exists(id string) (bool, error) {
|
||||
return r.exists(Eq{"album.id": id})
|
||||
}
|
||||
|
||||
func (r *albumRepository) Put(al *model.Album) error {
|
||||
al.ImportedAt = time.Now()
|
||||
id, err := r.put(al.ID, &dbAlbum{Album: al})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
al.ID = id
|
||||
if len(al.Participants) > 0 {
|
||||
err = r.updateParticipants(al.ID, al.Participants)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.ms != nil {
|
||||
r.ms.IndexAlbum(al)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO Move external metadata to a separated table
|
||||
func (r *albumRepository) UpdateExternalInfo(al *model.Album) error {
|
||||
_, err := r.put(al.ID, &dbAlbum{Album: al}, "description", "small_image_url", "medium_image_url", "large_image_url", "external_url", "external_info_updated_at")
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *albumRepository) selectAlbum(options ...model.QueryOptions) SelectBuilder {
|
||||
sql := r.newSelect(options...).Columns("album.*", "library.path as library_path", "library.name as library_name").
|
||||
LeftJoin("library on album.library_id = library.id")
|
||||
sql = r.withAnnotation(sql, "album.id")
|
||||
return r.applyLibraryFilter(sql)
|
||||
}
|
||||
|
||||
func (r *albumRepository) Get(id string) (*model.Album, error) {
|
||||
res, err := r.GetAll(model.QueryOptions{Filters: Eq{"album.id": id}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return nil, model.ErrNotFound
|
||||
}
|
||||
return &res[0], nil
|
||||
}
|
||||
|
||||
func (r *albumRepository) GetAll(options ...model.QueryOptions) (model.Albums, error) {
|
||||
sq := r.selectAlbum(options...)
|
||||
var res dbAlbums
|
||||
err := r.queryAll(sq, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.toModels(), err
|
||||
}
|
||||
|
||||
func (r *albumRepository) CopyAttributes(fromID, toID string, columns ...string) error {
|
||||
var from dbx.NullStringMap
|
||||
err := r.queryOne(Select(columns...).From(r.tableName).Where(Eq{"id": fromID}), &from)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting album to copy fields from: %w", err)
|
||||
}
|
||||
to := make(map[string]interface{})
|
||||
for _, col := range columns {
|
||||
to[col] = from[col]
|
||||
}
|
||||
_, err = r.executeSQL(Update(r.tableName).SetMap(to).Where(Eq{"id": toID}))
|
||||
return err
|
||||
}
|
||||
|
||||
// Touch flags an album as being scanned by the scanner, but not necessarily updated.
|
||||
// This is used for when missing tracks are detected for an album during scan.
|
||||
func (r *albumRepository) Touch(ids ...string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
for ids := range slices.Chunk(ids, 200) {
|
||||
upd := Update(r.tableName).Set("imported_at", time.Now()).Where(Eq{"id": ids})
|
||||
c, err := r.executeSQL(upd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error touching albums: %w", err)
|
||||
}
|
||||
log.Debug(r.ctx, "Touching albums", "ids", ids, "updated", c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TouchByMissingFolder touches all albums that have missing folders
|
||||
func (r *albumRepository) TouchByMissingFolder() (int64, error) {
|
||||
upd := Update(r.tableName).Set("imported_at", time.Now()).
|
||||
Where(And{
|
||||
NotEq{"folder_ids": nil},
|
||||
ConcatExpr("EXISTS (SELECT 1 FROM json_each(folder_ids) AS je JOIN main.folder AS f ON je.value = f.id WHERE f.missing = true)"),
|
||||
})
|
||||
c, err := r.executeSQL(upd)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error touching albums by missing folder: %w", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetTouchedAlbums returns all albums that were touched by the scanner for a given library, in the
|
||||
// current library scan run.
|
||||
// It does not need to load participants, as they are not used by the scanner.
|
||||
func (r *albumRepository) GetTouchedAlbums(libID int) (model.AlbumCursor, error) {
|
||||
query := r.selectAlbum().
|
||||
Where(And{
|
||||
Eq{"library.id": libID},
|
||||
ConcatExpr("album.imported_at > library.last_scan_at"),
|
||||
})
|
||||
cursor, err := queryWithStableResults[dbAlbum](r.sqlRepository, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(yield func(model.Album, error) bool) {
|
||||
for a, err := range cursor {
|
||||
if a.Album == nil {
|
||||
yield(model.Album{}, fmt.Errorf("unexpected nil album: %v", a))
|
||||
return
|
||||
}
|
||||
if !yield(*a.Album, err) || err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshPlayCounts updates the play count and last play date annotations for all albums, based
|
||||
// on the media files associated with them.
|
||||
func (r *albumRepository) RefreshPlayCounts() (int64, error) {
|
||||
query := Expr(`
|
||||
with play_counts as (
|
||||
select user_id, album_id, sum(play_count) as total_play_count, max(play_date) as last_play_date
|
||||
from media_file
|
||||
join annotation on item_id = media_file.id
|
||||
group by user_id, album_id
|
||||
)
|
||||
insert into annotation (user_id, item_id, item_type, play_count, play_date)
|
||||
select user_id, album_id, 'album', total_play_count, last_play_date
|
||||
from play_counts
|
||||
where total_play_count > 0
|
||||
on conflict (user_id, item_id, item_type) do update
|
||||
set play_count = excluded.play_count,
|
||||
play_date = excluded.play_date;
|
||||
`)
|
||||
return r.executeSQL(query)
|
||||
}
|
||||
|
||||
func (r *albumRepository) purgeEmpty(libraryIDs ...int) error {
|
||||
del := Delete(r.tableName).Where("id not in (select distinct(album_id) from media_file)")
|
||||
// If libraryIDs are specified, only purge albums from those libraries
|
||||
if len(libraryIDs) > 0 {
|
||||
del = del.Where(Eq{"library_id": libraryIDs})
|
||||
}
|
||||
c, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
return fmt.Errorf("purging empty albums: %w", err)
|
||||
}
|
||||
// TODO: Delete from Meilisearch.
|
||||
// Since purgeEmpty executes a DELETE statement without returning IDs, we can't easily sync Meilisearch here.
|
||||
// Ideally we should select IDs first, then delete. But this is a cleanup task.
|
||||
// For now we skip Meilisearch deletion here as it requires more changes.
|
||||
// The stale entries in Meilisearch will just return empty results when fetched from DB, which Search handles.
|
||||
if c > 0 {
|
||||
log.Debug(r.ctx, "Purged empty albums", "totalDeleted", c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *albumRepository) Search(q string, offset int, size int, options ...model.QueryOptions) (model.Albums, error) {
|
||||
var res dbAlbums
|
||||
if uuid.Validate(q) == nil {
|
||||
err := r.searchByMBID(r.selectAlbum(options...), q, []string{"mbz_album_id", "mbz_release_group_id"}, &res)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("searching album by MBID %q: %w", q, err)
|
||||
}
|
||||
} else {
|
||||
if r.ms != nil {
|
||||
ids, err := r.ms.Search("albums", q, offset, size)
|
||||
if err == nil {
|
||||
if len(ids) == 0 {
|
||||
return model.Albums{}, nil
|
||||
}
|
||||
// Fetch matching albums from the database
|
||||
albums, err := r.GetAll(model.QueryOptions{Filters: Eq{"album.id": ids}})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching albums from meilisearch ids: %w", err)
|
||||
}
|
||||
// Reorder results to match Meilisearch order
|
||||
idMap := make(map[string]model.Album, len(albums))
|
||||
for _, a := range albums {
|
||||
idMap[a.ID] = a
|
||||
}
|
||||
sorted := make(model.Albums, 0, len(albums))
|
||||
for _, id := range ids {
|
||||
if a, ok := idMap[id]; ok {
|
||||
sorted = append(sorted, a)
|
||||
}
|
||||
}
|
||||
return sorted, nil
|
||||
}
|
||||
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
|
||||
}
|
||||
err := r.doSearch(r.selectAlbum(options...), q, offset, size, &res, "album.rowid", "name")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("searching album by query %q: %w", q, err)
|
||||
}
|
||||
}
|
||||
return res.toModels(), nil
|
||||
}
|
||||
|
||||
func (r *albumRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *albumRepository) Read(id string) (interface{}, error) {
|
||||
return r.Get(id)
|
||||
}
|
||||
|
||||
func (r *albumRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
if len(options) > 0 && r.ms != nil {
|
||||
if name, ok := options[0].Filters["name"].(string); ok && name != "" {
|
||||
ids, err := r.ms.Search("albums", name, 0, 10000)
|
||||
if err == nil {
|
||||
log.Debug(r.ctx, "Meilisearch found matches", "count", len(ids), "query", name)
|
||||
delete(options[0].Filters, "name")
|
||||
options[0].Filters["id"] = ids
|
||||
} else {
|
||||
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *albumRepository) EntityName() string {
|
||||
return "album"
|
||||
}
|
||||
|
||||
func (r *albumRepository) NewInstance() interface{} {
|
||||
return &model.Album{}
|
||||
}
|
||||
|
||||
var _ model.AlbumRepository = (*albumRepository)(nil)
|
||||
var _ model.ResourceRepository = (*albumRepository)(nil)
|
||||
524
persistence/album_repository_test.go
Normal file
524
persistence/album_repository_test.go
Normal file
@@ -0,0 +1,524 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/consts"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("AlbumRepository", func() {
|
||||
var albumRepo *albumRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := request.WithUser(GinkgoT().Context(), model.User{ID: "userid", UserName: "johndoe"})
|
||||
albumRepo = NewAlbumRepository(ctx, GetDBXBuilder(), nil).(*albumRepository)
|
||||
})
|
||||
|
||||
Describe("Get", func() {
|
||||
var Get = func(id string) (*model.Album, error) {
|
||||
album, err := albumRepo.Get(id)
|
||||
if album != nil {
|
||||
album.ImportedAt = time.Time{}
|
||||
}
|
||||
return album, err
|
||||
}
|
||||
It("returns an existent album", func() {
|
||||
Expect(Get("103")).To(Equal(&albumRadioactivity))
|
||||
})
|
||||
It("returns ErrNotFound when the album does not exist", func() {
|
||||
_, err := Get("666")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetAll", func() {
|
||||
var GetAll = func(opts ...model.QueryOptions) (model.Albums, error) {
|
||||
albums, err := albumRepo.GetAll(opts...)
|
||||
for i := range albums {
|
||||
albums[i].ImportedAt = time.Time{}
|
||||
}
|
||||
return albums, err
|
||||
}
|
||||
|
||||
It("returns all records", func() {
|
||||
Expect(GetAll()).To(Equal(testAlbums))
|
||||
})
|
||||
|
||||
It("returns all records sorted", func() {
|
||||
Expect(GetAll(model.QueryOptions{Sort: "name"})).To(Equal(model.Albums{
|
||||
albumAbbeyRoad,
|
||||
albumMultiDisc,
|
||||
albumRadioactivity,
|
||||
albumSgtPeppers,
|
||||
}))
|
||||
})
|
||||
|
||||
It("returns all records sorted desc", func() {
|
||||
Expect(GetAll(model.QueryOptions{Sort: "name", Order: "desc"})).To(Equal(model.Albums{
|
||||
albumSgtPeppers,
|
||||
albumRadioactivity,
|
||||
albumMultiDisc,
|
||||
albumAbbeyRoad,
|
||||
}))
|
||||
})
|
||||
|
||||
It("paginates the result", func() {
|
||||
Expect(GetAll(model.QueryOptions{Offset: 1, Max: 1})).To(Equal(model.Albums{
|
||||
albumAbbeyRoad,
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Album.PlayCount", func() {
|
||||
// Implementation is in withAnnotation() method
|
||||
DescribeTable("normalizes play count when AlbumPlayCountMode is absolute",
|
||||
func(songCount, playCount, expected int) {
|
||||
conf.Server.AlbumPlayCountMode = consts.AlbumPlayCountModeAbsolute
|
||||
|
||||
newID := id.NewRandom()
|
||||
Expect(albumRepo.Put(&model.Album{LibraryID: 1, ID: newID, Name: "name", SongCount: songCount})).To(Succeed())
|
||||
for i := 0; i < playCount; i++ {
|
||||
Expect(albumRepo.IncPlayCount(newID, time.Now())).To(Succeed())
|
||||
}
|
||||
|
||||
album, err := albumRepo.Get(newID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(album.PlayCount).To(Equal(int64(expected)))
|
||||
},
|
||||
Entry("1 song, 0 plays", 1, 0, 0),
|
||||
Entry("1 song, 4 plays", 1, 4, 4),
|
||||
Entry("3 songs, 6 plays", 3, 6, 6),
|
||||
Entry("10 songs, 6 plays", 10, 6, 6),
|
||||
Entry("70 songs, 70 plays", 70, 70, 70),
|
||||
Entry("10 songs, 50 plays", 10, 50, 50),
|
||||
Entry("120 songs, 121 plays", 120, 121, 121),
|
||||
)
|
||||
|
||||
DescribeTable("normalizes play count when AlbumPlayCountMode is normalized",
|
||||
func(songCount, playCount, expected int) {
|
||||
conf.Server.AlbumPlayCountMode = consts.AlbumPlayCountModeNormalized
|
||||
|
||||
newID := id.NewRandom()
|
||||
Expect(albumRepo.Put(&model.Album{LibraryID: 1, ID: newID, Name: "name", SongCount: songCount})).To(Succeed())
|
||||
for i := 0; i < playCount; i++ {
|
||||
Expect(albumRepo.IncPlayCount(newID, time.Now())).To(Succeed())
|
||||
}
|
||||
|
||||
album, err := albumRepo.Get(newID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(album.PlayCount).To(Equal(int64(expected)))
|
||||
},
|
||||
Entry("1 song, 0 plays", 1, 0, 0),
|
||||
Entry("1 song, 4 plays", 1, 4, 4),
|
||||
Entry("3 songs, 6 plays", 3, 6, 2),
|
||||
Entry("10 songs, 6 plays", 10, 6, 1),
|
||||
Entry("70 songs, 70 plays", 70, 70, 1),
|
||||
Entry("10 songs, 50 plays", 10, 50, 5),
|
||||
Entry("120 songs, 121 plays", 120, 121, 1),
|
||||
)
|
||||
})
|
||||
|
||||
Describe("dbAlbum mapping", func() {
|
||||
var (
|
||||
a model.Album
|
||||
dba *dbAlbum
|
||||
args map[string]any
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
a = al(model.Album{ID: "1", Name: "name"})
|
||||
dba = &dbAlbum{Album: &a, Participants: "{}"}
|
||||
args = make(map[string]any)
|
||||
})
|
||||
|
||||
Describe("PostScan", func() {
|
||||
It("parses Discs correctly", func() {
|
||||
dba.Discs = `{"1":"disc1","2":"disc2"}`
|
||||
Expect(dba.PostScan()).To(Succeed())
|
||||
Expect(dba.Album.Discs).To(Equal(model.Discs{1: "disc1", 2: "disc2"}))
|
||||
})
|
||||
|
||||
It("parses Participants correctly", func() {
|
||||
dba.Participants = `{"composer":[{"id":"1","name":"Composer 1"}],` +
|
||||
`"artist":[{"id":"2","name":"Artist 2"},{"id":"3","name":"Artist 3","subRole":"subRole"}]}`
|
||||
Expect(dba.PostScan()).To(Succeed())
|
||||
Expect(dba.Album.Participants).To(HaveLen(2))
|
||||
Expect(dba.Album.Participants).To(HaveKeyWithValue(
|
||||
model.RoleFromString("composer"),
|
||||
model.ParticipantList{{Artist: model.Artist{ID: "1", Name: "Composer 1"}}},
|
||||
))
|
||||
Expect(dba.Album.Participants).To(HaveKeyWithValue(
|
||||
model.RoleFromString("artist"),
|
||||
model.ParticipantList{{Artist: model.Artist{ID: "2", Name: "Artist 2"}}, {Artist: model.Artist{ID: "3", Name: "Artist 3"}, SubRole: "subRole"}},
|
||||
))
|
||||
})
|
||||
|
||||
It("parses Tags correctly", func() {
|
||||
dba.Tags = `{"genre":[{"id":"1","value":"rock"},{"id":"2","value":"pop"}],"mood":[{"id":"3","value":"happy"}]}`
|
||||
Expect(dba.PostScan()).To(Succeed())
|
||||
Expect(dba.Album.Tags).To(HaveKeyWithValue(
|
||||
model.TagName("mood"), []string{"happy"},
|
||||
))
|
||||
Expect(dba.Album.Tags).To(HaveKeyWithValue(
|
||||
model.TagName("genre"), []string{"rock", "pop"},
|
||||
))
|
||||
Expect(dba.Album.Genre).To(Equal("rock"))
|
||||
Expect(dba.Album.Genres).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("parses Paths correctly", func() {
|
||||
dba.FolderIDs = `["folder1","folder2"]`
|
||||
Expect(dba.PostScan()).To(Succeed())
|
||||
Expect(dba.Album.FolderIDs).To(Equal([]string{"folder1", "folder2"}))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("PostMapArgs", func() {
|
||||
It("maps full_text correctly", func() {
|
||||
Expect(dba.PostMapArgs(args)).To(Succeed())
|
||||
Expect(args).To(HaveKeyWithValue("full_text", " name"))
|
||||
})
|
||||
|
||||
It("maps tags correctly", func() {
|
||||
dba.Album.Tags = model.Tags{"genre": {"rock", "pop"}, "mood": {"happy"}}
|
||||
Expect(dba.PostMapArgs(args)).To(Succeed())
|
||||
Expect(args).To(HaveKeyWithValue("tags",
|
||||
`{"genre":[{"id":"5qDZoz1FBC36K73YeoJ2lF","value":"rock"},{"id":"4H0KjnlS2ob9nKLL0zHOqB",`+
|
||||
`"value":"pop"}],"mood":[{"id":"1F4tmb516DIlHKFT1KzE1Z","value":"happy"}]}`,
|
||||
))
|
||||
})
|
||||
|
||||
It("maps participants correctly", func() {
|
||||
dba.Album.Participants = model.Participants{
|
||||
model.RoleAlbumArtist: model.ParticipantList{_p("AA1", "AlbumArtist1")},
|
||||
model.RoleComposer: model.ParticipantList{{Artist: model.Artist{ID: "C1", Name: "Composer1"}, SubRole: "composer"}},
|
||||
}
|
||||
Expect(dba.PostMapArgs(args)).To(Succeed())
|
||||
Expect(args).To(HaveKeyWithValue(
|
||||
"participants",
|
||||
`{"albumartist":[{"id":"AA1","name":"AlbumArtist1"}],`+
|
||||
`"composer":[{"id":"C1","name":"Composer1","subRole":"composer"}]}`,
|
||||
))
|
||||
})
|
||||
|
||||
It("maps discs correctly", func() {
|
||||
dba.Album.Discs = model.Discs{1: "disc1", 2: "disc2"}
|
||||
Expect(dba.PostMapArgs(args)).To(Succeed())
|
||||
Expect(args).To(HaveKeyWithValue("discs", `{"1":"disc1","2":"disc2"}`))
|
||||
})
|
||||
|
||||
It("maps paths correctly", func() {
|
||||
dba.Album.FolderIDs = []string{"folder1", "folder2"}
|
||||
Expect(dba.PostMapArgs(args)).To(Succeed())
|
||||
Expect(args).To(HaveKeyWithValue("folder_ids", `["folder1","folder2"]`))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("dbAlbums.toModels", func() {
|
||||
It("converts dbAlbums to model.Albums", func() {
|
||||
dba := dbAlbums{
|
||||
{Album: &model.Album{ID: "1", Name: "name", SongCount: 2, Annotations: model.Annotations{PlayCount: 4}}},
|
||||
{Album: &model.Album{ID: "2", Name: "name2", SongCount: 3, Annotations: model.Annotations{PlayCount: 6}}},
|
||||
}
|
||||
albums := dba.toModels()
|
||||
for i := range dba {
|
||||
Expect(albums[i].ID).To(Equal(dba[i].Album.ID))
|
||||
Expect(albums[i].Name).To(Equal(dba[i].Album.Name))
|
||||
Expect(albums[i].SongCount).To(Equal(dba[i].Album.SongCount))
|
||||
Expect(albums[i].PlayCount).To(Equal(dba[i].Album.PlayCount))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("artistRoleFilter", func() {
|
||||
DescribeTable("creates correct SQL expressions for artist roles",
|
||||
func(filterName, artistID, expectedSQL string) {
|
||||
sqlizer := artistRoleFilter(filterName, artistID)
|
||||
sql, args, err := sqlizer.ToSql()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sql).To(Equal(expectedSQL))
|
||||
Expect(args).To(Equal([]interface{}{artistID}))
|
||||
},
|
||||
Entry("artist role", "role_artist_id", "123",
|
||||
"exists (select 1 from json_tree(participants, '$.artist') where value = ?)"),
|
||||
Entry("albumartist role", "role_albumartist_id", "456",
|
||||
"exists (select 1 from json_tree(participants, '$.albumartist') where value = ?)"),
|
||||
Entry("composer role", "role_composer_id", "789",
|
||||
"exists (select 1 from json_tree(participants, '$.composer') where value = ?)"),
|
||||
)
|
||||
|
||||
It("works with the actual filter map", func() {
|
||||
filters := albumFilters()
|
||||
|
||||
for roleName := range model.AllRoles {
|
||||
filterName := "role_" + roleName + "_id"
|
||||
filterFunc, exists := filters[filterName]
|
||||
Expect(exists).To(BeTrue(), fmt.Sprintf("Filter %s should exist", filterName))
|
||||
|
||||
sqlizer := filterFunc(filterName, "test-id")
|
||||
sql, args, err := sqlizer.ToSql()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sql).To(Equal(fmt.Sprintf("exists (select 1 from json_tree(participants, '$.%s') where value = ?)", roleName)))
|
||||
Expect(args).To(Equal([]interface{}{"test-id"}))
|
||||
}
|
||||
})
|
||||
|
||||
It("rejects invalid roles", func() {
|
||||
sqlizer := artistRoleFilter("role_invalid_id", "123")
|
||||
_, _, err := sqlizer.ToSql()
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects invalid filter names", func() {
|
||||
sqlizer := artistRoleFilter("invalid_name", "123")
|
||||
_, _, err := sqlizer.ToSql()
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Participant Foreign Key Handling", func() {
|
||||
// albumArtistRecord represents a record in the album_artists table
|
||||
type albumArtistRecord struct {
|
||||
ArtistID string `db:"artist_id"`
|
||||
Role string `db:"role"`
|
||||
SubRole string `db:"sub_role"`
|
||||
}
|
||||
|
||||
var artistRepo *artistRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := request.WithUser(GinkgoT().Context(), adminUser)
|
||||
artistRepo = NewArtistRepository(ctx, GetDBXBuilder(), nil).(*artistRepository)
|
||||
})
|
||||
|
||||
// Helper to verify album_artists records
|
||||
verifyAlbumArtists := func(albumID string, expected []albumArtistRecord) {
|
||||
GinkgoHelper()
|
||||
var actual []albumArtistRecord
|
||||
sq := squirrel.Select("artist_id", "role", "sub_role").
|
||||
From("album_artists").
|
||||
Where(squirrel.Eq{"album_id": albumID}).
|
||||
OrderBy("role", "artist_id", "sub_role")
|
||||
|
||||
err := albumRepo.queryAll(sq, &actual)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual).To(Equal(expected))
|
||||
}
|
||||
|
||||
It("verifies that participant records are actually inserted into database", func() {
|
||||
// Create a real artist in the database first
|
||||
artist := &model.Artist{
|
||||
ID: "real-artist-1",
|
||||
Name: "Real Artist",
|
||||
OrderArtistName: "real artist",
|
||||
SortArtistName: "Artist, Real",
|
||||
}
|
||||
err := createArtistWithLibrary(artistRepo, artist, 1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create an album with participants that reference the real artist
|
||||
album := &model.Album{
|
||||
LibraryID: 1,
|
||||
ID: "test-album-db-insert",
|
||||
Name: "Test Album DB Insert",
|
||||
AlbumArtistID: "real-artist-1",
|
||||
AlbumArtist: "Real Artist",
|
||||
Participants: model.Participants{
|
||||
model.RoleArtist: {
|
||||
{Artist: model.Artist{ID: "real-artist-1", Name: "Real Artist"}},
|
||||
},
|
||||
model.RoleComposer: {
|
||||
{Artist: model.Artist{ID: "real-artist-1", Name: "Real Artist"}, SubRole: "primary"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Insert the album
|
||||
err = albumRepo.Put(album)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify that participant records were actually inserted into album_artists table
|
||||
expected := []albumArtistRecord{
|
||||
{ArtistID: "real-artist-1", Role: "artist", SubRole: ""},
|
||||
{ArtistID: "real-artist-1", Role: "composer", SubRole: "primary"},
|
||||
}
|
||||
verifyAlbumArtists(album.ID, expected)
|
||||
|
||||
// Clean up the test artist and album created for this test
|
||||
_, _ = artistRepo.executeSQL(squirrel.Delete("artist").Where(squirrel.Eq{"id": artist.ID}))
|
||||
_, _ = albumRepo.executeSQL(squirrel.Delete("album").Where(squirrel.Eq{"id": album.ID}))
|
||||
})
|
||||
|
||||
It("filters out invalid artist IDs leaving only valid participants in database", func() {
|
||||
// Create two real artists in the database
|
||||
artist1 := &model.Artist{
|
||||
ID: "real-artist-mix-1",
|
||||
Name: "Real Artist 1",
|
||||
OrderArtistName: "real artist 1",
|
||||
}
|
||||
artist2 := &model.Artist{
|
||||
ID: "real-artist-mix-2",
|
||||
Name: "Real Artist 2",
|
||||
OrderArtistName: "real artist 2",
|
||||
}
|
||||
err := createArtistWithLibrary(artistRepo, artist1, 1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = createArtistWithLibrary(artistRepo, artist2, 1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create an album with mix of valid and invalid artist IDs
|
||||
album := &model.Album{
|
||||
LibraryID: 1,
|
||||
ID: "test-album-mixed-validity",
|
||||
Name: "Test Album Mixed Validity",
|
||||
AlbumArtistID: "real-artist-mix-1",
|
||||
AlbumArtist: "Real Artist 1",
|
||||
Participants: model.Participants{
|
||||
model.RoleArtist: {
|
||||
{Artist: model.Artist{ID: "real-artist-mix-1", Name: "Real Artist 1"}},
|
||||
{Artist: model.Artist{ID: "non-existent-mix-1", Name: "Non Existent 1"}},
|
||||
{Artist: model.Artist{ID: "real-artist-mix-2", Name: "Real Artist 2"}},
|
||||
},
|
||||
model.RoleComposer: {
|
||||
{Artist: model.Artist{ID: "non-existent-mix-2", Name: "Non Existent 2"}},
|
||||
{Artist: model.Artist{ID: "real-artist-mix-1", Name: "Real Artist 1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// This should not fail - only valid artists should be inserted
|
||||
err = albumRepo.Put(album)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify that only valid artist IDs were inserted into album_artists table
|
||||
// Non-existent artists should be filtered out by the INNER JOIN
|
||||
expected := []albumArtistRecord{
|
||||
{ArtistID: "real-artist-mix-1", Role: "artist", SubRole: ""},
|
||||
{ArtistID: "real-artist-mix-2", Role: "artist", SubRole: ""},
|
||||
{ArtistID: "real-artist-mix-1", Role: "composer", SubRole: ""},
|
||||
}
|
||||
verifyAlbumArtists(album.ID, expected)
|
||||
|
||||
// Clean up the test artists and album created for this test
|
||||
artistIDs := []string{artist1.ID, artist2.ID}
|
||||
_, _ = artistRepo.executeSQL(squirrel.Delete("artist").Where(squirrel.Eq{"id": artistIDs}))
|
||||
_, _ = albumRepo.executeSQL(squirrel.Delete("album").Where(squirrel.Eq{"id": album.ID}))
|
||||
})
|
||||
|
||||
It("handles complex nested JSON with multiple roles and sub-roles", func() {
|
||||
// Create 4 artists for this test
|
||||
artists := []*model.Artist{
|
||||
{ID: "complex-artist-1", Name: "Lead Vocalist", OrderArtistName: "lead vocalist"},
|
||||
{ID: "complex-artist-2", Name: "Guitarist", OrderArtistName: "guitarist"},
|
||||
{ID: "complex-artist-3", Name: "Producer", OrderArtistName: "producer"},
|
||||
{ID: "complex-artist-4", Name: "Engineer", OrderArtistName: "engineer"},
|
||||
}
|
||||
|
||||
for _, artist := range artists {
|
||||
err := createArtistWithLibrary(artistRepo, artist, 1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
// Create album with complex participant structure
|
||||
album := &model.Album{
|
||||
LibraryID: 1,
|
||||
ID: "test-album-complex-json",
|
||||
Name: "Test Album Complex JSON",
|
||||
AlbumArtistID: "complex-artist-1",
|
||||
AlbumArtist: "Lead Vocalist",
|
||||
Participants: model.Participants{
|
||||
model.RoleArtist: {
|
||||
{Artist: model.Artist{ID: "complex-artist-1", Name: "Lead Vocalist"}},
|
||||
{Artist: model.Artist{ID: "complex-artist-2", Name: "Guitarist"}, SubRole: "lead guitar"},
|
||||
{Artist: model.Artist{ID: "complex-artist-2", Name: "Guitarist"}, SubRole: "rhythm guitar"},
|
||||
},
|
||||
model.RoleProducer: {
|
||||
{Artist: model.Artist{ID: "complex-artist-3", Name: "Producer"}, SubRole: "executive"},
|
||||
},
|
||||
model.RoleEngineer: {
|
||||
{Artist: model.Artist{ID: "complex-artist-4", Name: "Engineer"}, SubRole: "mixing"},
|
||||
{Artist: model.Artist{ID: "complex-artist-4", Name: "Engineer"}, SubRole: "mastering"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := albumRepo.Put(album)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify complex JSON structure was correctly parsed and inserted
|
||||
expected := []albumArtistRecord{
|
||||
{ArtistID: "complex-artist-1", Role: "artist", SubRole: ""},
|
||||
{ArtistID: "complex-artist-2", Role: "artist", SubRole: "lead guitar"},
|
||||
{ArtistID: "complex-artist-2", Role: "artist", SubRole: "rhythm guitar"},
|
||||
{ArtistID: "complex-artist-4", Role: "engineer", SubRole: "mastering"},
|
||||
{ArtistID: "complex-artist-4", Role: "engineer", SubRole: "mixing"},
|
||||
{ArtistID: "complex-artist-3", Role: "producer", SubRole: "executive"},
|
||||
}
|
||||
verifyAlbumArtists(album.ID, expected)
|
||||
|
||||
// Clean up the test artists and album created for this test
|
||||
artistIDs := make([]string, len(artists))
|
||||
for i, artist := range artists {
|
||||
artistIDs[i] = artist.ID
|
||||
}
|
||||
_, _ = artistRepo.executeSQL(squirrel.Delete("artist").Where(squirrel.Eq{"id": artistIDs}))
|
||||
_, _ = albumRepo.executeSQL(squirrel.Delete("album").Where(squirrel.Eq{"id": album.ID}))
|
||||
})
|
||||
|
||||
It("handles albums with non-existent artist IDs without constraint errors", func() {
|
||||
// Regression test for foreign key constraint error when album participants
|
||||
// contain artist IDs that don't exist in the artist table
|
||||
|
||||
// Create an album with participants that reference non-existent artist IDs
|
||||
album := &model.Album{
|
||||
LibraryID: 1,
|
||||
ID: "test-album-fk-constraints",
|
||||
Name: "Test Album with Invalid Artist References",
|
||||
AlbumArtistID: "non-existent-artist-1",
|
||||
AlbumArtist: "Non Existent Album Artist",
|
||||
Participants: model.Participants{
|
||||
model.RoleArtist: {
|
||||
{Artist: model.Artist{ID: "non-existent-artist-1", Name: "Non Existent Artist 1"}},
|
||||
{Artist: model.Artist{ID: "non-existent-artist-2", Name: "Non Existent Artist 2"}},
|
||||
},
|
||||
model.RoleComposer: {
|
||||
{Artist: model.Artist{ID: "non-existent-composer-1", Name: "Non Existent Composer 1"}},
|
||||
{Artist: model.Artist{ID: "non-existent-composer-2", Name: "Non Existent Composer 2"}},
|
||||
},
|
||||
model.RoleAlbumArtist: {
|
||||
{Artist: model.Artist{ID: "non-existent-album-artist-1", Name: "Non Existent Album Artist 1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// This should not fail with foreign key constraint error
|
||||
// The updateParticipants method should handle non-existent artist IDs gracefully
|
||||
err := albumRepo.Put(album)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify that no participant records were inserted since all artist IDs were invalid
|
||||
// The INNER JOIN with the artist table should filter out all non-existent artists
|
||||
verifyAlbumArtists(album.ID, []albumArtistRecord{})
|
||||
|
||||
// Clean up the test album created for this test
|
||||
_, _ = albumRepo.executeSQL(squirrel.Delete("album").Where(squirrel.Eq{"id": album.ID}))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func _p(id, name string, sortName ...string) model.Participant {
|
||||
p := model.Participant{Artist: model.Artist{ID: id, Name: name}}
|
||||
if len(sortName) > 0 {
|
||||
p.Artist.SortArtistName = sortName[0]
|
||||
}
|
||||
return p
|
||||
}
|
||||
608
persistence/artist_repository.go
Normal file
608
persistence/artist_repository.go
Normal file
@@ -0,0 +1,608 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils"
|
||||
. "github.com/navidrome/navidrome/utils/gg"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type artistRepository struct {
|
||||
sqlRepository
|
||||
indexGroups utils.IndexGroups
|
||||
ms *MeilisearchService
|
||||
}
|
||||
|
||||
type dbArtist struct {
|
||||
*model.Artist `structs:",flatten"`
|
||||
SimilarArtists string `structs:"-" json:"-"`
|
||||
LibraryStatsJSON string `structs:"-" json:"-"`
|
||||
}
|
||||
|
||||
type dbSimilarArtist struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
func (a *dbArtist) PostScan() error {
|
||||
a.Artist.Stats = make(map[model.Role]model.ArtistStats)
|
||||
|
||||
if a.LibraryStatsJSON != "" {
|
||||
var rawLibStats map[string]map[string]map[string]int64
|
||||
if err := json.Unmarshal([]byte(a.LibraryStatsJSON), &rawLibStats); err != nil {
|
||||
return fmt.Errorf("parsing artist stats from db: %w", err)
|
||||
}
|
||||
|
||||
for _, stats := range rawLibStats {
|
||||
// Sum all libraries roles stats
|
||||
for key, stat := range stats {
|
||||
// Aggregate stats into the main Artist.Stats map
|
||||
artistStats := model.ArtistStats{
|
||||
SongCount: int(stat["m"]),
|
||||
AlbumCount: int(stat["a"]),
|
||||
Size: stat["s"],
|
||||
}
|
||||
|
||||
// Store total stats into the main attributes
|
||||
if key == "total" {
|
||||
a.Artist.Size += artistStats.Size
|
||||
a.Artist.SongCount += artistStats.SongCount
|
||||
a.Artist.AlbumCount += artistStats.AlbumCount
|
||||
}
|
||||
|
||||
role := model.RoleFromString(key)
|
||||
if role == model.RoleInvalid {
|
||||
continue
|
||||
}
|
||||
|
||||
current := a.Artist.Stats[role]
|
||||
current.Size += artistStats.Size
|
||||
current.SongCount += artistStats.SongCount
|
||||
current.AlbumCount += artistStats.AlbumCount
|
||||
a.Artist.Stats[role] = current
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.Artist.SimilarArtists = nil
|
||||
if a.SimilarArtists == "" {
|
||||
return nil
|
||||
}
|
||||
var sa []dbSimilarArtist
|
||||
if err := json.Unmarshal([]byte(a.SimilarArtists), &sa); err != nil {
|
||||
return fmt.Errorf("parsing similar artists from db: %w", err)
|
||||
}
|
||||
for _, s := range sa {
|
||||
a.Artist.SimilarArtists = append(a.Artist.SimilarArtists, model.Artist{
|
||||
ID: s.ID,
|
||||
Name: s.Name,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *dbArtist) PostMapArgs(m map[string]any) error {
|
||||
sa := make([]dbSimilarArtist, 0)
|
||||
for _, s := range a.Artist.SimilarArtists {
|
||||
sa = append(sa, dbSimilarArtist{ID: s.ID, Name: s.Name})
|
||||
}
|
||||
similarArtists, _ := json.Marshal(sa)
|
||||
m["similar_artists"] = string(similarArtists)
|
||||
m["full_text"] = formatFullText(a.Name, a.SortArtistName)
|
||||
|
||||
// Do not override the sort_artist_name and mbz_artist_id fields if they are empty
|
||||
// TODO: Better way to handle this?
|
||||
if v, ok := m["sort_artist_name"]; !ok || v.(string) == "" {
|
||||
delete(m, "sort_artist_name")
|
||||
}
|
||||
if v, ok := m["mbz_artist_id"]; !ok || v.(string) == "" {
|
||||
delete(m, "mbz_artist_id")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbArtists []dbArtist
|
||||
|
||||
func (dba dbArtists) toModels() model.Artists {
|
||||
res := make(model.Artists, len(dba))
|
||||
for i := range dba {
|
||||
res[i] = *dba[i].Artist
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func NewArtistRepository(ctx context.Context, db dbx.Builder, ms *MeilisearchService) model.ArtistRepository {
|
||||
r := &artistRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.ms = ms
|
||||
r.indexGroups = utils.ParseIndexGroups(conf.Server.IndexGroups)
|
||||
r.tableName = "artist" // To be used by the idFilter below
|
||||
r.registerModel(&model.Artist{}, map[string]filterFunc{
|
||||
"id": idFilter(r.tableName),
|
||||
"name": fullTextFilter(r.tableName, "mbz_artist_id"),
|
||||
"starred": booleanFilter,
|
||||
"role": roleFilter,
|
||||
"missing": booleanFilter,
|
||||
"library_id": artistLibraryIdFilter,
|
||||
})
|
||||
r.setSortMappings(map[string]string{
|
||||
"name": "order_artist_name",
|
||||
"starred_at": "starred, starred_at",
|
||||
"rated_at": "rating, rated_at",
|
||||
"song_count": "stats->>'total'->>'m'",
|
||||
"album_count": "stats->>'total'->>'a'",
|
||||
"size": "stats->>'total'->>'s'",
|
||||
|
||||
// Stats by credits that are currently available
|
||||
"maincredit_song_count": "sum(stats->>'maincredit'->>'m')",
|
||||
"maincredit_album_count": "sum(stats->>'maincredit'->>'a')",
|
||||
"maincredit_size": "sum(stats->>'maincredit'->>'s')",
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func roleFilter(_ string, role any) Sqlizer {
|
||||
if role, ok := role.(string); ok {
|
||||
if _, ok := model.AllRoles[role]; ok {
|
||||
return Expr("JSON_EXTRACT(library_artist.stats, '$." + role + ".m') IS NOT NULL")
|
||||
}
|
||||
}
|
||||
return Eq{"1": 2}
|
||||
}
|
||||
|
||||
// artistLibraryIdFilter filters artists based on library access through the library_artist table
|
||||
func artistLibraryIdFilter(_ string, value interface{}) Sqlizer {
|
||||
return Eq{"library_artist.library_id": value}
|
||||
}
|
||||
|
||||
// applyLibraryFilterToArtistQuery applies library filtering to artist queries through the library_artist junction table
|
||||
func (r *artistRepository) applyLibraryFilterToArtistQuery(query SelectBuilder) SelectBuilder {
|
||||
user := loggedUser(r.ctx)
|
||||
// Join with library_artist first to ensure only artists with content in libraries are included
|
||||
// Exclude artists with empty stats (no actual content in the library)
|
||||
query = query.Join("library_artist on library_artist.artist_id = artist.id")
|
||||
//query = query.Join("library_artist on library_artist.artist_id = artist.id AND library_artist.stats != '{}'")
|
||||
|
||||
// Admin users see all artists from all libraries, no additional filtering needed
|
||||
if user.ID != invalidUserId && !user.IsAdmin {
|
||||
// Apply library filtering only for non-admin users by joining with their accessible libraries
|
||||
query = query.Join("user_library on user_library.library_id = library_artist.library_id AND user_library.user_id = ?", user.ID)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (r *artistRepository) selectArtist(options ...model.QueryOptions) SelectBuilder {
|
||||
// Stats Format: {"1": {"albumartist": {"m": 10, "a": 5, "s": 1024}, "artist": {...}}, "2": {...}}
|
||||
query := r.newSelect(options...).Columns("artist.*",
|
||||
"JSON_GROUP_OBJECT(library_artist.library_id, JSONB(library_artist.stats)) as library_stats_json")
|
||||
|
||||
query = r.applyLibraryFilterToArtistQuery(query)
|
||||
query = query.GroupBy("artist.id")
|
||||
return r.withAnnotation(query, "artist.id")
|
||||
}
|
||||
|
||||
func (r *artistRepository) CountAll(options ...model.QueryOptions) (int64, error) {
|
||||
query := r.newSelect()
|
||||
query = r.applyLibraryFilterToArtistQuery(query)
|
||||
query = r.withAnnotation(query, "artist.id")
|
||||
return r.count(query, options...)
|
||||
}
|
||||
|
||||
// Exists checks if an artist with the given ID exists in the database and is accessible by the current user.
|
||||
func (r *artistRepository) Exists(id string) (bool, error) {
|
||||
// Create a query using the same library filtering logic as selectArtist()
|
||||
query := r.newSelect().Columns("count(distinct artist.id) as exist").Where(Eq{"artist.id": id})
|
||||
query = r.applyLibraryFilterToArtistQuery(query)
|
||||
|
||||
var res struct{ Exist int64 }
|
||||
err := r.queryOne(query, &res)
|
||||
return res.Exist > 0, err
|
||||
}
|
||||
|
||||
func (r *artistRepository) Put(a *model.Artist, colsToUpdate ...string) error {
|
||||
dba := &dbArtist{Artist: a}
|
||||
dba.CreatedAt = P(time.Now())
|
||||
dba.UpdatedAt = dba.CreatedAt
|
||||
_, err := r.put(dba.ID, dba, colsToUpdate...)
|
||||
if err == nil && r.ms != nil {
|
||||
r.ms.IndexArtist(a)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *artistRepository) UpdateExternalInfo(a *model.Artist) error {
|
||||
dba := &dbArtist{Artist: a}
|
||||
_, err := r.put(a.ID, dba,
|
||||
"biography", "small_image_url", "medium_image_url", "large_image_url",
|
||||
"similar_artists", "external_url", "external_info_updated_at")
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *artistRepository) Get(id string) (*model.Artist, error) {
|
||||
sel := r.selectArtist().Where(Eq{"artist.id": id})
|
||||
var dba dbArtists
|
||||
if err := r.queryAll(sel, &dba); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dba) == 0 {
|
||||
return nil, model.ErrNotFound
|
||||
}
|
||||
res := dba.toModels()
|
||||
return &res[0], nil
|
||||
}
|
||||
|
||||
func (r *artistRepository) GetAll(options ...model.QueryOptions) (model.Artists, error) {
|
||||
sel := r.selectArtist(options...)
|
||||
var dba dbArtists
|
||||
err := r.queryAll(sel, &dba)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res := dba.toModels()
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *artistRepository) getIndexKey(a model.Artist) string {
|
||||
source := a.OrderArtistName
|
||||
if conf.Server.PreferSortTags {
|
||||
source = cmp.Or(a.SortArtistName, a.OrderArtistName)
|
||||
}
|
||||
name := strings.ToLower(source)
|
||||
for k, v := range r.indexGroups {
|
||||
if strings.HasPrefix(name, strings.ToLower(k)) {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return "#"
|
||||
}
|
||||
|
||||
// GetIndex returns a list of artists grouped by the first letter of their name, or by the index group if configured.
|
||||
// It can filter by roles and libraries, and optionally include artists that are missing (i.e., have no albums).
|
||||
// TODO Cache the index (recalculate at scan time)
|
||||
func (r *artistRepository) GetIndex(includeMissing bool, libraryIds []int, roles ...model.Role) (model.ArtistIndexes, error) {
|
||||
// Validate library IDs. If no library IDs are provided, return an empty index.
|
||||
if len(libraryIds) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
options := model.QueryOptions{Sort: "name"}
|
||||
if len(roles) > 0 {
|
||||
roleFilters := slice.Map(roles, func(r model.Role) Sqlizer {
|
||||
return roleFilter("role", r.String())
|
||||
})
|
||||
options.Filters = Or(roleFilters)
|
||||
}
|
||||
if !includeMissing {
|
||||
if options.Filters == nil {
|
||||
options.Filters = Eq{"artist.missing": false}
|
||||
} else {
|
||||
options.Filters = And{options.Filters, Eq{"artist.missing": false}}
|
||||
}
|
||||
}
|
||||
|
||||
libFilter := artistLibraryIdFilter("library_id", libraryIds)
|
||||
if options.Filters == nil {
|
||||
options.Filters = libFilter
|
||||
} else {
|
||||
options.Filters = And{options.Filters, libFilter}
|
||||
}
|
||||
|
||||
artists, err := r.GetAll(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result model.ArtistIndexes
|
||||
for k, v := range slice.Group(artists, r.getIndexKey) {
|
||||
result = append(result, model.ArtistIndex{ID: k, Artists: v})
|
||||
}
|
||||
slices.SortFunc(result, func(a, b model.ArtistIndex) int {
|
||||
return cmp.Compare(a.ID, b.ID)
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *artistRepository) purgeEmpty() error {
|
||||
del := Delete(r.tableName).Where("id not in (select artist_id from album_artists)")
|
||||
c, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
return fmt.Errorf("purging empty artists: %w", err)
|
||||
}
|
||||
if c > 0 {
|
||||
log.Debug(r.ctx, "Purged empty artists", "totalDeleted", c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// markMissing marks artists as missing if all their albums are missing.
|
||||
func (r *artistRepository) markMissing() error {
|
||||
q := Expr(`
|
||||
with artists_with_non_missing_albums as (
|
||||
select distinct aa.artist_id
|
||||
from album_artists aa
|
||||
join album a on aa.album_id = a.id
|
||||
where a.missing = false
|
||||
)
|
||||
update artist
|
||||
set missing = (artist.id not in (select artist_id from artists_with_non_missing_albums));
|
||||
`)
|
||||
_, err := r.executeSQL(q)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marking missing artists: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshPlayCounts updates the play count and last play date annotations for all artists, based
|
||||
// on the media files associated with them.
|
||||
func (r *artistRepository) RefreshPlayCounts() (int64, error) {
|
||||
query := Expr(`
|
||||
with play_counts as (
|
||||
select user_id, atom as artist_id, sum(play_count) as total_play_count, max(play_date) as last_play_date
|
||||
from media_file
|
||||
join annotation on item_id = media_file.id
|
||||
left join json_tree(participants, '$.artist') as jt
|
||||
where atom is not null and key = 'id'
|
||||
group by user_id, atom
|
||||
)
|
||||
insert into annotation (user_id, item_id, item_type, play_count, play_date)
|
||||
select user_id, artist_id, 'artist', total_play_count, last_play_date
|
||||
from play_counts
|
||||
where total_play_count > 0
|
||||
on conflict (user_id, item_id, item_type) do update
|
||||
set play_count = excluded.play_count,
|
||||
play_date = excluded.play_date;
|
||||
`)
|
||||
return r.executeSQL(query)
|
||||
}
|
||||
|
||||
// RefreshStats updates the stats field for artists whose associated media files were updated after the oldest recorded library scan time.
|
||||
// When allArtists is true, it refreshes stats for all artists. It processes artists in batches to handle potentially large updates.
|
||||
// This method now calculates per-library statistics and stores them in the library_artist junction table.
|
||||
func (r *artistRepository) RefreshStats(allArtists bool) (int64, error) {
|
||||
var allTouchedArtistIDs []string
|
||||
if allArtists {
|
||||
// Refresh stats for all artists
|
||||
allArtistsQuerySQL := `SELECT DISTINCT id FROM artist WHERE id <> ''`
|
||||
if err := r.db.NewQuery(allArtistsQuerySQL).Column(&allTouchedArtistIDs); err != nil {
|
||||
return 0, fmt.Errorf("fetching all artist IDs: %w", err)
|
||||
}
|
||||
log.Debug(r.ctx, "RefreshStats: Refreshing all artists.", "count", len(allTouchedArtistIDs))
|
||||
} else {
|
||||
// Only refresh artists with updated timestamps
|
||||
touchedArtistsQuerySQL := `
|
||||
SELECT DISTINCT id
|
||||
FROM artist
|
||||
WHERE updated_at > (SELECT last_scan_at FROM library ORDER BY last_scan_at ASC LIMIT 1)
|
||||
`
|
||||
if err := r.db.NewQuery(touchedArtistsQuerySQL).Column(&allTouchedArtistIDs); err != nil {
|
||||
return 0, fmt.Errorf("fetching touched artist IDs: %w", err)
|
||||
}
|
||||
log.Debug(r.ctx, "RefreshStats: Refreshing touched artists.", "count", len(allTouchedArtistIDs))
|
||||
}
|
||||
|
||||
if len(allTouchedArtistIDs) == 0 {
|
||||
log.Debug(r.ctx, "RefreshStats: No artists to update.")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Template for the batch update with placeholder markers that we'll replace
|
||||
// This now calculates per-library statistics and stores them in library_artist.stats
|
||||
batchUpdateStatsSQL := `
|
||||
WITH artist_role_counters AS (
|
||||
SELECT mfa.artist_id,
|
||||
mf.library_id,
|
||||
mfa.role,
|
||||
count(DISTINCT mf.album_id) AS album_count,
|
||||
count(DISTINCT mf.id) AS count,
|
||||
sum(mf.size) AS size
|
||||
FROM media_file_artists mfa
|
||||
JOIN media_file mf ON mfa.media_file_id = mf.id
|
||||
WHERE mfa.artist_id IN (ROLE_IDS_PLACEHOLDER) -- Will replace with actual placeholders
|
||||
GROUP BY mfa.artist_id, mf.library_id, mfa.role
|
||||
),
|
||||
artist_total_counters AS (
|
||||
SELECT mfa.artist_id,
|
||||
mf.library_id,
|
||||
'total' AS role,
|
||||
count(DISTINCT mf.album_id) AS album_count,
|
||||
count(DISTINCT mf.id) AS count,
|
||||
sum(mf.size) AS size
|
||||
FROM media_file_artists mfa
|
||||
JOIN media_file mf ON mfa.media_file_id = mf.id
|
||||
WHERE mfa.artist_id IN (ROLE_IDS_PLACEHOLDER) -- Will replace with actual placeholders
|
||||
GROUP BY mfa.artist_id, mf.library_id
|
||||
),
|
||||
artist_participant_counter AS (
|
||||
SELECT mfa.artist_id,
|
||||
mf.library_id,
|
||||
'maincredit' AS role,
|
||||
count(DISTINCT mf.album_id) AS album_count,
|
||||
count(DISTINCT mf.id) AS count,
|
||||
sum(mf.size) AS size
|
||||
FROM media_file_artists mfa
|
||||
JOIN media_file mf ON mfa.media_file_id = mf.id
|
||||
WHERE mfa.artist_id IN (ROLE_IDS_PLACEHOLDER) -- Will replace with actual placeholders
|
||||
AND mfa.role IN ('albumartist', 'artist')
|
||||
GROUP BY mfa.artist_id, mf.library_id
|
||||
),
|
||||
combined_counters AS (
|
||||
SELECT artist_id, library_id, role, album_count, count, size FROM artist_role_counters
|
||||
UNION ALL
|
||||
SELECT artist_id, library_id, role, album_count, count, size FROM artist_total_counters
|
||||
UNION ALL
|
||||
SELECT artist_id, library_id, role, album_count, count, size FROM artist_participant_counter
|
||||
),
|
||||
library_artist_counters AS (
|
||||
SELECT artist_id,
|
||||
library_id,
|
||||
json_group_object(
|
||||
role,
|
||||
json_object('a', album_count, 'm', count, 's', size)
|
||||
) AS counters
|
||||
FROM combined_counters
|
||||
GROUP BY artist_id, library_id
|
||||
)
|
||||
UPDATE library_artist
|
||||
SET stats = coalesce((SELECT counters FROM library_artist_counters lac
|
||||
WHERE lac.artist_id = library_artist.artist_id
|
||||
AND lac.library_id = library_artist.library_id), '{}')
|
||||
WHERE library_artist.artist_id IN (ROLE_IDS_PLACEHOLDER);` // Will replace with actual placeholders
|
||||
|
||||
var totalRowsAffected int64 = 0
|
||||
const batchSize = 1000
|
||||
|
||||
batchCounter := 0
|
||||
for artistIDBatch := range slice.CollectChunks(slices.Values(allTouchedArtistIDs), batchSize) {
|
||||
batchCounter++
|
||||
log.Trace(r.ctx, "RefreshStats: Processing batch", "batchNum", batchCounter, "batchSize", len(artistIDBatch))
|
||||
|
||||
// Create placeholders for each ID in the IN clauses
|
||||
placeholders := make([]string, len(artistIDBatch))
|
||||
for i := range artistIDBatch {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
// Don't add extra parentheses, the IN clause already expects them in SQL syntax
|
||||
inClause := strings.Join(placeholders, ",")
|
||||
|
||||
// Replace the placeholder markers with actual SQL placeholders
|
||||
batchSQL := strings.Replace(batchUpdateStatsSQL, "ROLE_IDS_PLACEHOLDER", inClause, 4)
|
||||
|
||||
// Create a single parameter array with all IDs (repeated 4 times for each IN clause)
|
||||
// We need to repeat each ID 4 times (once for each IN clause)
|
||||
args := make([]any, 4*len(artistIDBatch))
|
||||
for idx, id := range artistIDBatch {
|
||||
for i := range 4 {
|
||||
startIdx := i * len(artistIDBatch)
|
||||
args[startIdx+idx] = id
|
||||
}
|
||||
}
|
||||
|
||||
// Now use Expr with the expanded SQL and all parameters
|
||||
sqlizer := Expr(batchSQL, args...)
|
||||
|
||||
rowsAffected, err := r.executeSQL(sqlizer)
|
||||
if err != nil {
|
||||
return totalRowsAffected, fmt.Errorf("executing batch update for artist stats (batch %d): %w", batchCounter, err)
|
||||
}
|
||||
totalRowsAffected += rowsAffected
|
||||
}
|
||||
|
||||
// // Remove library_artist entries for artists that no longer have any content in any library
|
||||
cleanupSQL := Delete("library_artist").Where("stats = '{}'")
|
||||
cleanupRows, err := r.executeSQL(cleanupSQL)
|
||||
if err != nil {
|
||||
log.Warn(r.ctx, "Failed to cleanup empty library_artist entries", "error", err)
|
||||
} else if cleanupRows > 0 {
|
||||
log.Debug(r.ctx, "Cleaned up empty library_artist entries", "rowsDeleted", cleanupRows)
|
||||
}
|
||||
|
||||
log.Debug(r.ctx, "RefreshStats: Successfully updated stats.", "totalArtistsProcessed", len(allTouchedArtistIDs), "totalDBRowsAffected", totalRowsAffected)
|
||||
return totalRowsAffected, nil
|
||||
}
|
||||
|
||||
func (r *artistRepository) Search(q string, offset int, size int, options ...model.QueryOptions) (model.Artists, error) {
|
||||
var res dbArtists
|
||||
if uuid.Validate(q) == nil {
|
||||
err := r.searchByMBID(r.selectArtist(options...), q, []string{"mbz_artist_id"}, &res)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("searching artist by MBID %q: %w", q, err)
|
||||
}
|
||||
} else {
|
||||
if r.ms != nil {
|
||||
ids, err := r.ms.Search("artists", q, offset, size)
|
||||
if err == nil {
|
||||
if len(ids) == 0 {
|
||||
return model.Artists{}, nil
|
||||
}
|
||||
// Fetch matching artists from the database
|
||||
// We need to fetch all fields to return complete objects
|
||||
artists, err := r.GetAll(model.QueryOptions{Filters: Eq{"artist.id": ids}})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching artists from meilisearch ids: %w", err)
|
||||
}
|
||||
// Reorder results to match Meilisearch order
|
||||
idMap := make(map[string]model.Artist, len(artists))
|
||||
for _, a := range artists {
|
||||
idMap[a.ID] = a
|
||||
}
|
||||
sorted := make(model.Artists, 0, len(artists))
|
||||
for _, id := range ids {
|
||||
if a, ok := idMap[id]; ok {
|
||||
sorted = append(sorted, a)
|
||||
}
|
||||
}
|
||||
return sorted, nil
|
||||
}
|
||||
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
|
||||
}
|
||||
// Natural order for artists is more performant by ID, due to GROUP BY clause in selectArtist
|
||||
err := r.doSearch(r.selectArtist(options...), q, offset, size, &res, "artist.id",
|
||||
"sum(json_extract(stats, '$.total.m')) desc", "name")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("searching artist by query %q: %w", q, err)
|
||||
}
|
||||
}
|
||||
return res.toModels(), nil
|
||||
}
|
||||
|
||||
func (r *artistRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *artistRepository) Read(id string) (interface{}, error) {
|
||||
return r.Get(id)
|
||||
}
|
||||
|
||||
func (r *artistRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
role := "total"
|
||||
if len(options) > 0 {
|
||||
if v, ok := options[0].Filters["role"].(string); ok {
|
||||
role = v
|
||||
}
|
||||
|
||||
if r.ms != nil {
|
||||
if name, ok := options[0].Filters["name"].(string); ok && name != "" {
|
||||
ids, err := r.ms.Search("artists", name, 0, 10000)
|
||||
if err == nil {
|
||||
log.Debug(r.ctx, "Meilisearch found matches", "count", len(ids), "query", name)
|
||||
delete(options[0].Filters, "name")
|
||||
options[0].Filters["id"] = ids
|
||||
} else {
|
||||
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
r.sortMappings["song_count"] = "sum(stats->>'" + role + "'->>'m')"
|
||||
r.sortMappings["album_count"] = "sum(stats->>'" + role + "'->>'a')"
|
||||
r.sortMappings["size"] = "sum(stats->>'" + role + "'->>'s')"
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *artistRepository) EntityName() string {
|
||||
return "artist"
|
||||
}
|
||||
|
||||
func (r *artistRepository) NewInstance() interface{} {
|
||||
return &model.Artist{}
|
||||
}
|
||||
|
||||
var _ model.ArtistRepository = (*artistRepository)(nil)
|
||||
var _ model.ResourceRepository = (*artistRepository)(nil)
|
||||
772
persistence/artist_repository_test.go
Normal file
772
persistence/artist_repository_test.go
Normal file
@@ -0,0 +1,772 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/conf/configtest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
"github.com/navidrome/navidrome/utils"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Test helper functions to reduce duplication
|
||||
func createTestArtistWithMBID(id, name, mbid string) model.Artist {
|
||||
return model.Artist{
|
||||
ID: id,
|
||||
Name: name,
|
||||
MbzArtistID: mbid,
|
||||
}
|
||||
}
|
||||
|
||||
func createUserWithLibraries(userID string, libraryIDs []int) model.User {
|
||||
user := model.User{
|
||||
ID: userID,
|
||||
UserName: userID,
|
||||
Name: userID,
|
||||
Email: userID + "@test.com",
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
if len(libraryIDs) > 0 {
|
||||
user.Libraries = make(model.Libraries, len(libraryIDs))
|
||||
for i, libID := range libraryIDs {
|
||||
user.Libraries[i] = model.Library{ID: libID, Name: "Test Library", Path: "/test"}
|
||||
}
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
var _ = Describe("ArtistRepository", func() {
|
||||
|
||||
Context("Core Functionality", func() {
|
||||
Describe("GetIndexKey", func() {
|
||||
// Note: OrderArtistName should never be empty, so we don't need to test for that
|
||||
r := artistRepository{indexGroups: utils.ParseIndexGroups(conf.Server.IndexGroups)}
|
||||
|
||||
DescribeTable("returns correct index key based on PreferSortTags setting",
|
||||
func(preferSortTags bool, sortArtistName, orderArtistName, expectedKey string) {
|
||||
DeferCleanup(configtest.SetupConfig())
|
||||
conf.Server.PreferSortTags = preferSortTags
|
||||
a := model.Artist{SortArtistName: sortArtistName, OrderArtistName: orderArtistName, Name: "Test"}
|
||||
idx := GetIndexKey(&r, a)
|
||||
Expect(idx).To(Equal(expectedKey))
|
||||
},
|
||||
Entry("PreferSortTags=false, SortArtistName empty -> uses OrderArtistName", false, "", "Bar", "B"),
|
||||
Entry("PreferSortTags=false, SortArtistName not empty -> still uses OrderArtistName", false, "Foo", "Bar", "B"),
|
||||
Entry("PreferSortTags=true, SortArtistName not empty -> uses SortArtistName", true, "Foo", "Bar", "F"),
|
||||
Entry("PreferSortTags=true, SortArtistName empty -> falls back to OrderArtistName", true, "", "Bar", "B"),
|
||||
)
|
||||
})
|
||||
|
||||
Describe("roleFilter", func() {
|
||||
DescribeTable("validates roles and returns appropriate SQL expressions",
|
||||
func(role string, shouldBeValid bool) {
|
||||
result := roleFilter("", role)
|
||||
if shouldBeValid {
|
||||
expectedExpr := squirrel.Expr("JSON_EXTRACT(library_artist.stats, '$." + role + ".m') IS NOT NULL")
|
||||
Expect(result).To(Equal(expectedExpr))
|
||||
} else {
|
||||
expectedInvalid := squirrel.Eq{"1": 2}
|
||||
Expect(result).To(Equal(expectedInvalid))
|
||||
}
|
||||
},
|
||||
// Valid roles from model.AllRoles
|
||||
Entry("artist role", "artist", true),
|
||||
Entry("albumartist role", "albumartist", true),
|
||||
Entry("composer role", "composer", true),
|
||||
Entry("conductor role", "conductor", true),
|
||||
Entry("lyricist role", "lyricist", true),
|
||||
Entry("arranger role", "arranger", true),
|
||||
Entry("producer role", "producer", true),
|
||||
Entry("director role", "director", true),
|
||||
Entry("engineer role", "engineer", true),
|
||||
Entry("mixer role", "mixer", true),
|
||||
Entry("remixer role", "remixer", true),
|
||||
Entry("djmixer role", "djmixer", true),
|
||||
Entry("performer role", "performer", true),
|
||||
Entry("maincredit role", "maincredit", true),
|
||||
// Invalid roles
|
||||
Entry("invalid role - wizard", "wizard", false),
|
||||
Entry("invalid role - songanddanceman", "songanddanceman", false),
|
||||
Entry("empty string", "", false),
|
||||
Entry("SQL injection attempt", "artist') SELECT LIKE(CHAR(65,66,67,68,69,70,71),UPPER(HEX(RANDOMBLOB(500000000/2))))--", false),
|
||||
)
|
||||
|
||||
It("handles non-string input types", func() {
|
||||
expectedInvalid := squirrel.Eq{"1": 2}
|
||||
Expect(roleFilter("", 123)).To(Equal(expectedInvalid))
|
||||
Expect(roleFilter("", nil)).To(Equal(expectedInvalid))
|
||||
Expect(roleFilter("", []string{"artist"})).To(Equal(expectedInvalid))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("dbArtist mapping", func() {
|
||||
var (
|
||||
artist *model.Artist
|
||||
dba *dbArtist
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
artist = &model.Artist{ID: "1", Name: "Eddie Van Halen", SortArtistName: "Van Halen, Eddie"}
|
||||
dba = &dbArtist{Artist: artist}
|
||||
})
|
||||
|
||||
Describe("PostScan", func() {
|
||||
It("parses stats and similar artists correctly", func() {
|
||||
stats := map[string]map[string]map[string]int64{
|
||||
"1": {
|
||||
"total": {"s": 1000, "m": 10, "a": 2},
|
||||
"composer": {"s": 500, "m": 5, "a": 1},
|
||||
},
|
||||
}
|
||||
statsJSON, _ := json.Marshal(stats)
|
||||
dba.LibraryStatsJSON = string(statsJSON)
|
||||
dba.SimilarArtists = `[{"id":"2","Name":"AC/DC"},{"name":"Test;With:Sep,Chars"}]`
|
||||
|
||||
err := dba.PostScan()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(dba.Artist.Size).To(Equal(int64(1000)))
|
||||
Expect(dba.Artist.SongCount).To(Equal(10))
|
||||
Expect(dba.Artist.AlbumCount).To(Equal(2))
|
||||
Expect(dba.Artist.Stats).To(HaveLen(1))
|
||||
Expect(dba.Artist.Stats[model.RoleFromString("composer")].Size).To(Equal(int64(500)))
|
||||
Expect(dba.Artist.Stats[model.RoleFromString("composer")].SongCount).To(Equal(5))
|
||||
Expect(dba.Artist.Stats[model.RoleFromString("composer")].AlbumCount).To(Equal(1))
|
||||
Expect(dba.Artist.SimilarArtists).To(HaveLen(2))
|
||||
Expect(dba.Artist.SimilarArtists[0].ID).To(Equal("2"))
|
||||
Expect(dba.Artist.SimilarArtists[0].Name).To(Equal("AC/DC"))
|
||||
Expect(dba.Artist.SimilarArtists[1].ID).To(BeEmpty())
|
||||
Expect(dba.Artist.SimilarArtists[1].Name).To(Equal("Test;With:Sep,Chars"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("PostMapArgs", func() {
|
||||
It("maps empty similar artists correctly", func() {
|
||||
m := make(map[string]any)
|
||||
err := dba.PostMapArgs(m)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m).To(HaveKeyWithValue("similar_artists", "[]"))
|
||||
})
|
||||
|
||||
It("maps similar artists and full text correctly", func() {
|
||||
artist.SimilarArtists = []model.Artist{
|
||||
{ID: "2", Name: "AC/DC"},
|
||||
{Name: "Test;With:Sep,Chars"},
|
||||
}
|
||||
m := make(map[string]any)
|
||||
err := dba.PostMapArgs(m)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m).To(HaveKeyWithValue("similar_artists", `[{"id":"2","name":"AC/DC"},{"name":"Test;With:Sep,Chars"}]`))
|
||||
Expect(m).To(HaveKeyWithValue("full_text", " eddie halen van"))
|
||||
})
|
||||
|
||||
It("does not override empty sort_artist_name and mbz_artist_id", func() {
|
||||
m := map[string]any{
|
||||
"sort_artist_name": "",
|
||||
"mbz_artist_id": "",
|
||||
}
|
||||
err := dba.PostMapArgs(m)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(m).ToNot(HaveKey("sort_artist_name"))
|
||||
Expect(m).ToNot(HaveKey("mbz_artist_id"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Admin User Operations", func() {
|
||||
var repo model.ArtistRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := GinkgoT().Context()
|
||||
ctx = request.WithUser(ctx, adminUser)
|
||||
repo = NewArtistRepository(ctx, GetDBXBuilder(), nil).(*artistRepository)
|
||||
})
|
||||
|
||||
Describe("Basic Operations", func() {
|
||||
Describe("Count", func() {
|
||||
It("returns the number of artists in the DB", func() {
|
||||
Expect(repo.CountAll()).To(Equal(int64(2)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Exists", func() {
|
||||
It("returns true for an artist that is in the DB", func() {
|
||||
Expect(repo.Exists("3")).To(BeTrue())
|
||||
})
|
||||
It("returns false for an artist that is NOT in the DB", func() {
|
||||
Expect(repo.Exists("666")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Get", func() {
|
||||
It("retrieves existing artist data", func() {
|
||||
artist, err := repo.Get("2")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(artist.Name).To(Equal(artistKraftwerk.Name))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetIndex", func() {
|
||||
When("PreferSortTags is true", func() {
|
||||
BeforeEach(func() {
|
||||
conf.Server.PreferSortTags = true
|
||||
})
|
||||
It("returns the index when PreferSortTags is true and SortArtistName is not empty", func() {
|
||||
// Set SortArtistName to "Foo" for Beatles
|
||||
artistBeatles.SortArtistName = "Foo"
|
||||
er := repo.Put(&artistBeatles)
|
||||
Expect(er).To(BeNil())
|
||||
|
||||
idx, err := repo.GetIndex(false, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(2))
|
||||
Expect(idx[0].ID).To(Equal("F"))
|
||||
Expect(idx[0].Artists).To(HaveLen(1))
|
||||
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
|
||||
Expect(idx[1].ID).To(Equal("K"))
|
||||
Expect(idx[1].Artists).To(HaveLen(1))
|
||||
Expect(idx[1].Artists[0].Name).To(Equal(artistKraftwerk.Name))
|
||||
|
||||
// Restore the original value
|
||||
artistBeatles.SortArtistName = ""
|
||||
er = repo.Put(&artistBeatles)
|
||||
Expect(er).To(BeNil())
|
||||
})
|
||||
|
||||
// BFR Empty SortArtistName is not saved in the DB anymore
|
||||
XIt("returns the index when PreferSortTags is true and SortArtistName is empty", func() {
|
||||
idx, err := repo.GetIndex(false, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(2))
|
||||
Expect(idx[0].ID).To(Equal("B"))
|
||||
Expect(idx[0].Artists).To(HaveLen(1))
|
||||
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
|
||||
Expect(idx[1].ID).To(Equal("K"))
|
||||
Expect(idx[1].Artists).To(HaveLen(1))
|
||||
Expect(idx[1].Artists[0].Name).To(Equal(artistKraftwerk.Name))
|
||||
})
|
||||
})
|
||||
|
||||
When("PreferSortTags is false", func() {
|
||||
BeforeEach(func() {
|
||||
conf.Server.PreferSortTags = false
|
||||
})
|
||||
It("returns the index when SortArtistName is NOT empty", func() {
|
||||
// Set SortArtistName to "Foo" for Beatles
|
||||
artistBeatles.SortArtistName = "Foo"
|
||||
er := repo.Put(&artistBeatles)
|
||||
Expect(er).To(BeNil())
|
||||
|
||||
idx, err := repo.GetIndex(false, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(2))
|
||||
Expect(idx[0].ID).To(Equal("B"))
|
||||
Expect(idx[0].Artists).To(HaveLen(1))
|
||||
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
|
||||
Expect(idx[1].ID).To(Equal("K"))
|
||||
Expect(idx[1].Artists).To(HaveLen(1))
|
||||
Expect(idx[1].Artists[0].Name).To(Equal(artistKraftwerk.Name))
|
||||
|
||||
// Restore the original value
|
||||
artistBeatles.SortArtistName = ""
|
||||
er = repo.Put(&artistBeatles)
|
||||
Expect(er).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns the index when SortArtistName is empty", func() {
|
||||
idx, err := repo.GetIndex(false, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(2))
|
||||
Expect(idx[0].ID).To(Equal("B"))
|
||||
Expect(idx[0].Artists).To(HaveLen(1))
|
||||
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
|
||||
Expect(idx[1].ID).To(Equal("K"))
|
||||
Expect(idx[1].Artists).To(HaveLen(1))
|
||||
Expect(idx[1].Artists[0].Name).To(Equal(artistKraftwerk.Name))
|
||||
})
|
||||
})
|
||||
|
||||
When("filtering by role", func() {
|
||||
var raw *artistRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
raw = repo.(*artistRepository)
|
||||
// Add stats to library_artist table since stats are now stored per-library
|
||||
composerStats := `{"composer": {"s": 1000, "m": 5, "a": 2}}`
|
||||
producerStats := `{"producer": {"s": 500, "m": 3, "a": 1}}`
|
||||
|
||||
// Set Beatles as composer in library 1
|
||||
_, err := raw.executeSQL(squirrel.Insert("library_artist").
|
||||
Columns("library_id", "artist_id", "stats").
|
||||
Values(1, artistBeatles.ID, composerStats).
|
||||
Suffix("ON CONFLICT(library_id, artist_id) DO UPDATE SET stats = excluded.stats"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Set Kraftwerk as producer in library 1
|
||||
_, err = raw.executeSQL(squirrel.Insert("library_artist").
|
||||
Columns("library_id", "artist_id", "stats").
|
||||
Values(1, artistKraftwerk.ID, producerStats).
|
||||
Suffix("ON CONFLICT(library_id, artist_id) DO UPDATE SET stats = excluded.stats"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up stats from library_artist table
|
||||
_, _ = raw.executeSQL(squirrel.Update("library_artist").
|
||||
Set("stats", "{}").
|
||||
Where(squirrel.Eq{"artist_id": artistBeatles.ID, "library_id": 1}))
|
||||
_, _ = raw.executeSQL(squirrel.Update("library_artist").
|
||||
Set("stats", "{}").
|
||||
Where(squirrel.Eq{"artist_id": artistKraftwerk.ID, "library_id": 1}))
|
||||
})
|
||||
|
||||
It("returns only artists with the specified role", func() {
|
||||
idx, err := repo.GetIndex(false, []int{1}, model.RoleComposer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(1))
|
||||
Expect(idx[0].ID).To(Equal("B"))
|
||||
Expect(idx[0].Artists).To(HaveLen(1))
|
||||
Expect(idx[0].Artists[0].Name).To(Equal(artistBeatles.Name))
|
||||
})
|
||||
|
||||
It("returns artists with any of the specified roles", func() {
|
||||
idx, err := repo.GetIndex(false, []int{1}, model.RoleComposer, model.RoleProducer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(2))
|
||||
|
||||
// Find Beatles and Kraftwerk in the results
|
||||
var beatlesFound, kraftwerkFound bool
|
||||
for _, index := range idx {
|
||||
for _, artist := range index.Artists {
|
||||
if artist.Name == artistBeatles.Name {
|
||||
beatlesFound = true
|
||||
}
|
||||
if artist.Name == artistKraftwerk.Name {
|
||||
kraftwerkFound = true
|
||||
}
|
||||
}
|
||||
}
|
||||
Expect(beatlesFound).To(BeTrue())
|
||||
Expect(kraftwerkFound).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns empty index when no artists have the specified role", func() {
|
||||
idx, err := repo.GetIndex(false, []int{1}, model.RoleDirector)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
|
||||
When("validating library IDs", func() {
|
||||
It("returns nil when no library IDs are provided", func() {
|
||||
idx, err := repo.GetIndex(false, []int{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns artists when library IDs are provided (admin user sees all content)", func() {
|
||||
// Admin users can see all content when valid library IDs are provided
|
||||
idx, err := repo.GetIndex(false, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(2))
|
||||
|
||||
// With non-existent library ID, admin users see no content because no artists are associated with that library
|
||||
idx, err = repo.GetIndex(false, []int{999})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(0)) // Even admin users need valid library associations
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("MBID and Text Search", func() {
|
||||
var lib2 model.Library
|
||||
var lr model.LibraryRepository
|
||||
var restrictedUser model.User
|
||||
var restrictedRepo model.ArtistRepository
|
||||
var headlessRepo model.ArtistRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
// Set up headless repo (no user context)
|
||||
headlessRepo = NewArtistRepository(context.Background(), GetDBXBuilder(), nil)
|
||||
|
||||
// Create library for testing access restrictions
|
||||
lib2 = model.Library{ID: 0, Name: "Artist Test Library", Path: "/artist/test/lib"}
|
||||
lr = NewLibraryRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
|
||||
err := lr.Put(&lib2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create a user with access to only library 1
|
||||
restrictedUser = createUserWithLibraries("search_user", []int{1})
|
||||
|
||||
// Create repository context for the restricted user
|
||||
ctx := request.WithUser(GinkgoT().Context(), restrictedUser)
|
||||
restrictedRepo = NewArtistRepository(ctx, GetDBXBuilder(), nil)
|
||||
|
||||
// Ensure both test artists are associated with library 1
|
||||
err = lr.AddArtist(1, artistBeatles.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = lr.AddArtist(1, artistKraftwerk.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create the restricted user in the database
|
||||
ur := NewUserRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
|
||||
err = ur.Put(&restrictedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = ur.SetUserLibraries(restrictedUser.ID, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up library 2
|
||||
lr := NewLibraryRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
|
||||
_ = lr.(*libraryRepository).delete(squirrel.Eq{"id": lib2.ID})
|
||||
})
|
||||
|
||||
DescribeTable("MBID search behavior across different user types",
|
||||
func(testRepo *model.ArtistRepository, shouldFind bool, testDesc string) {
|
||||
// Create test artist with MBID
|
||||
artistWithMBID := createTestArtistWithMBID("test-mbid-artist", "Test MBID Artist", "550e8400-e29b-41d4-a716-446655440010")
|
||||
|
||||
err := createArtistWithLibrary(*testRepo, &artistWithMBID, 1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Test the search
|
||||
results, err := (*testRepo).Search("550e8400-e29b-41d4-a716-446655440010", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
if shouldFind {
|
||||
Expect(results).To(HaveLen(1), testDesc)
|
||||
Expect(results[0].ID).To(Equal("test-mbid-artist"))
|
||||
} else {
|
||||
Expect(results).To(BeEmpty(), testDesc)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if raw, ok := (*testRepo).(*artistRepository); ok {
|
||||
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": artistWithMBID.ID}))
|
||||
}
|
||||
},
|
||||
Entry("Admin user can find artist by MBID", &repo, true, "Admin should find MBID artist"),
|
||||
Entry("Restricted user can find artist by MBID in accessible library", &restrictedRepo, true, "Restricted user should find MBID artist in accessible library"),
|
||||
Entry("Headless process can find artist by MBID", &headlessRepo, true, "Headless process should find MBID artist"),
|
||||
)
|
||||
|
||||
It("prevents restricted user from finding artist by MBID when not in accessible library", func() {
|
||||
// Create an artist in library 2 (not accessible to restricted user)
|
||||
inaccessibleArtist := createTestArtistWithMBID("inaccessible-mbid-artist", "Inaccessible MBID Artist", "a74b1b7f-71a5-4011-9441-d0b5e4122711")
|
||||
err := repo.Put(&inaccessibleArtist)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Add to library 2 (not accessible to restricted user)
|
||||
err = lr.AddArtist(lib2.ID, inaccessibleArtist.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Restricted user should not find this artist
|
||||
results, err := restrictedRepo.Search("a74b1b7f-71a5-4011-9441-d0b5e4122711", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
|
||||
// But admin should find it
|
||||
results, err = repo.Search("a74b1b7f-71a5-4011-9441-d0b5e4122711", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(1))
|
||||
|
||||
// Clean up
|
||||
if raw, ok := repo.(*artistRepository); ok {
|
||||
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": inaccessibleArtist.ID}))
|
||||
}
|
||||
})
|
||||
|
||||
Context("Text Search", func() {
|
||||
It("allows admin to find artists by name regardless of library", func() {
|
||||
results, err := repo.Search("Beatles", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].Name).To(Equal("The Beatles"))
|
||||
})
|
||||
|
||||
It("correctly prevents restricted user from finding artists by name when not in accessible library", func() {
|
||||
// Create an artist in library 2 (not accessible to restricted user)
|
||||
inaccessibleArtist := model.Artist{
|
||||
ID: "inaccessible-text-artist",
|
||||
Name: "Unique Search Name Artist",
|
||||
}
|
||||
err := repo.Put(&inaccessibleArtist)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Add to library 2 (not accessible to restricted user)
|
||||
err = lr.AddArtist(lib2.ID, inaccessibleArtist.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Restricted user should not find this artist
|
||||
results, err := restrictedRepo.Search("Unique Search Name", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty(), "Text search should respect library filtering")
|
||||
|
||||
// Clean up
|
||||
if raw, ok := repo.(*artistRepository); ok {
|
||||
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": inaccessibleArtist.ID}))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("Headless Processes (No User Context)", func() {
|
||||
It("should see all artists from all libraries when no user is in context", func() {
|
||||
// Add artists to different libraries
|
||||
err := lr.AddArtist(lib2.ID, artistBeatles.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Headless processes should see all artists regardless of library
|
||||
artists, err := headlessRepo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Should see all artists from all libraries
|
||||
found := false
|
||||
for _, artist := range artists {
|
||||
if artist.ID == artistBeatles.ID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).To(BeTrue(), "Headless process should see artists from all libraries")
|
||||
})
|
||||
|
||||
It("should allow headless processes to apply explicit library_id filters", func() {
|
||||
// Add artists to different libraries
|
||||
err := lr.AddArtist(lib2.ID, artistBeatles.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Filter by specific library
|
||||
artists, err := headlessRepo.GetAll(model.QueryOptions{
|
||||
Filters: squirrel.Eq{"library_id": lib2.ID},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Should see only artists from the specified library
|
||||
for _, artist := range artists {
|
||||
if artist.ID == artistBeatles.ID {
|
||||
return // Found the expected artist
|
||||
}
|
||||
}
|
||||
Expect(false).To(BeTrue(), "Should find artist from specified library")
|
||||
})
|
||||
|
||||
It("should get individual artists when no user is in context", func() {
|
||||
// Add artist to a library
|
||||
err := lr.AddArtist(lib2.ID, artistBeatles.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Headless process should be able to get the artist
|
||||
artist, err := headlessRepo.Get(artistBeatles.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(artist.ID).To(Equal(artistBeatles.ID))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Admin User Library Access", func() {
|
||||
It("sees all artists regardless of library permissions", func() {
|
||||
count, err := repo.CountAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(2)))
|
||||
|
||||
artists, err := repo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(artists).To(HaveLen(2))
|
||||
|
||||
exists, err := repo.Exists(artistBeatles.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Missing Artist Handling", func() {
|
||||
var missingArtist model.Artist
|
||||
var raw *artistRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
raw = repo.(*artistRepository)
|
||||
missingArtist = model.Artist{ID: "missing_test", Name: "Missing Artist", OrderArtistName: "missing artist"}
|
||||
|
||||
// Create and mark as missing
|
||||
err := createArtistWithLibrary(repo, &missingArtist, 1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = raw.executeSQL(squirrel.Update(raw.tableName).Set("missing", true).Where(squirrel.Eq{"id": missingArtist.ID}))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": missingArtist.ID}))
|
||||
})
|
||||
|
||||
It("missing artists are never returned by search", func() {
|
||||
// Should see missing artist in GetAll by default for admin users
|
||||
artists, err := repo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(artists).To(HaveLen(3)) // Including the missing artist
|
||||
|
||||
// Search never returns missing artists (hardcoded behavior)
|
||||
results, err := repo.Search("Missing Artist", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Regular User Operations", func() {
|
||||
var restrictedRepo model.ArtistRepository
|
||||
var unauthorizedUser model.User
|
||||
|
||||
BeforeEach(func() {
|
||||
// Create a user without access to any libraries
|
||||
unauthorizedUser = model.User{ID: "restricted_user", UserName: "restricted", Name: "Restricted User", Email: "restricted@test.com", IsAdmin: false}
|
||||
|
||||
// Create repository context for the unauthorized user
|
||||
ctx := GinkgoT().Context()
|
||||
ctx = request.WithUser(ctx, unauthorizedUser)
|
||||
restrictedRepo = NewArtistRepository(ctx, GetDBXBuilder(), nil)
|
||||
})
|
||||
|
||||
Describe("Library Access Restrictions", func() {
|
||||
It("CountAll returns 0 for users without library access", func() {
|
||||
count, err := restrictedRepo.CountAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(0)))
|
||||
})
|
||||
|
||||
It("GetAll returns empty list for users without library access", func() {
|
||||
artists, err := restrictedRepo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(artists).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("Exists returns false for existing artists when user has no library access", func() {
|
||||
// These artists exist in the DB but the user has no access to them
|
||||
exists, err := restrictedRepo.Exists(artistBeatles.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeFalse())
|
||||
|
||||
exists, err = restrictedRepo.Exists(artistKraftwerk.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeFalse())
|
||||
})
|
||||
|
||||
It("Get returns ErrNotFound for existing artists when user has no library access", func() {
|
||||
_, err := restrictedRepo.Get(artistBeatles.ID)
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
|
||||
_, err = restrictedRepo.Get(artistKraftwerk.ID)
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
})
|
||||
|
||||
It("Search returns empty results for users without library access", func() {
|
||||
results, err := restrictedRepo.Search("Beatles", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
|
||||
results, err = restrictedRepo.Search("Kraftwerk", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("GetIndex returns empty index for users without library access", func() {
|
||||
idx, err := restrictedRepo.GetIndex(false, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when user gains library access", func() {
|
||||
BeforeEach(func() {
|
||||
ctx := GinkgoT().Context()
|
||||
// Give the user access to library 1
|
||||
ur := NewUserRepository(request.WithUser(ctx, adminUser), GetDBXBuilder())
|
||||
|
||||
// First create the user if not exists
|
||||
err := ur.Put(&unauthorizedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Then add library access
|
||||
err = ur.SetUserLibraries(unauthorizedUser.ID, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Update the user object with the libraries to simulate middleware behavior
|
||||
libraries, err := ur.GetUserLibraries(unauthorizedUser.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
unauthorizedUser.Libraries = libraries
|
||||
|
||||
// Recreate repository context with updated user
|
||||
ctx = request.WithUser(ctx, unauthorizedUser)
|
||||
restrictedRepo = NewArtistRepository(ctx, GetDBXBuilder(), nil)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up: remove the user's library access
|
||||
ur := NewUserRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
|
||||
_ = ur.SetUserLibraries(unauthorizedUser.ID, []int{})
|
||||
})
|
||||
|
||||
It("CountAll returns correct count after gaining access", func() {
|
||||
count, err := restrictedRepo.CountAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(2))) // Beatles and Kraftwerk
|
||||
})
|
||||
|
||||
It("GetAll returns artists after gaining access", func() {
|
||||
artists, err := restrictedRepo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(artists).To(HaveLen(2))
|
||||
|
||||
var names []string
|
||||
for _, artist := range artists {
|
||||
names = append(names, artist.Name)
|
||||
}
|
||||
Expect(names).To(ContainElements("The Beatles", "Kraftwerk"))
|
||||
})
|
||||
|
||||
It("Exists returns true for accessible artists", func() {
|
||||
exists, err := restrictedRepo.Exists(artistBeatles.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeTrue())
|
||||
|
||||
exists, err = restrictedRepo.Exists(artistKraftwerk.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(exists).To(BeTrue())
|
||||
})
|
||||
|
||||
It("GetIndex returns artists with proper library filtering", func() {
|
||||
// With valid library access, should see artists
|
||||
idx, err := restrictedRepo.GetIndex(false, []int{1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(2))
|
||||
|
||||
// With non-existent library ID, should see nothing (non-admin user)
|
||||
idx, err = restrictedRepo.GetIndex(false, []int{999})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(idx).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Helper function to create an artist with proper library association.
|
||||
// This ensures test artists always have library_artist associations to avoid orphaned artists in tests.
|
||||
func createArtistWithLibrary(repo model.ArtistRepository, artist *model.Artist, libraryID int) error {
|
||||
err := repo.Put(artist)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the artist to the specified library
|
||||
lr := NewLibraryRepository(request.WithUser(GinkgoT().Context(), adminUser), GetDBXBuilder())
|
||||
return lr.AddArtist(libraryID, artist.ID)
|
||||
}
|
||||
122
persistence/collation_test.go
Normal file
122
persistence/collation_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/navidrome/navidrome/db"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// When creating migrations that change existing columns, it is easy to miss the original collation of a column.
|
||||
// These tests enforce that the required collation of the columns and indexes in the database are kept in place.
|
||||
// This is important to ensure that the database can perform fast case-insensitive searches and sorts.
|
||||
var _ = Describe("Collation", func() {
|
||||
conn := db.Db()
|
||||
DescribeTable("Column collation",
|
||||
func(table, column string) {
|
||||
Expect(checkCollation(conn, table, column)).To(Succeed())
|
||||
},
|
||||
Entry("artist.order_artist_name", "artist", "order_artist_name"),
|
||||
Entry("artist.sort_artist_name", "artist", "sort_artist_name"),
|
||||
Entry("album.order_album_name", "album", "order_album_name"),
|
||||
Entry("album.order_album_artist_name", "album", "order_album_artist_name"),
|
||||
Entry("album.sort_album_name", "album", "sort_album_name"),
|
||||
Entry("album.sort_album_artist_name", "album", "sort_album_artist_name"),
|
||||
Entry("media_file.order_title", "media_file", "order_title"),
|
||||
Entry("media_file.order_album_name", "media_file", "order_album_name"),
|
||||
Entry("media_file.order_artist_name", "media_file", "order_artist_name"),
|
||||
Entry("media_file.sort_title", "media_file", "sort_title"),
|
||||
Entry("media_file.sort_album_name", "media_file", "sort_album_name"),
|
||||
Entry("media_file.sort_artist_name", "media_file", "sort_artist_name"),
|
||||
Entry("radio.name", "radio", "name"),
|
||||
Entry("user.name", "user", "name"),
|
||||
)
|
||||
|
||||
DescribeTable("Index collation",
|
||||
func(table, column string) {
|
||||
Expect(checkIndexUsage(conn, table, column)).To(Succeed())
|
||||
},
|
||||
Entry("artist.order_artist_name", "artist", "order_artist_name collate nocase"),
|
||||
Entry("artist.sort_artist_name", "artist", "coalesce(nullif(sort_artist_name,''),order_artist_name) collate nocase"),
|
||||
Entry("album.order_album_name", "album", "order_album_name collate nocase"),
|
||||
Entry("album.order_album_artist_name", "album", "order_album_artist_name collate nocase"),
|
||||
Entry("album.sort_album_name", "album", "coalesce(nullif(sort_album_name,''),order_album_name) collate nocase"),
|
||||
Entry("album.sort_album_artist_name", "album", "coalesce(nullif(sort_album_artist_name,''),order_album_artist_name) collate nocase"),
|
||||
Entry("media_file.order_title", "media_file", "order_title collate nocase"),
|
||||
Entry("media_file.order_album_name", "media_file", "order_album_name collate nocase"),
|
||||
Entry("media_file.order_artist_name", "media_file", "order_artist_name collate nocase"),
|
||||
Entry("media_file.sort_title", "media_file", "coalesce(nullif(sort_title,''),order_title) collate nocase"),
|
||||
Entry("media_file.sort_album_name", "media_file", "coalesce(nullif(sort_album_name,''),order_album_name) collate nocase"),
|
||||
Entry("media_file.sort_artist_name", "media_file", "coalesce(nullif(sort_artist_name,''),order_artist_name) collate nocase"),
|
||||
Entry("media_file.path", "media_file", "path collate nocase"),
|
||||
Entry("radio.name", "radio", "name collate nocase"),
|
||||
Entry("user.user_name", "user", "user_name collate nocase"),
|
||||
)
|
||||
})
|
||||
|
||||
func checkIndexUsage(conn *sql.DB, table string, column string) error {
|
||||
rows, err := conn.Query(fmt.Sprintf(`
|
||||
explain query plan select * from %[1]s
|
||||
where %[2]s = 'test'
|
||||
order by %[2]s`, table, column))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
var dummy int
|
||||
var detail string
|
||||
err = rows.Scan(&dummy, &dummy, &dummy, &detail)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if ok, _ := regexp.MatchString("SEARCH.*USING INDEX", detail); ok {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("INDEX for '%s' not used: %s", column, detail)
|
||||
}
|
||||
}
|
||||
return errors.New("no rows returned")
|
||||
}
|
||||
|
||||
func checkCollation(conn *sql.DB, table string, column string) error {
|
||||
rows, err := conn.Query(fmt.Sprintf("SELECT sql FROM sqlite_master WHERE type='table' AND tbl_name='%s'", table))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
var res string
|
||||
err = rows.Scan(&res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
re := regexp.MustCompile(fmt.Sprintf(`(?i)\b%s\b.*varchar`, column))
|
||||
if !re.MatchString(res) {
|
||||
return fmt.Errorf("column '%s' not found in table '%s'", column, table)
|
||||
}
|
||||
re = regexp.MustCompile(fmt.Sprintf(`(?i)\b%s\b.*collate\s+NOCASE`, column))
|
||||
if re.MatchString(res) {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("table '%s' not found", table)
|
||||
}
|
||||
return fmt.Errorf("column '%s' in table '%s' does not have NOCASE collation", column, table)
|
||||
}
|
||||
4
persistence/export_test.go
Normal file
4
persistence/export_test.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package persistence
|
||||
|
||||
// Definitions for testing private methods
|
||||
var GetIndexKey = (*artistRepository).getIndexKey
|
||||
216
persistence/folder_repository.go
Normal file
216
persistence/folder_repository.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type folderRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
type dbFolder struct {
|
||||
*model.Folder `structs:",flatten"`
|
||||
ImageFiles string `structs:"-" json:"-"`
|
||||
}
|
||||
|
||||
func (f *dbFolder) PostScan() error {
|
||||
var err error
|
||||
if f.ImageFiles != "" {
|
||||
if err = json.Unmarshal([]byte(f.ImageFiles), &f.Folder.ImageFiles); err != nil {
|
||||
return fmt.Errorf("parsing folder image files from db: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *dbFolder) PostMapArgs(args map[string]any) error {
|
||||
if f.Folder.ImageFiles == nil {
|
||||
args["image_files"] = "[]"
|
||||
} else {
|
||||
imgFiles, err := json.Marshal(f.Folder.ImageFiles)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling image files: %w", err)
|
||||
}
|
||||
args["image_files"] = string(imgFiles)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbFolders []dbFolder
|
||||
|
||||
func (fs dbFolders) toModels() []model.Folder {
|
||||
return slice.Map(fs, func(f dbFolder) model.Folder { return *f.Folder })
|
||||
}
|
||||
|
||||
func newFolderRepository(ctx context.Context, db dbx.Builder) model.FolderRepository {
|
||||
r := &folderRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.tableName = "folder"
|
||||
return r
|
||||
}
|
||||
|
||||
func (r folderRepository) selectFolder(options ...model.QueryOptions) SelectBuilder {
|
||||
sql := r.newSelect(options...).Columns("folder.*", "library.path as library_path").
|
||||
Join("library on library.id = folder.library_id")
|
||||
return r.applyLibraryFilter(sql)
|
||||
}
|
||||
|
||||
func (r folderRepository) Get(id string) (*model.Folder, error) {
|
||||
sq := r.selectFolder().Where(Eq{"folder.id": id})
|
||||
var res dbFolder
|
||||
err := r.queryOne(sq, &res)
|
||||
return res.Folder, err
|
||||
}
|
||||
|
||||
func (r folderRepository) GetByPath(lib model.Library, path string) (*model.Folder, error) {
|
||||
id := model.NewFolder(lib, path).ID
|
||||
return r.Get(id)
|
||||
}
|
||||
|
||||
func (r folderRepository) GetAll(opt ...model.QueryOptions) ([]model.Folder, error) {
|
||||
sq := r.selectFolder(opt...)
|
||||
var res dbFolders
|
||||
err := r.queryAll(sq, &res)
|
||||
return res.toModels(), err
|
||||
}
|
||||
|
||||
func (r folderRepository) CountAll(opt ...model.QueryOptions) (int64, error) {
|
||||
query := r.newSelect(opt...).Columns("count(*)")
|
||||
query = r.applyLibraryFilter(query)
|
||||
return r.count(query)
|
||||
}
|
||||
|
||||
func (r folderRepository) GetFolderUpdateInfo(lib model.Library, targetPaths ...string) (map[string]model.FolderUpdateInfo, error) {
|
||||
where := And{
|
||||
Eq{"library_id": lib.ID},
|
||||
Eq{"missing": false},
|
||||
}
|
||||
|
||||
// If specific paths are requested, include those folders and all their descendants
|
||||
if len(targetPaths) > 0 {
|
||||
// Collect folder IDs for exact target folders and path conditions for descendants
|
||||
folderIDs := make([]string, 0, len(targetPaths))
|
||||
pathConditions := make(Or, 0, len(targetPaths)*2)
|
||||
|
||||
for _, targetPath := range targetPaths {
|
||||
if targetPath == "" || targetPath == "." {
|
||||
// Root path - include everything in this library
|
||||
pathConditions = Or{}
|
||||
folderIDs = nil
|
||||
break
|
||||
}
|
||||
// Clean the path to normalize it. Paths stored in the folder table do not have leading/trailing slashes.
|
||||
cleanPath := strings.TrimPrefix(targetPath, string(os.PathSeparator))
|
||||
cleanPath = filepath.Clean(cleanPath)
|
||||
|
||||
// Include the target folder itself by ID
|
||||
folderIDs = append(folderIDs, model.FolderID(lib, cleanPath))
|
||||
|
||||
// Include all descendants: folders whose path field equals or starts with the target path
|
||||
// Note: Folder.Path is the directory path, so children have path = targetPath
|
||||
pathConditions = append(pathConditions, Eq{"path": cleanPath})
|
||||
pathConditions = append(pathConditions, Like{"path": cleanPath + "/%"})
|
||||
}
|
||||
|
||||
// Combine conditions: exact folder IDs OR descendant path patterns
|
||||
if len(folderIDs) > 0 {
|
||||
where = append(where, Or{Eq{"id": folderIDs}, pathConditions})
|
||||
} else if len(pathConditions) > 0 {
|
||||
where = append(where, pathConditions)
|
||||
}
|
||||
}
|
||||
|
||||
sq := r.newSelect().Columns("id", "updated_at", "hash").Where(where)
|
||||
var res []struct {
|
||||
ID string
|
||||
UpdatedAt time.Time
|
||||
Hash string
|
||||
}
|
||||
err := r.queryAll(sq, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := make(map[string]model.FolderUpdateInfo, len(res))
|
||||
for _, f := range res {
|
||||
m[f.ID] = model.FolderUpdateInfo{UpdatedAt: f.UpdatedAt, Hash: f.Hash}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (r folderRepository) Put(f *model.Folder) error {
|
||||
dbf := dbFolder{Folder: f}
|
||||
_, err := r.put(dbf.ID, &dbf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r folderRepository) MarkMissing(missing bool, ids ...string) error {
|
||||
log.Debug(r.ctx, "Marking folders as missing", "ids", ids, "missing", missing)
|
||||
for chunk := range slices.Chunk(ids, 200) {
|
||||
sq := Update(r.tableName).
|
||||
Set("missing", missing).
|
||||
Set("updated_at", time.Now()).
|
||||
Where(Eq{"id": chunk})
|
||||
_, err := r.executeSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r folderRepository) GetTouchedWithPlaylists() (model.FolderCursor, error) {
|
||||
query := r.selectFolder().Where(And{
|
||||
Eq{"missing": false},
|
||||
Gt{"num_playlists": 0},
|
||||
ConcatExpr("folder.updated_at > library.last_scan_at"),
|
||||
})
|
||||
cursor, err := queryWithStableResults[dbFolder](r.sqlRepository, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(yield func(model.Folder, error) bool) {
|
||||
for f, err := range cursor {
|
||||
if !yield(*f.Folder, err) || err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r folderRepository) purgeEmpty(libraryIDs ...int) error {
|
||||
sq := Delete(r.tableName).Where(And{
|
||||
Eq{"num_audio_files": 0},
|
||||
Eq{"num_playlists": 0},
|
||||
Eq{"image_files": "[]"},
|
||||
ConcatExpr("id not in (select parent_id from folder)"),
|
||||
ConcatExpr("id not in (select folder_id from media_file)"),
|
||||
})
|
||||
// If libraryIDs are specified, only purge folders from those libraries
|
||||
if len(libraryIDs) > 0 {
|
||||
sq = sq.Where(Eq{"library_id": libraryIDs})
|
||||
}
|
||||
c, err := r.executeSQL(sq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("purging empty folders: %w", err)
|
||||
}
|
||||
if c > 0 {
|
||||
log.Debug(r.ctx, "Purging empty folders", "totalDeleted", c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ model.FolderRepository = (*folderRepository)(nil)
|
||||
213
persistence/folder_repository_test.go
Normal file
213
persistence/folder_repository_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ = Describe("FolderRepository", func() {
|
||||
var repo model.FolderRepository
|
||||
var ctx context.Context
|
||||
var conn *dbx.DB
|
||||
var testLib, otherLib model.Library
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx = request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid"})
|
||||
conn = GetDBXBuilder()
|
||||
repo = newFolderRepository(ctx, conn)
|
||||
|
||||
// Use existing library ID 1 from test fixtures
|
||||
libRepo := NewLibraryRepository(ctx, conn)
|
||||
lib, err := libRepo.Get(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
testLib = *lib
|
||||
|
||||
// Create a second library with its own folder to verify isolation
|
||||
otherLib = model.Library{Name: "Other Library", Path: "/other/path"}
|
||||
Expect(libRepo.Put(&otherLib)).To(Succeed())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up only test folders created by our tests (paths starting with "Test")
|
||||
// This prevents interference with fixture data needed by other tests
|
||||
_, _ = conn.NewQuery("DELETE FROM folder WHERE library_id = 1 AND path LIKE 'Test%'").Execute()
|
||||
_, _ = conn.NewQuery(fmt.Sprintf("DELETE FROM library WHERE id = %d", otherLib.ID)).Execute()
|
||||
})
|
||||
|
||||
Describe("GetFolderUpdateInfo", func() {
|
||||
Context("with no target paths", func() {
|
||||
It("returns all folders in the library", func() {
|
||||
// Create test folders with unique names to avoid conflicts
|
||||
folder1 := model.NewFolder(testLib, "TestGetLastUpdates/Folder1")
|
||||
folder2 := model.NewFolder(testLib, "TestGetLastUpdates/Folder2")
|
||||
|
||||
err := repo.Put(folder1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = repo.Put(folder2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
otherFolder := model.NewFolder(otherLib, "TestOtherLib/Folder")
|
||||
err = repo.Put(otherFolder)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Query all folders (no target paths) - should only return folders from testLib
|
||||
results, err := repo.GetFolderUpdateInfo(testLib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Should include folders from testLib
|
||||
Expect(results).To(HaveKey(folder1.ID))
|
||||
Expect(results).To(HaveKey(folder2.ID))
|
||||
// Should NOT include folders from other library
|
||||
Expect(results).ToNot(HaveKey(otherFolder.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with specific target paths", func() {
|
||||
It("returns folder info for existing folders", func() {
|
||||
// Create test folders with unique names
|
||||
folder1 := model.NewFolder(testLib, "TestSpecific/Rock")
|
||||
folder2 := model.NewFolder(testLib, "TestSpecific/Jazz")
|
||||
folder3 := model.NewFolder(testLib, "TestSpecific/Classical")
|
||||
|
||||
err := repo.Put(folder1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = repo.Put(folder2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = repo.Put(folder3)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Query specific paths
|
||||
results, err := repo.GetFolderUpdateInfo(testLib, "TestSpecific/Rock", "TestSpecific/Classical")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(2))
|
||||
|
||||
// Verify folder IDs are in results
|
||||
Expect(results).To(HaveKey(folder1.ID))
|
||||
Expect(results).To(HaveKey(folder3.ID))
|
||||
Expect(results).ToNot(HaveKey(folder2.ID))
|
||||
|
||||
// Verify update info is populated
|
||||
Expect(results[folder1.ID].UpdatedAt).ToNot(BeZero())
|
||||
Expect(results[folder1.ID].Hash).To(Equal(folder1.Hash))
|
||||
})
|
||||
|
||||
It("includes all child folders when querying parent", func() {
|
||||
// Create a parent folder with multiple children
|
||||
parent := model.NewFolder(testLib, "TestParent/Music")
|
||||
child1 := model.NewFolder(testLib, "TestParent/Music/Rock/Queen")
|
||||
child2 := model.NewFolder(testLib, "TestParent/Music/Jazz")
|
||||
otherParent := model.NewFolder(testLib, "TestParent2/Music/Jazz")
|
||||
|
||||
Expect(repo.Put(parent)).To(Succeed())
|
||||
Expect(repo.Put(child1)).To(Succeed())
|
||||
Expect(repo.Put(child2)).To(Succeed())
|
||||
|
||||
// Query the parent folder - should return parent and all children
|
||||
results, err := repo.GetFolderUpdateInfo(testLib, "TestParent/Music")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(3))
|
||||
Expect(results).To(HaveKey(parent.ID))
|
||||
Expect(results).To(HaveKey(child1.ID))
|
||||
Expect(results).To(HaveKey(child2.ID))
|
||||
Expect(results).ToNot(HaveKey(otherParent.ID))
|
||||
})
|
||||
|
||||
It("excludes children from other libraries", func() {
|
||||
// Create parent in testLib
|
||||
parent := model.NewFolder(testLib, "TestIsolation/Parent")
|
||||
child := model.NewFolder(testLib, "TestIsolation/Parent/Child")
|
||||
|
||||
Expect(repo.Put(parent)).To(Succeed())
|
||||
Expect(repo.Put(child)).To(Succeed())
|
||||
|
||||
// Create similar path in other library
|
||||
otherParent := model.NewFolder(otherLib, "TestIsolation/Parent")
|
||||
otherChild := model.NewFolder(otherLib, "TestIsolation/Parent/Child")
|
||||
|
||||
Expect(repo.Put(otherParent)).To(Succeed())
|
||||
Expect(repo.Put(otherChild)).To(Succeed())
|
||||
|
||||
// Query should only return folders from testLib
|
||||
results, err := repo.GetFolderUpdateInfo(testLib, "TestIsolation/Parent")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(2))
|
||||
Expect(results).To(HaveKey(parent.ID))
|
||||
Expect(results).To(HaveKey(child.ID))
|
||||
Expect(results).ToNot(HaveKey(otherParent.ID))
|
||||
Expect(results).ToNot(HaveKey(otherChild.ID))
|
||||
})
|
||||
|
||||
It("excludes missing children when querying parent", func() {
|
||||
// Create parent and children, mark one as missing
|
||||
parent := model.NewFolder(testLib, "TestMissingChild/Parent")
|
||||
child1 := model.NewFolder(testLib, "TestMissingChild/Parent/Child1")
|
||||
child2 := model.NewFolder(testLib, "TestMissingChild/Parent/Child2")
|
||||
child2.Missing = true
|
||||
|
||||
Expect(repo.Put(parent)).To(Succeed())
|
||||
Expect(repo.Put(child1)).To(Succeed())
|
||||
Expect(repo.Put(child2)).To(Succeed())
|
||||
|
||||
// Query parent - should only return parent and non-missing child
|
||||
results, err := repo.GetFolderUpdateInfo(testLib, "TestMissingChild/Parent")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(2))
|
||||
Expect(results).To(HaveKey(parent.ID))
|
||||
Expect(results).To(HaveKey(child1.ID))
|
||||
Expect(results).ToNot(HaveKey(child2.ID))
|
||||
})
|
||||
|
||||
It("handles mix of existing and non-existing target paths", func() {
|
||||
// Create folders for one path but not the other
|
||||
existingParent := model.NewFolder(testLib, "TestMixed/Exists")
|
||||
existingChild := model.NewFolder(testLib, "TestMixed/Exists/Child")
|
||||
|
||||
Expect(repo.Put(existingParent)).To(Succeed())
|
||||
Expect(repo.Put(existingChild)).To(Succeed())
|
||||
|
||||
// Query both existing and non-existing paths
|
||||
results, err := repo.GetFolderUpdateInfo(testLib, "TestMixed/Exists", "TestMixed/DoesNotExist")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(2))
|
||||
Expect(results).To(HaveKey(existingParent.ID))
|
||||
Expect(results).To(HaveKey(existingChild.ID))
|
||||
})
|
||||
|
||||
It("handles empty folder path as root", func() {
|
||||
// Test querying for root folder without creating it (fixtures should have one)
|
||||
rootFolderID := model.FolderID(testLib, ".")
|
||||
|
||||
results, err := repo.GetFolderUpdateInfo(testLib, "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Should return the root folder if it exists
|
||||
if len(results) > 0 {
|
||||
Expect(results).To(HaveKey(rootFolderID))
|
||||
}
|
||||
})
|
||||
|
||||
It("returns empty map for non-existent folders", func() {
|
||||
results, err := repo.GetFolderUpdateInfo(testLib, "NonExistent/Path")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("skips missing folders", func() {
|
||||
// Create a folder and mark it as missing
|
||||
folder := model.NewFolder(testLib, "TestMissing/Folder")
|
||||
folder.Missing = true
|
||||
err := repo.Put(folder)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
results, err := repo.GetFolderUpdateInfo(testLib, "TestMissing/Folder")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
52
persistence/genre_repository.go
Normal file
52
persistence/genre_repository.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type genreRepository struct {
|
||||
*baseTagRepository
|
||||
}
|
||||
|
||||
func NewGenreRepository(ctx context.Context, db dbx.Builder) model.GenreRepository {
|
||||
genreFilter := model.TagGenre
|
||||
return &genreRepository{
|
||||
baseTagRepository: newBaseTagRepository(ctx, db, &genreFilter),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *genreRepository) selectGenre(opt ...model.QueryOptions) SelectBuilder {
|
||||
return r.newSelect(opt...).Columns("tag.tag_value as name")
|
||||
}
|
||||
|
||||
func (r *genreRepository) GetAll(opt ...model.QueryOptions) (model.Genres, error) {
|
||||
sq := r.selectGenre(opt...)
|
||||
res := model.Genres{}
|
||||
err := r.queryAll(sq, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Override ResourceRepository methods to return Genre objects instead of Tag objects
|
||||
|
||||
func (r *genreRepository) Read(id string) (interface{}, error) {
|
||||
sel := r.selectGenre().Where(Eq{"tag.id": id})
|
||||
var res model.Genre
|
||||
err := r.queryOne(sel, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *genreRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *genreRepository) NewInstance() interface{} {
|
||||
return &model.Genre{}
|
||||
}
|
||||
|
||||
var _ model.GenreRepository = (*genreRepository)(nil)
|
||||
var _ model.ResourceRepository = (*genreRepository)(nil)
|
||||
329
persistence/genre_repository_test.go
Normal file
329
persistence/genre_repository_test.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("GenreRepository", func() {
|
||||
var repo model.GenreRepository
|
||||
var restRepo model.ResourceRepository
|
||||
var tagRepo model.TagRepository
|
||||
var ctx context.Context
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx = request.WithUser(GinkgoT().Context(), model.User{ID: "userid", UserName: "johndoe", IsAdmin: true})
|
||||
genreRepo := NewGenreRepository(ctx, GetDBXBuilder())
|
||||
repo = genreRepo
|
||||
restRepo = genreRepo.(model.ResourceRepository)
|
||||
tagRepo = NewTagRepository(ctx, GetDBXBuilder())
|
||||
|
||||
// Clear any existing tags to ensure test isolation
|
||||
db := GetDBXBuilder()
|
||||
_, err := db.NewQuery("DELETE FROM tag").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Ensure library 1 exists and user has access to it
|
||||
_, err = db.NewQuery("INSERT OR IGNORE INTO library (id, name, path, default_new_users) VALUES (1, 'Test Library', '/test', true)").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = db.NewQuery("INSERT OR IGNORE INTO user_library (user_id, library_id) VALUES ('userid', 1)").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Add comprehensive test data that covers all test scenarios
|
||||
newTag := func(name, value string) model.Tag {
|
||||
return model.Tag{ID: id.NewTagID(name, value), TagName: model.TagName(name), TagValue: value}
|
||||
}
|
||||
|
||||
err = tagRepo.Add(1,
|
||||
newTag("genre", "rock"),
|
||||
newTag("genre", "pop"),
|
||||
newTag("genre", "jazz"),
|
||||
newTag("genre", "electronic"),
|
||||
newTag("genre", "classical"),
|
||||
newTag("genre", "ambient"),
|
||||
newTag("genre", "techno"),
|
||||
newTag("genre", "house"),
|
||||
newTag("genre", "trance"),
|
||||
newTag("genre", "Alternative Rock"),
|
||||
newTag("genre", "Blues"),
|
||||
newTag("genre", "Country"),
|
||||
// These should not be counted as genres
|
||||
newTag("mood", "happy"),
|
||||
newTag("mood", "ambient"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Describe("GetAll", func() {
|
||||
It("should return all genres", func() {
|
||||
genres, err := repo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).To(HaveLen(12))
|
||||
|
||||
// Verify that all returned items are genres (TagName = "genre")
|
||||
genreNames := make([]string, len(genres))
|
||||
for i, genre := range genres {
|
||||
genreNames[i] = genre.Name
|
||||
}
|
||||
Expect(genreNames).To(ContainElement("rock"))
|
||||
Expect(genreNames).To(ContainElement("pop"))
|
||||
Expect(genreNames).To(ContainElement("jazz"))
|
||||
// Should not contain mood tags
|
||||
Expect(genreNames).ToNot(ContainElement("happy"))
|
||||
})
|
||||
|
||||
It("should support query options", func() {
|
||||
// Test with limiting results
|
||||
genres, err := repo.GetAll(model.QueryOptions{Max: 1})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("should handle empty results gracefully", func() {
|
||||
// Clear all genre tags
|
||||
_, err := GetDBXBuilder().NewQuery("DELETE FROM tag WHERE tag_name = 'genre'").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
genres, err := repo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).To(BeEmpty())
|
||||
})
|
||||
Describe("filtering and sorting", func() {
|
||||
It("should filter by name using like match", func() {
|
||||
// Test filtering by partial name match using the "name" filter which maps to containsFilter("tag_value")
|
||||
options := model.QueryOptions{
|
||||
Filters: squirrel.Like{"tag_value": "%rock%"}, // Direct field access
|
||||
}
|
||||
genres, err := repo.GetAll(options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).To(HaveLen(2)) // Should match "rock" and "Alternative Rock"
|
||||
|
||||
// Verify all returned genres contain "rock" in their name
|
||||
for _, genre := range genres {
|
||||
Expect(strings.ToLower(genre.Name)).To(ContainSubstring("rock"))
|
||||
}
|
||||
})
|
||||
|
||||
It("should sort by name in ascending order", func() {
|
||||
// Test sorting by name with the fixed mapping
|
||||
options := model.QueryOptions{
|
||||
Filters: squirrel.Like{"tag_value": "%e%"}, // Should match genres containing "e"
|
||||
Sort: "name",
|
||||
}
|
||||
genres, err := repo.GetAll(options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).To(HaveLen(7))
|
||||
|
||||
Expect(slices.IsSortedFunc(genres, func(a, b model.Genre) int {
|
||||
return strings.Compare(b.Name, a.Name) // Inverted to check descending order
|
||||
}))
|
||||
})
|
||||
|
||||
It("should sort by name in descending order", func() {
|
||||
// Test sorting by name in descending order
|
||||
options := model.QueryOptions{
|
||||
Filters: squirrel.Like{"tag_value": "%e%"}, // Should match genres containing "e"
|
||||
Sort: "name",
|
||||
Order: "desc",
|
||||
}
|
||||
genres, err := repo.GetAll(options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).To(HaveLen(7))
|
||||
|
||||
Expect(slices.IsSortedFunc(genres, func(a, b model.Genre) int {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
}))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
It("should return correct count of genres", func() {
|
||||
count, err := restRepo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(12))) // We have 12 genre tags
|
||||
})
|
||||
|
||||
It("should handle zero count", func() {
|
||||
// Clear all genre tags
|
||||
_, err := GetDBXBuilder().NewQuery("DELETE FROM tag WHERE tag_name = 'genre'").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
count, err := restRepo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(BeZero())
|
||||
})
|
||||
|
||||
It("should only count genre tags", func() {
|
||||
// Add a non-genre tag
|
||||
nonGenreTag := model.Tag{
|
||||
ID: id.NewTagID("mood", "energetic"),
|
||||
TagName: "mood",
|
||||
TagValue: "energetic",
|
||||
}
|
||||
err := tagRepo.Add(1, nonGenreTag)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
count, err := restRepo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Count should not include the mood tag
|
||||
Expect(count).To(Equal(int64(12))) // Should still be 12 genre tags
|
||||
})
|
||||
|
||||
It("should filter by name using like match", func() {
|
||||
// Test filtering by partial name match using the "name" filter which maps to containsFilter("tag_value")
|
||||
options := rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"name": "%rock%"},
|
||||
}
|
||||
count, err := restRepo.Count(options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(BeNumerically("==", 2))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Read", func() {
|
||||
It("should return existing genre", func() {
|
||||
// Use one of the existing genres from our consolidated dataset
|
||||
genreID := id.NewTagID("genre", "rock")
|
||||
result, err := restRepo.Read(genreID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
genre := result.(*model.Genre)
|
||||
Expect(genre.ID).To(Equal(genreID))
|
||||
Expect(genre.Name).To(Equal("rock"))
|
||||
})
|
||||
|
||||
It("should return error for non-existent genre", func() {
|
||||
_, err := restRepo.Read("non-existent-id")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should not return non-genre tags", func() {
|
||||
moodID := id.NewTagID("mood", "happy") // This exists as a mood tag, not genre
|
||||
_, err := restRepo.Read(moodID)
|
||||
Expect(err).To(HaveOccurred()) // Should not find it as a genre
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ReadAll", func() {
|
||||
It("should return all genres through ReadAll", func() {
|
||||
result, err := restRepo.ReadAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
genres := result.(model.Genres)
|
||||
Expect(genres).To(HaveLen(12)) // We have 12 genre tags
|
||||
|
||||
genreNames := make([]string, len(genres))
|
||||
for i, genre := range genres {
|
||||
genreNames[i] = genre.Name
|
||||
}
|
||||
// Check for some of our consolidated dataset genres
|
||||
Expect(genreNames).To(ContainElement("rock"))
|
||||
Expect(genreNames).To(ContainElement("pop"))
|
||||
Expect(genreNames).To(ContainElement("jazz"))
|
||||
})
|
||||
|
||||
It("should support rest query options", func() {
|
||||
result, err := restRepo.ReadAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(result).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Library Filtering", func() {
|
||||
Context("Headless Processes (No User Context)", func() {
|
||||
var headlessRepo model.GenreRepository
|
||||
var headlessRestRepo model.ResourceRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
// Create a repository with no user context (headless)
|
||||
headlessGenreRepo := NewGenreRepository(context.Background(), GetDBXBuilder())
|
||||
headlessRepo = headlessGenreRepo
|
||||
headlessRestRepo = headlessGenreRepo.(model.ResourceRepository)
|
||||
|
||||
// Add genres to different libraries
|
||||
db := GetDBXBuilder()
|
||||
_, err := db.NewQuery("INSERT OR IGNORE INTO library (id, name, path) VALUES (2, 'Test Library 2', '/test2')").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Add tags to different libraries
|
||||
newTag := func(name, value string) model.Tag {
|
||||
return model.Tag{ID: id.NewTagID(name, value), TagName: model.TagName(name), TagValue: value}
|
||||
}
|
||||
|
||||
err = tagRepo.Add(2, newTag("genre", "jazz"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should see all genres from all libraries when no user is in context", func() {
|
||||
// Headless processes should see all genres regardless of library
|
||||
genres, err := headlessRepo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Should see genres from all libraries
|
||||
var genreNames []string
|
||||
for _, genre := range genres {
|
||||
genreNames = append(genreNames, genre.Name)
|
||||
}
|
||||
|
||||
// Should include both rock (library 1) and jazz (library 2)
|
||||
Expect(genreNames).To(ContainElement("rock"))
|
||||
Expect(genreNames).To(ContainElement("jazz"))
|
||||
})
|
||||
|
||||
It("should count all genres from all libraries when no user is in context", func() {
|
||||
count, err := headlessRestRepo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Should count all genres from all libraries
|
||||
Expect(count).To(BeNumerically(">=", 2))
|
||||
})
|
||||
|
||||
It("should allow headless processes to apply explicit library_id filters", func() {
|
||||
// Filter by specific library
|
||||
genres, err := headlessRestRepo.ReadAll(rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"library_id": 2},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
genreList := genres.(model.Genres)
|
||||
// Should see only genres from library 2
|
||||
Expect(genreList).To(HaveLen(1))
|
||||
Expect(genreList[0].Name).To(Equal("jazz"))
|
||||
})
|
||||
|
||||
It("should get individual genres when no user is in context", func() {
|
||||
// Get all genres first to find an ID
|
||||
genres, err := headlessRepo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).ToNot(BeEmpty())
|
||||
|
||||
// Headless process should be able to get the genre
|
||||
genre, err := headlessRestRepo.Read(genres[0].ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genre).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("EntityName", func() {
|
||||
It("should return correct entity name", func() {
|
||||
name := restRepo.EntityName()
|
||||
Expect(name).To(Equal("tag")) // Genre repository uses tag table
|
||||
})
|
||||
})
|
||||
|
||||
Describe("NewInstance", func() {
|
||||
It("should return new genre instance", func() {
|
||||
instance := restRepo.NewInstance()
|
||||
Expect(instance).To(BeAssignableToTypeOf(&model.Genre{}))
|
||||
})
|
||||
})
|
||||
})
|
||||
92
persistence/helpers.go
Normal file
92
persistence/helpers.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/fatih/structs"
|
||||
)
|
||||
|
||||
type PostMapper interface {
|
||||
PostMapArgs(map[string]any) error
|
||||
}
|
||||
|
||||
func toSQLArgs(rec interface{}) (map[string]interface{}, error) {
|
||||
m := structs.Map(rec)
|
||||
for k, v := range m {
|
||||
switch t := v.(type) {
|
||||
case *time.Time:
|
||||
if t != nil {
|
||||
m[k] = *t
|
||||
}
|
||||
case driver.Valuer:
|
||||
var err error
|
||||
m[k], err = t.Value()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if r, ok := rec.(PostMapper); ok {
|
||||
err := r.PostMapArgs(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
|
||||
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
|
||||
|
||||
func toSnakeCase(str string) string {
|
||||
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
|
||||
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
|
||||
return strings.ToLower(snake)
|
||||
}
|
||||
|
||||
var matchUnderscore = regexp.MustCompile("_([A-Za-z])")
|
||||
|
||||
func toCamelCase(str string) string {
|
||||
return matchUnderscore.ReplaceAllStringFunc(str, func(s string) string {
|
||||
return strings.ToUpper(strings.Replace(s, "_", "", -1))
|
||||
})
|
||||
}
|
||||
|
||||
func Exists(subTable string, cond squirrel.Sqlizer) existsCond {
|
||||
return existsCond{subTable: subTable, cond: cond, not: false}
|
||||
}
|
||||
|
||||
func NotExists(subTable string, cond squirrel.Sqlizer) existsCond {
|
||||
return existsCond{subTable: subTable, cond: cond, not: true}
|
||||
}
|
||||
|
||||
type existsCond struct {
|
||||
subTable string
|
||||
cond squirrel.Sqlizer
|
||||
not bool
|
||||
}
|
||||
|
||||
func (e existsCond) ToSql() (string, []interface{}, error) {
|
||||
sql, args, err := e.cond.ToSql()
|
||||
sql = fmt.Sprintf("exists (select 1 from %s where %s)", e.subTable, sql)
|
||||
if e.not {
|
||||
sql = "not " + sql
|
||||
}
|
||||
return sql, args, err
|
||||
}
|
||||
|
||||
var sortOrderRegex = regexp.MustCompile(`order_([a-z_]+)`)
|
||||
|
||||
// Convert the order_* columns to an expression using sort_* columns. Example:
|
||||
// sort_album_name -> (coalesce(nullif(sort_album_name,”),order_album_name) collate nocase)
|
||||
// It finds order column names anywhere in the substring
|
||||
func mapSortOrder(tableName, order string) string {
|
||||
order = strings.ToLower(order)
|
||||
repl := fmt.Sprintf("(coalesce(nullif(%[1]s.sort_$1,''),%[1]s.order_$1) collate nocase)", tableName)
|
||||
return sortOrderRegex.ReplaceAllString(order, repl)
|
||||
}
|
||||
106
persistence/helpers_test.go
Normal file
106
persistence/helpers_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Helpers", func() {
|
||||
Describe("toSnakeCase", func() {
|
||||
It("converts camelCase", func() {
|
||||
Expect(toSnakeCase("camelCase")).To(Equal("camel_case"))
|
||||
})
|
||||
It("converts PascalCase", func() {
|
||||
Expect(toSnakeCase("PascalCase")).To(Equal("pascal_case"))
|
||||
})
|
||||
It("converts ALLCAPS", func() {
|
||||
Expect(toSnakeCase("ALLCAPS")).To(Equal("allcaps"))
|
||||
})
|
||||
It("does not converts snake_case", func() {
|
||||
Expect(toSnakeCase("snake_case")).To(Equal("snake_case"))
|
||||
})
|
||||
})
|
||||
Describe("toCamelCase", func() {
|
||||
It("converts snake_case", func() {
|
||||
Expect(toCamelCase("snake_case")).To(Equal("snakeCase"))
|
||||
})
|
||||
It("converts PascalCase", func() {
|
||||
Expect(toCamelCase("PascalCase")).To(Equal("PascalCase"))
|
||||
})
|
||||
It("converts camelCase", func() {
|
||||
Expect(toCamelCase("camelCase")).To(Equal("camelCase"))
|
||||
})
|
||||
It("converts ALLCAPS", func() {
|
||||
Expect(toCamelCase("ALLCAPS")).To(Equal("ALLCAPS"))
|
||||
})
|
||||
})
|
||||
Describe("toSQLArgs", func() {
|
||||
type Embed struct{}
|
||||
type Model struct {
|
||||
Embed `structs:"-"`
|
||||
ID string `structs:"id" json:"id"`
|
||||
AlbumId string `structs:"album_id" json:"albumId"`
|
||||
PlayCount int `structs:"play_count" json:"playCount"`
|
||||
UpdatedAt *time.Time `structs:"updated_at"`
|
||||
CreatedAt time.Time `structs:"created_at"`
|
||||
}
|
||||
|
||||
It("returns a map with snake_case keys", func() {
|
||||
now := time.Now()
|
||||
m := &Model{ID: "123", AlbumId: "456", CreatedAt: now, UpdatedAt: &now, PlayCount: 2}
|
||||
args, err := toSQLArgs(m)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(args).To(SatisfyAll(
|
||||
HaveKeyWithValue("id", "123"),
|
||||
HaveKeyWithValue("album_id", "456"),
|
||||
HaveKeyWithValue("play_count", 2),
|
||||
HaveKeyWithValue("updated_at", BeTemporally("~", now)),
|
||||
HaveKeyWithValue("created_at", BeTemporally("~", now)),
|
||||
Not(HaveKey("Embed")),
|
||||
))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Exists", func() {
|
||||
It("constructs the correct EXISTS query", func() {
|
||||
e := Exists("album", squirrel.Eq{"id": 1})
|
||||
sql, args, err := e.ToSql()
|
||||
Expect(sql).To(Equal("exists (select 1 from album where id = ?)"))
|
||||
Expect(args).To(ConsistOf(1))
|
||||
Expect(err).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("NotExists", func() {
|
||||
It("constructs the correct NOT EXISTS query", func() {
|
||||
e := NotExists("artist", squirrel.ConcatExpr("id = artist_id"))
|
||||
sql, args, err := e.ToSql()
|
||||
Expect(sql).To(Equal("not exists (select 1 from artist where id = artist_id)"))
|
||||
Expect(args).To(BeEmpty())
|
||||
Expect(err).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("mapSortOrder", func() {
|
||||
It("does not change the sort string if there are no order columns", func() {
|
||||
sort := "album_name asc"
|
||||
mapped := mapSortOrder("album", sort)
|
||||
Expect(mapped).To(Equal(sort))
|
||||
})
|
||||
It("changes order columns to sort expression", func() {
|
||||
sort := "ORDER_ALBUM_NAME asc"
|
||||
mapped := mapSortOrder("album", sort)
|
||||
Expect(mapped).To(Equal(`(coalesce(nullif(album.sort_album_name,''),album.order_album_name)` +
|
||||
` collate nocase) asc`))
|
||||
})
|
||||
It("changes multiple order columns to sort expressions", func() {
|
||||
sort := "compilation, order_title asc, order_album_artist_name desc, year desc"
|
||||
mapped := mapSortOrder("album", sort)
|
||||
Expect(mapped).To(Equal(`compilation, (coalesce(nullif(album.sort_title,''),album.order_title) collate nocase) asc,` +
|
||||
` (coalesce(nullif(album.sort_album_artist_name,''),album.order_album_artist_name) collate nocase) desc, year desc`))
|
||||
})
|
||||
})
|
||||
})
|
||||
347
persistence/library_repository.go
Normal file
347
persistence/library_repository.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/run"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type libraryRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
var (
|
||||
libCache = map[int]string{}
|
||||
libLock sync.RWMutex
|
||||
)
|
||||
|
||||
func NewLibraryRepository(ctx context.Context, db dbx.Builder) model.LibraryRepository {
|
||||
r := &libraryRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.registerModel(&model.Library{}, nil)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *libraryRepository) Get(id int) (*model.Library, error) {
|
||||
sq := r.newSelect().Columns("*").Where(Eq{"id": id})
|
||||
var res model.Library
|
||||
err := r.queryOne(sq, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *libraryRepository) GetPath(id int) (string, error) {
|
||||
l := func() string {
|
||||
libLock.RLock()
|
||||
defer libLock.RUnlock()
|
||||
if l, ok := libCache[id]; ok {
|
||||
return l
|
||||
}
|
||||
return ""
|
||||
}()
|
||||
if l != "" {
|
||||
return l, nil
|
||||
}
|
||||
|
||||
libLock.Lock()
|
||||
defer libLock.Unlock()
|
||||
libs, err := r.GetAll()
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error loading libraries from DB", err)
|
||||
return "", err
|
||||
}
|
||||
for _, l := range libs {
|
||||
libCache[l.ID] = l.Path
|
||||
}
|
||||
if l, ok := libCache[id]; ok {
|
||||
return l, nil
|
||||
} else {
|
||||
return "", model.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func (r *libraryRepository) Put(l *model.Library) error {
|
||||
if l.ID == model.DefaultLibraryID {
|
||||
currentLib, err := r.Get(1)
|
||||
// if we are creating it, it's ok.
|
||||
if err == nil { // it exists, so we are updating it
|
||||
if currentLib.Path != l.Path {
|
||||
return fmt.Errorf("%w: path for library with ID 1 cannot be changed", model.ErrValidation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
l.UpdatedAt = time.Now()
|
||||
if l.ID == 0 {
|
||||
// Insert with autoassigned ID
|
||||
l.CreatedAt = time.Now()
|
||||
err = r.db.Model(l).Insert()
|
||||
} else {
|
||||
// Try to update first
|
||||
cols := map[string]any{
|
||||
"name": l.Name,
|
||||
"path": l.Path,
|
||||
"remote_path": l.RemotePath,
|
||||
"default_new_users": l.DefaultNewUsers,
|
||||
"updated_at": l.UpdatedAt,
|
||||
}
|
||||
sq := Update(r.tableName).SetMap(cols).Where(Eq{"id": l.ID})
|
||||
rowsAffected, updateErr := r.executeSQL(sq)
|
||||
if updateErr != nil {
|
||||
return updateErr
|
||||
}
|
||||
|
||||
// If no rows were affected, the record doesn't exist, so insert it
|
||||
if rowsAffected == 0 {
|
||||
l.CreatedAt = time.Now()
|
||||
l.UpdatedAt = time.Now()
|
||||
err = r.db.Model(l).Insert()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Auto-assign all libraries to all admin users
|
||||
sql := Expr(`
|
||||
INSERT INTO user_library (user_id, library_id)
|
||||
SELECT u.id, l.id
|
||||
FROM user u
|
||||
CROSS JOIN library l
|
||||
WHERE u.is_admin = true
|
||||
ON CONFLICT (user_id, library_id) DO NOTHING;`,
|
||||
)
|
||||
if _, err = r.executeSQL(sql); err != nil {
|
||||
return fmt.Errorf("failed to assign library to admin users: %w", err)
|
||||
}
|
||||
|
||||
libLock.Lock()
|
||||
defer libLock.Unlock()
|
||||
libCache[l.ID] = l.Path
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO Remove this method when we have a proper UI to add libraries
|
||||
// This is a temporary method to store the music folder path from the config in the DB
|
||||
func (r *libraryRepository) StoreMusicFolder() error {
|
||||
sq := Update(r.tableName).Set("path", conf.Server.MusicFolder).
|
||||
Set("updated_at", time.Now()).
|
||||
Where(Eq{"id": model.DefaultLibraryID})
|
||||
_, err := r.executeSQL(sq)
|
||||
if err != nil {
|
||||
libLock.Lock()
|
||||
defer libLock.Unlock()
|
||||
libCache[model.DefaultLibraryID] = conf.Server.MusicFolder
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *libraryRepository) AddArtist(id int, artistID string) error {
|
||||
sq := Insert("library_artist").Columns("library_id", "artist_id").Values(id, artistID).
|
||||
Suffix(`on conflict(library_id, artist_id) do nothing`)
|
||||
_, err := r.executeSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *libraryRepository) ScanBegin(id int, fullScan bool) error {
|
||||
sq := Update(r.tableName).
|
||||
Set("last_scan_started_at", time.Now()).
|
||||
Set("full_scan_in_progress", fullScan).
|
||||
Where(Eq{"id": id})
|
||||
_, err := r.executeSQL(sq)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *libraryRepository) ScanEnd(id int) error {
|
||||
sq := Update(r.tableName).
|
||||
Set("last_scan_at", time.Now()).
|
||||
Set("full_scan_in_progress", false).
|
||||
Set("last_scan_started_at", time.Time{}).
|
||||
Where(Eq{"id": id})
|
||||
_, err := r.executeSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// https://www.sqlite.org/pragma.html#pragma_optimize
|
||||
// Use mask 0x10000 to check table sizes without running ANALYZE
|
||||
// Running ANALYZE can cause query planner issues with expression-based collation indexes
|
||||
if conf.Server.DevOptimizeDB {
|
||||
_, err = r.executeSQL(Expr("PRAGMA optimize=0x10000;"))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *libraryRepository) ScanInProgress() (bool, error) {
|
||||
query := r.newSelect().Where(NotEq{"last_scan_started_at": time.Time{}})
|
||||
count, err := r.count(query)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *libraryRepository) RefreshStats(id int) error {
|
||||
var songsRes, albumsRes, artistsRes, foldersRes, filesRes, missingRes struct{ Count int64 }
|
||||
var sizeRes struct{ Sum int64 }
|
||||
var durationRes struct{ Sum float64 }
|
||||
|
||||
err := run.Parallel(
|
||||
func() error {
|
||||
return r.queryOne(Select("count(*) as count").From("media_file").Where(Eq{"library_id": id, "missing": false}), &songsRes)
|
||||
},
|
||||
func() error {
|
||||
return r.queryOne(Select("count(*) as count").From("album").Where(Eq{"library_id": id, "missing": false}), &albumsRes)
|
||||
},
|
||||
func() error {
|
||||
return r.queryOne(Select("count(*) as count").From("library_artist la").
|
||||
Join("artist a on la.artist_id = a.id").
|
||||
Where(Eq{"la.library_id": id, "a.missing": false}), &artistsRes)
|
||||
},
|
||||
func() error {
|
||||
return r.queryOne(Select("count(*) as count").From("folder").
|
||||
Where(And{
|
||||
Eq{"library_id": id, "missing": false},
|
||||
Gt{"num_audio_files": 0},
|
||||
}), &foldersRes)
|
||||
},
|
||||
func() error {
|
||||
return r.queryOne(Select("ifnull(sum(num_audio_files + num_playlists + json_array_length(image_files)),0) as count").
|
||||
From("folder").Where(Eq{"library_id": id, "missing": false}), &filesRes)
|
||||
},
|
||||
func() error {
|
||||
return r.queryOne(Select("count(*) as count").From("media_file").Where(Eq{"library_id": id, "missing": true}), &missingRes)
|
||||
},
|
||||
func() error {
|
||||
return r.queryOne(Select("ifnull(sum(size),0) as sum").From("album").Where(Eq{"library_id": id, "missing": false}), &sizeRes)
|
||||
},
|
||||
func() error {
|
||||
return r.queryOne(Select("ifnull(sum(duration),0) as sum").From("album").Where(Eq{"library_id": id, "missing": false}), &durationRes)
|
||||
},
|
||||
)()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sq := Update(r.tableName).
|
||||
Set("total_songs", songsRes.Count).
|
||||
Set("total_albums", albumsRes.Count).
|
||||
Set("total_artists", artistsRes.Count).
|
||||
Set("total_folders", foldersRes.Count).
|
||||
Set("total_files", filesRes.Count).
|
||||
Set("total_missing_files", missingRes.Count).
|
||||
Set("total_size", sizeRes.Sum).
|
||||
Set("total_duration", durationRes.Sum).
|
||||
Set("updated_at", time.Now()).
|
||||
Where(Eq{"id": id})
|
||||
_, err = r.executeSQL(sq)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *libraryRepository) Delete(id int) error {
|
||||
if !loggedUser(r.ctx).IsAdmin {
|
||||
return model.ErrNotAuthorized
|
||||
}
|
||||
if id == 1 {
|
||||
return fmt.Errorf("%w: library with ID 1 cannot be deleted", model.ErrValidation)
|
||||
}
|
||||
|
||||
err := r.delete(Eq{"id": id})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Clear cache entry for this library only if DB operation was successful
|
||||
libLock.Lock()
|
||||
defer libLock.Unlock()
|
||||
delete(libCache, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *libraryRepository) GetAll(ops ...model.QueryOptions) (model.Libraries, error) {
|
||||
sq := r.newSelect(ops...).Columns("*")
|
||||
res := model.Libraries{}
|
||||
err := r.queryAll(sq, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *libraryRepository) CountAll(ops ...model.QueryOptions) (int64, error) {
|
||||
sq := r.newSelect(ops...)
|
||||
return r.count(sq)
|
||||
}
|
||||
|
||||
// User-library association methods
|
||||
|
||||
func (r *libraryRepository) GetUsersWithLibraryAccess(libraryID int) (model.Users, error) {
|
||||
sel := Select("u.*").
|
||||
From("user u").
|
||||
Join("user_library ul ON u.id = ul.user_id").
|
||||
Where(Eq{"ul.library_id": libraryID}).
|
||||
OrderBy("u.name")
|
||||
|
||||
var res model.Users
|
||||
err := r.queryAll(sel, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
// REST interface methods
|
||||
|
||||
func (r *libraryRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *libraryRepository) Read(id string) (interface{}, error) {
|
||||
idInt, err := strconv.Atoi(id)
|
||||
if err != nil {
|
||||
log.Trace(r.ctx, "invalid library id: %s", id, err)
|
||||
return nil, rest.ErrNotFound
|
||||
}
|
||||
return r.Get(idInt)
|
||||
}
|
||||
|
||||
func (r *libraryRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *libraryRepository) EntityName() string {
|
||||
return "library"
|
||||
}
|
||||
|
||||
func (r *libraryRepository) NewInstance() interface{} {
|
||||
return &model.Library{}
|
||||
}
|
||||
|
||||
func (r *libraryRepository) Save(entity interface{}) (string, error) {
|
||||
lib := entity.(*model.Library)
|
||||
lib.ID = 0 // Reset ID to ensure we create a new library
|
||||
err := r.Put(lib)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.Itoa(lib.ID), nil
|
||||
}
|
||||
|
||||
func (r *libraryRepository) Update(id string, entity interface{}, cols ...string) error {
|
||||
lib := entity.(*model.Library)
|
||||
idInt, err := strconv.Atoi(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid library ID: %s", id)
|
||||
}
|
||||
|
||||
lib.ID = idInt
|
||||
return r.Put(lib)
|
||||
}
|
||||
|
||||
var _ model.LibraryRepository = (*libraryRepository)(nil)
|
||||
var _ rest.Repository = (*libraryRepository)(nil)
|
||||
203
persistence/library_repository_test.go
Normal file
203
persistence/library_repository_test.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ = Describe("LibraryRepository", func() {
|
||||
var repo model.LibraryRepository
|
||||
var ctx context.Context
|
||||
var conn *dbx.DB
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx = request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid"})
|
||||
conn = GetDBXBuilder()
|
||||
repo = NewLibraryRepository(ctx, conn)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up test libraries (keep ID 1 which is the default library)
|
||||
_, _ = conn.NewQuery("DELETE FROM library WHERE id > 1").Execute()
|
||||
})
|
||||
|
||||
Describe("Put", func() {
|
||||
Context("when ID is 0", func() {
|
||||
It("inserts a new library with autoassigned ID", func() {
|
||||
lib := &model.Library{
|
||||
ID: 0,
|
||||
Name: "Test Library",
|
||||
Path: "/music/test",
|
||||
}
|
||||
|
||||
err := repo.Put(lib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(lib.ID).To(BeNumerically(">", 0))
|
||||
Expect(lib.CreatedAt).ToNot(BeZero())
|
||||
Expect(lib.UpdatedAt).ToNot(BeZero())
|
||||
|
||||
// Verify it was inserted
|
||||
savedLib, err := repo.Get(lib.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(savedLib.Name).To(Equal("Test Library"))
|
||||
Expect(savedLib.Path).To(Equal("/music/test"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when ID is non-zero and record exists", func() {
|
||||
It("updates the existing record", func() {
|
||||
// First create a library
|
||||
lib := &model.Library{
|
||||
ID: 0,
|
||||
Name: "Original Library",
|
||||
Path: "/music/original",
|
||||
}
|
||||
err := repo.Put(lib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
originalID := lib.ID
|
||||
originalCreatedAt := lib.CreatedAt
|
||||
|
||||
// Now update it
|
||||
lib.Name = "Updated Library"
|
||||
lib.Path = "/music/updated"
|
||||
err = repo.Put(lib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify it was updated, not inserted
|
||||
Expect(lib.ID).To(Equal(originalID))
|
||||
Expect(lib.CreatedAt).To(Equal(originalCreatedAt))
|
||||
Expect(lib.UpdatedAt).To(BeTemporally(">", originalCreatedAt))
|
||||
|
||||
// Verify the changes were saved
|
||||
savedLib, err := repo.Get(lib.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(savedLib.Name).To(Equal("Updated Library"))
|
||||
Expect(savedLib.Path).To(Equal("/music/updated"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when ID is non-zero but record doesn't exist", func() {
|
||||
It("inserts a new record with the specified ID", func() {
|
||||
lib := &model.Library{
|
||||
ID: 999,
|
||||
Name: "New Library with ID",
|
||||
Path: "/music/new",
|
||||
}
|
||||
|
||||
// Ensure the record doesn't exist
|
||||
_, err := repo.Get(999)
|
||||
Expect(err).To(HaveOccurred())
|
||||
|
||||
// Put should insert it
|
||||
err = repo.Put(lib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(lib.ID).To(Equal(999))
|
||||
Expect(lib.CreatedAt).ToNot(BeZero())
|
||||
Expect(lib.UpdatedAt).ToNot(BeZero())
|
||||
|
||||
// Verify it was inserted with the correct ID
|
||||
savedLib, err := repo.Get(999)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(savedLib.ID).To(Equal(999))
|
||||
Expect(savedLib.Name).To(Equal("New Library with ID"))
|
||||
Expect(savedLib.Path).To(Equal("/music/new"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
It("refreshes stats", func() {
|
||||
libBefore, err := repo.Get(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(repo.RefreshStats(1)).To(Succeed())
|
||||
libAfter, err := repo.Get(1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libAfter.UpdatedAt).To(BeTemporally(">", libBefore.UpdatedAt))
|
||||
|
||||
var songsRes, albumsRes, artistsRes, foldersRes, filesRes, missingRes struct{ Count int64 }
|
||||
var sizeRes struct{ Sum int64 }
|
||||
var durationRes struct{ Sum float64 }
|
||||
|
||||
Expect(conn.NewQuery("select count(*) as count from media_file where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&songsRes)).To(Succeed())
|
||||
Expect(conn.NewQuery("select count(*) as count from album where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&albumsRes)).To(Succeed())
|
||||
Expect(conn.NewQuery("select count(*) as count from library_artist la join artist a on la.artist_id = a.id where la.library_id = {:id} and a.missing = 0").Bind(dbx.Params{"id": 1}).One(&artistsRes)).To(Succeed())
|
||||
Expect(conn.NewQuery("select count(*) as count from folder where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&foldersRes)).To(Succeed())
|
||||
Expect(conn.NewQuery("select ifnull(sum(num_audio_files + num_playlists + json_array_length(image_files)),0) as count from folder where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&filesRes)).To(Succeed())
|
||||
Expect(conn.NewQuery("select count(*) as count from media_file where library_id = {:id} and missing = 1").Bind(dbx.Params{"id": 1}).One(&missingRes)).To(Succeed())
|
||||
Expect(conn.NewQuery("select ifnull(sum(size),0) as sum from album where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&sizeRes)).To(Succeed())
|
||||
Expect(conn.NewQuery("select ifnull(sum(duration),0) as sum from album where library_id = {:id} and missing = 0").Bind(dbx.Params{"id": 1}).One(&durationRes)).To(Succeed())
|
||||
|
||||
Expect(libAfter.TotalSongs).To(Equal(int(songsRes.Count)))
|
||||
Expect(libAfter.TotalAlbums).To(Equal(int(albumsRes.Count)))
|
||||
Expect(libAfter.TotalArtists).To(Equal(int(artistsRes.Count)))
|
||||
Expect(libAfter.TotalFolders).To(Equal(int(foldersRes.Count)))
|
||||
Expect(libAfter.TotalFiles).To(Equal(int(filesRes.Count)))
|
||||
Expect(libAfter.TotalMissingFiles).To(Equal(int(missingRes.Count)))
|
||||
Expect(libAfter.TotalSize).To(Equal(sizeRes.Sum))
|
||||
Expect(libAfter.TotalDuration).To(Equal(durationRes.Sum))
|
||||
})
|
||||
|
||||
Describe("ScanBegin and ScanEnd", func() {
|
||||
var lib *model.Library
|
||||
|
||||
BeforeEach(func() {
|
||||
lib = &model.Library{
|
||||
ID: 0,
|
||||
Name: "Test Scan Library",
|
||||
Path: "/music/test-scan",
|
||||
}
|
||||
err := repo.Put(lib)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
DescribeTable("ScanBegin",
|
||||
func(fullScan bool, expectedFullScanInProgress bool) {
|
||||
err := repo.ScanBegin(lib.ID, fullScan)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
updatedLib, err := repo.Get(lib.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(updatedLib.LastScanStartedAt).ToNot(BeZero())
|
||||
Expect(updatedLib.FullScanInProgress).To(Equal(expectedFullScanInProgress))
|
||||
},
|
||||
Entry("sets FullScanInProgress to true for full scan", true, true),
|
||||
Entry("sets FullScanInProgress to false for quick scan", false, false),
|
||||
)
|
||||
|
||||
Context("ScanEnd", func() {
|
||||
BeforeEach(func() {
|
||||
err := repo.ScanBegin(lib.ID, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("sets LastScanAt and clears FullScanInProgress and LastScanStartedAt", func() {
|
||||
err := repo.ScanEnd(lib.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
updatedLib, err := repo.Get(lib.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(updatedLib.LastScanAt).ToNot(BeZero())
|
||||
Expect(updatedLib.FullScanInProgress).To(BeFalse())
|
||||
Expect(updatedLib.LastScanStartedAt).To(BeZero())
|
||||
})
|
||||
|
||||
It("sets LastScanAt to be after LastScanStartedAt", func() {
|
||||
libBefore, err := repo.Get(lib.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = repo.ScanEnd(lib.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
libAfter, err := repo.Get(lib.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libAfter.LastScanAt).To(BeTemporally(">=", libBefore.LastScanStartedAt))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
477
persistence/mediafile_repository.go
Normal file
477
persistence/mediafile_repository.go
Normal file
@@ -0,0 +1,477 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type mediaFileRepository struct {
|
||||
sqlRepository
|
||||
ms *MeilisearchService
|
||||
}
|
||||
|
||||
type dbMediaFile struct {
|
||||
*model.MediaFile `structs:",flatten"`
|
||||
Participants string `structs:"-" json:"-"`
|
||||
Tags string `structs:"-" json:"-"`
|
||||
// These are necessary to map the correct names (rg_*) to the correct fields (RG*)
|
||||
// without using `db` struct tags in the model.MediaFile struct
|
||||
RgAlbumGain *float64 `structs:"-" json:"-"`
|
||||
RgAlbumPeak *float64 `structs:"-" json:"-"`
|
||||
RgTrackGain *float64 `structs:"-" json:"-"`
|
||||
RgTrackPeak *float64 `structs:"-" json:"-"`
|
||||
}
|
||||
|
||||
func (m *dbMediaFile) PostScan() error {
|
||||
m.RGTrackGain = m.RgTrackGain
|
||||
m.RGTrackPeak = m.RgTrackPeak
|
||||
m.RGAlbumGain = m.RgAlbumGain
|
||||
m.RGAlbumPeak = m.RgAlbumPeak
|
||||
var err error
|
||||
m.MediaFile.Participants, err = unmarshalParticipants(m.Participants)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing media_file from db: %w", err)
|
||||
}
|
||||
if m.Tags != "" {
|
||||
m.MediaFile.Tags, err = unmarshalTags(m.Tags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing media_file from db: %w", err)
|
||||
}
|
||||
m.Genre, m.Genres = m.MediaFile.Tags.ToGenres()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *dbMediaFile) PostMapArgs(args map[string]any) error {
|
||||
fullText := []string{m.FullTitle(), m.Album, m.Artist, m.AlbumArtist,
|
||||
m.SortTitle, m.SortAlbumName, m.SortArtistName, m.SortAlbumArtistName, m.DiscSubtitle}
|
||||
fullText = append(fullText, m.MediaFile.Participants.AllNames()...)
|
||||
args["full_text"] = formatFullText(fullText...)
|
||||
args["tags"] = marshalTags(m.MediaFile.Tags)
|
||||
args["participants"] = marshalParticipants(m.MediaFile.Participants)
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbMediaFiles []dbMediaFile
|
||||
|
||||
func (m dbMediaFiles) toModels() model.MediaFiles {
|
||||
return slice.Map(m, func(mf dbMediaFile) model.MediaFile { return *mf.MediaFile })
|
||||
}
|
||||
|
||||
func NewMediaFileRepository(ctx context.Context, db dbx.Builder, ms *MeilisearchService) model.MediaFileRepository {
|
||||
r := &mediaFileRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.ms = ms
|
||||
r.tableName = "media_file"
|
||||
r.registerModel(&model.MediaFile{}, mediaFileFilter())
|
||||
r.setSortMappings(map[string]string{
|
||||
"title": "order_title",
|
||||
"artist": "order_artist_name, order_album_name, release_date, disc_number, track_number",
|
||||
"album_artist": "order_album_artist_name, order_album_name, release_date, disc_number, track_number",
|
||||
"album": "order_album_name, album_id, disc_number, track_number, order_artist_name, title",
|
||||
"random": "random",
|
||||
"created_at": "media_file.created_at",
|
||||
"recently_added": mediaFileRecentlyAddedSort(),
|
||||
"starred_at": "starred, starred_at",
|
||||
"rated_at": "rating, rated_at",
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
var mediaFileFilter = sync.OnceValue(func() map[string]filterFunc {
|
||||
filters := map[string]filterFunc{
|
||||
"id": idFilter("media_file"),
|
||||
"title": fullTextFilter("media_file", "mbz_recording_id", "mbz_release_track_id"),
|
||||
"starred": booleanFilter,
|
||||
"genre_id": tagIDFilter,
|
||||
"missing": booleanFilter,
|
||||
"artists_id": artistFilter,
|
||||
"library_id": libraryIdFilter,
|
||||
}
|
||||
// Add all album tags as filters
|
||||
for tag := range model.TagMappings() {
|
||||
if _, exists := filters[string(tag)]; !exists {
|
||||
filters[string(tag)] = tagIDFilter
|
||||
}
|
||||
}
|
||||
return filters
|
||||
})
|
||||
|
||||
func mediaFileRecentlyAddedSort() string {
|
||||
if conf.Server.RecentlyAddedByModTime {
|
||||
return "media_file.updated_at"
|
||||
}
|
||||
return "media_file.created_at"
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) CountAll(options ...model.QueryOptions) (int64, error) {
|
||||
query := r.newSelect()
|
||||
query = r.withAnnotation(query, "media_file.id")
|
||||
query = r.applyLibraryFilter(query)
|
||||
return r.count(query, options...)
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) Exists(id string) (bool, error) {
|
||||
return r.exists(Eq{"media_file.id": id})
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) Put(m *model.MediaFile) error {
|
||||
m.CreatedAt = time.Now()
|
||||
id, err := r.putByMatch(Eq{"path": m.Path, "library_id": m.LibraryID}, m.ID, &dbMediaFile{MediaFile: m})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.ID = id
|
||||
err = r.updateParticipants(m.ID, m.Participants)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.ms != nil {
|
||||
r.ms.IndexMediaFile(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) selectMediaFile(options ...model.QueryOptions) SelectBuilder {
|
||||
sql := r.newSelect(options...).Columns("media_file.*", "library.path as library_path", "library.name as library_name").
|
||||
LeftJoin("library on media_file.library_id = library.id")
|
||||
sql = r.withAnnotation(sql, "media_file.id")
|
||||
sql = r.withBookmark(sql, "media_file.id")
|
||||
return r.applyLibraryFilter(sql)
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) Get(id string) (*model.MediaFile, error) {
|
||||
res, err := r.GetAll(model.QueryOptions{Filters: Eq{"media_file.id": id}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(res) == 0 {
|
||||
return nil, model.ErrNotFound
|
||||
}
|
||||
return &res[0], nil
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) GetWithParticipants(id string) (*model.MediaFile, error) {
|
||||
m, err := r.Get(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.Participants, err = r.getParticipants(m)
|
||||
return m, err
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) GetAll(options ...model.QueryOptions) (model.MediaFiles, error) {
|
||||
sq := r.selectMediaFile(options...)
|
||||
var res dbMediaFiles
|
||||
err := r.queryAll(sq, &res, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.toModels(), nil
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) GetCursor(options ...model.QueryOptions) (model.MediaFileCursor, error) {
|
||||
sq := r.selectMediaFile(options...)
|
||||
cursor, err := queryWithStableResults[dbMediaFile](r.sqlRepository, sq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(yield func(model.MediaFile, error) bool) {
|
||||
for m, err := range cursor {
|
||||
if m.MediaFile == nil {
|
||||
yield(model.MediaFile{}, fmt.Errorf("unexpected nil mediafile: %v", m))
|
||||
return
|
||||
}
|
||||
if !yield(*m.MediaFile, err) || err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FindByPaths finds media files by their paths.
|
||||
// The paths can be library-qualified (format: "libraryID:path") or unqualified ("path").
|
||||
// Library-qualified paths search within the specified library, while unqualified paths
|
||||
// search across all libraries for backward compatibility.
|
||||
func (r *mediaFileRepository) FindByPaths(paths []string) (model.MediaFiles, error) {
|
||||
query := Or{}
|
||||
|
||||
for _, path := range paths {
|
||||
parts := strings.SplitN(path, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
// Library-qualified path: "libraryID:path"
|
||||
libraryID, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
// Invalid format, skip
|
||||
continue
|
||||
}
|
||||
relativePath := parts[1]
|
||||
query = append(query, And{
|
||||
Eq{"path collate nocase": relativePath},
|
||||
Eq{"library_id": libraryID},
|
||||
})
|
||||
} else {
|
||||
// Unqualified path: search across all libraries
|
||||
query = append(query, Eq{"path collate nocase": path})
|
||||
}
|
||||
}
|
||||
|
||||
if len(query) == 0 {
|
||||
return model.MediaFiles{}, nil
|
||||
}
|
||||
|
||||
sel := r.newSelect().Columns("*").Where(query)
|
||||
var res dbMediaFiles
|
||||
if err := r.queryAll(sel, &res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.toModels(), nil
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) Delete(id string) error {
|
||||
err := r.delete(Eq{"id": id})
|
||||
if err == nil && r.ms != nil {
|
||||
r.ms.DeleteMediaFile(id)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) DeleteAllMissing() (int64, error) {
|
||||
user := loggedUser(r.ctx)
|
||||
if !user.IsAdmin {
|
||||
return 0, rest.ErrPermissionDenied
|
||||
}
|
||||
del := Delete(r.tableName).Where(Eq{"missing": true})
|
||||
var ids []string
|
||||
if r.ms != nil {
|
||||
_ = r.db.Select("id").From(r.tableName).Where(dbx.HashExp{"missing": true}).Column(&ids)
|
||||
}
|
||||
c, err := r.executeSQL(del)
|
||||
if err == nil && r.ms != nil && len(ids) > 0 {
|
||||
r.ms.DeleteMediaFiles(ids)
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) DeleteMissing(ids []string) error {
|
||||
user := loggedUser(r.ctx)
|
||||
if !user.IsAdmin {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
err := r.delete(
|
||||
And{
|
||||
Eq{"missing": true},
|
||||
Eq{"id": ids},
|
||||
},
|
||||
)
|
||||
if err == nil && r.ms != nil {
|
||||
r.ms.DeleteMediaFiles(ids)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) MarkMissing(missing bool, mfs ...*model.MediaFile) error {
|
||||
ids := slice.SeqFunc(mfs, func(m *model.MediaFile) string { return m.ID })
|
||||
for chunk := range slice.CollectChunks(ids, 200) {
|
||||
upd := Update(r.tableName).
|
||||
Set("missing", missing).
|
||||
Set("updated_at", time.Now()).
|
||||
Where(Eq{"id": chunk})
|
||||
c, err := r.executeSQL(upd)
|
||||
if err != nil || c == 0 {
|
||||
log.Error(r.ctx, "Error setting mediafile missing flag", "ids", chunk, err)
|
||||
return err
|
||||
}
|
||||
log.Debug(r.ctx, "Marked missing mediafiles", "total", c, "ids", chunk)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) MarkMissingByFolder(missing bool, folderIDs ...string) error {
|
||||
for chunk := range slices.Chunk(folderIDs, 200) {
|
||||
upd := Update(r.tableName).
|
||||
Set("missing", missing).
|
||||
Set("updated_at", time.Now()).
|
||||
Where(And{
|
||||
Eq{"folder_id": chunk},
|
||||
Eq{"missing": !missing},
|
||||
})
|
||||
c, err := r.executeSQL(upd)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error setting mediafile missing flag", "folderIDs", chunk, err)
|
||||
return err
|
||||
}
|
||||
log.Debug(r.ctx, "Marked missing mediafiles from missing folders", "total", c, "folders", chunk)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMissingAndMatching returns all mediafiles that are missing and their potential matches (comparing PIDs)
|
||||
// that were added/updated after the last scan started. The result is ordered by PID.
|
||||
// It does not need to load bookmarks, annotations and participants, as they are not used by the scanner.
|
||||
func (r *mediaFileRepository) GetMissingAndMatching(libId int) (model.MediaFileCursor, error) {
|
||||
subQ := r.newSelect().Columns("pid").
|
||||
Where(And{
|
||||
Eq{"media_file.missing": true},
|
||||
Eq{"library_id": libId},
|
||||
})
|
||||
subQText, subQArgs, err := subQ.PlaceholderFormat(Question).ToSql()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sel := r.newSelect().Columns("media_file.*", "library.path as library_path", "library.name as library_name").
|
||||
LeftJoin("library on media_file.library_id = library.id").
|
||||
Where("pid in ("+subQText+")", subQArgs...).
|
||||
Where(Or{
|
||||
Eq{"missing": true},
|
||||
ConcatExpr("media_file.created_at > library.last_scan_started_at"),
|
||||
}).
|
||||
OrderBy("pid")
|
||||
cursor, err := queryWithStableResults[dbMediaFile](r.sqlRepository, sel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(yield func(model.MediaFile, error) bool) {
|
||||
for m, err := range cursor {
|
||||
if !yield(*m.MediaFile, err) || err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FindRecentFilesByMBZTrackID finds recently added files by MusicBrainz Track ID in other libraries
|
||||
func (r *mediaFileRepository) FindRecentFilesByMBZTrackID(missing model.MediaFile, since time.Time) (model.MediaFiles, error) {
|
||||
sel := r.selectMediaFile().Where(And{
|
||||
NotEq{"media_file.library_id": missing.LibraryID},
|
||||
Eq{"media_file.mbz_release_track_id": missing.MbzReleaseTrackID},
|
||||
NotEq{"media_file.mbz_release_track_id": ""}, // Exclude empty MBZ Track IDs
|
||||
Eq{"media_file.suffix": missing.Suffix},
|
||||
Gt{"media_file.created_at": since},
|
||||
Eq{"media_file.missing": false},
|
||||
}).OrderBy("media_file.created_at DESC")
|
||||
|
||||
var res dbMediaFiles
|
||||
err := r.queryAll(sel, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.toModels(), nil
|
||||
}
|
||||
|
||||
// FindRecentFilesByProperties finds recently added files by intrinsic properties in other libraries
|
||||
func (r *mediaFileRepository) FindRecentFilesByProperties(missing model.MediaFile, since time.Time) (model.MediaFiles, error) {
|
||||
sel := r.selectMediaFile().Where(And{
|
||||
NotEq{"media_file.library_id": missing.LibraryID},
|
||||
Eq{"media_file.title": missing.Title},
|
||||
Eq{"media_file.size": missing.Size},
|
||||
Eq{"media_file.suffix": missing.Suffix},
|
||||
Eq{"media_file.disc_number": missing.DiscNumber},
|
||||
Eq{"media_file.track_number": missing.TrackNumber},
|
||||
Eq{"media_file.album": missing.Album},
|
||||
Eq{"media_file.mbz_release_track_id": ""}, // Exclude files with MBZ Track ID
|
||||
Gt{"media_file.created_at": since},
|
||||
Eq{"media_file.missing": false},
|
||||
}).OrderBy("media_file.created_at DESC")
|
||||
|
||||
var res dbMediaFiles
|
||||
err := r.queryAll(sel, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.toModels(), nil
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) Search(q string, offset int, size int, options ...model.QueryOptions) (model.MediaFiles, error) {
|
||||
var res dbMediaFiles
|
||||
if uuid.Validate(q) == nil {
|
||||
err := r.searchByMBID(r.selectMediaFile(options...), q, []string{"mbz_recording_id", "mbz_release_track_id"}, &res)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("searching media_file by MBID %q: %w", q, err)
|
||||
}
|
||||
} else {
|
||||
if r.ms != nil {
|
||||
ids, err := r.ms.Search("mediafiles", q, offset, size)
|
||||
if err == nil {
|
||||
if len(ids) == 0 {
|
||||
return model.MediaFiles{}, nil
|
||||
}
|
||||
// Fetch matching media files from the database
|
||||
// We need to fetch all fields to return complete objects
|
||||
mfs, err := r.GetAll(model.QueryOptions{Filters: Eq{"media_file.id": ids}})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching media_files from meilisearch ids: %w", err)
|
||||
}
|
||||
// Reorder results to match Meilisearch order
|
||||
idMap := make(map[string]model.MediaFile, len(mfs))
|
||||
for _, mf := range mfs {
|
||||
idMap[mf.ID] = mf
|
||||
}
|
||||
sorted := make(model.MediaFiles, 0, len(mfs))
|
||||
for _, id := range ids {
|
||||
if mf, ok := idMap[id]; ok {
|
||||
sorted = append(sorted, mf)
|
||||
}
|
||||
}
|
||||
return sorted, nil
|
||||
}
|
||||
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
|
||||
}
|
||||
err := r.doSearch(r.selectMediaFile(options...), q, offset, size, &res, "media_file.rowid", "title")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("searching media_file by query %q: %w", q, err)
|
||||
}
|
||||
}
|
||||
return res.toModels(), nil
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) Read(id string) (interface{}, error) {
|
||||
return r.Get(id)
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
if len(options) > 0 && r.ms != nil {
|
||||
if title, ok := options[0].Filters["title"].(string); ok && title != "" {
|
||||
ids, err := r.ms.Search("mediafiles", title, 0, 10000)
|
||||
if err == nil {
|
||||
log.Debug(r.ctx, "Meilisearch found matches", "count", len(ids), "query", title)
|
||||
delete(options[0].Filters, "title")
|
||||
options[0].Filters["id"] = ids
|
||||
} else {
|
||||
log.Warn(r.ctx, "Meilisearch search failed, falling back to SQL", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) EntityName() string {
|
||||
return "mediafile"
|
||||
}
|
||||
|
||||
func (r *mediaFileRepository) NewInstance() interface{} {
|
||||
return &model.MediaFile{}
|
||||
}
|
||||
|
||||
var _ model.MediaFileRepository = (*mediaFileRepository)(nil)
|
||||
var _ model.ResourceRepository = (*mediaFileRepository)(nil)
|
||||
413
persistence/mediafile_repository_test.go
Normal file
413
persistence/mediafile_repository_test.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/conf/configtest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ = Describe("MediaRepository", func() {
|
||||
var mr model.MediaFileRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid"})
|
||||
mr = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
|
||||
})
|
||||
|
||||
It("gets mediafile from the DB", func() {
|
||||
actual, err := mr.Get("1004")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
actual.CreatedAt = time.Time{}
|
||||
Expect(actual).To(Equal(&songAntenna))
|
||||
})
|
||||
|
||||
It("returns ErrNotFound", func() {
|
||||
_, err := mr.Get("56")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
|
||||
It("counts the number of mediafiles in the DB", func() {
|
||||
Expect(mr.CountAll()).To(Equal(int64(10)))
|
||||
})
|
||||
|
||||
It("returns songs ordered by lyrics with a specific title/artist", func() {
|
||||
// attempt to mimic filters.SongsByArtistTitleWithLyricsFirst, except we want all items
|
||||
results, err := mr.GetAll(model.QueryOptions{
|
||||
Sort: "lyrics, updated_at",
|
||||
Order: "desc",
|
||||
Filters: squirrel.And{
|
||||
squirrel.Eq{"title": "Antenna"},
|
||||
squirrel.Or{
|
||||
Exists("json_tree(participants, '$.albumartist')", squirrel.Eq{"value": "Kraftwerk"}),
|
||||
Exists("json_tree(participants, '$.artist')", squirrel.Eq{"value": "Kraftwerk"}),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
Expect(err).To(BeNil())
|
||||
Expect(results).To(HaveLen(3))
|
||||
Expect(results[0].Lyrics).To(Equal(`[{"lang":"xxx","line":[{"value":"This is a set of lyrics"}],"synced":false}]`))
|
||||
for _, item := range results[1:] {
|
||||
Expect(item.Lyrics).To(Equal("[]"))
|
||||
Expect(item.Title).To(Equal("Antenna"))
|
||||
Expect(item.Participants[model.RoleArtist][0].Name).To(Equal("Kraftwerk"))
|
||||
}
|
||||
})
|
||||
|
||||
It("checks existence of mediafiles in the DB", func() {
|
||||
Expect(mr.Exists(songAntenna.ID)).To(BeTrue())
|
||||
Expect(mr.Exists("666")).To(BeFalse())
|
||||
})
|
||||
|
||||
It("delete tracks by id", func() {
|
||||
newID := id.NewRandom()
|
||||
Expect(mr.Put(&model.MediaFile{LibraryID: 1, ID: newID})).To(Succeed())
|
||||
|
||||
Expect(mr.Delete(newID)).To(Succeed())
|
||||
|
||||
_, err := mr.Get(newID)
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
|
||||
It("deletes all missing files", func() {
|
||||
new1 := model.MediaFile{ID: id.NewRandom(), LibraryID: 1}
|
||||
new2 := model.MediaFile{ID: id.NewRandom(), LibraryID: 1}
|
||||
Expect(mr.Put(&new1)).To(Succeed())
|
||||
Expect(mr.Put(&new2)).To(Succeed())
|
||||
Expect(mr.MarkMissing(true, &new1, &new2)).To(Succeed())
|
||||
|
||||
adminCtx := request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", IsAdmin: true})
|
||||
adminRepo := NewMediaFileRepository(adminCtx, GetDBXBuilder(), nil)
|
||||
|
||||
// Ensure the files are marked as missing and we have 2 of them
|
||||
count, err := adminRepo.CountAll(model.QueryOptions{Filters: squirrel.Eq{"missing": true}})
|
||||
Expect(count).To(BeNumerically("==", 2))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
count, err = adminRepo.DeleteAllMissing()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(BeNumerically("==", 2))
|
||||
|
||||
_, err = mr.Get(new1.ID)
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
_, err = mr.Get(new2.ID)
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
|
||||
Context("Annotations", func() {
|
||||
It("increments play count when the tracks does not have annotations", func() {
|
||||
id := "incplay.firsttime"
|
||||
Expect(mr.Put(&model.MediaFile{LibraryID: 1, ID: id})).To(BeNil())
|
||||
playDate := time.Now()
|
||||
Expect(mr.IncPlayCount(id, playDate)).To(BeNil())
|
||||
|
||||
mf, err := mr.Get(id)
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
Expect(mf.PlayDate.Unix()).To(Equal(playDate.Unix()))
|
||||
Expect(mf.PlayCount).To(Equal(int64(1)))
|
||||
})
|
||||
|
||||
It("preserves play date if and only if provided date is older", func() {
|
||||
id := "incplay.playdate"
|
||||
Expect(mr.Put(&model.MediaFile{LibraryID: 1, ID: id})).To(BeNil())
|
||||
playDate := time.Now()
|
||||
Expect(mr.IncPlayCount(id, playDate)).To(BeNil())
|
||||
mf, err := mr.Get(id)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(mf.PlayDate.Unix()).To(Equal(playDate.Unix()))
|
||||
Expect(mf.PlayCount).To(Equal(int64(1)))
|
||||
|
||||
playDateLate := playDate.AddDate(0, 0, 1)
|
||||
Expect(mr.IncPlayCount(id, playDateLate)).To(BeNil())
|
||||
mf, err = mr.Get(id)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(mf.PlayDate.Unix()).To(Equal(playDateLate.Unix()))
|
||||
Expect(mf.PlayCount).To(Equal(int64(2)))
|
||||
|
||||
playDateEarly := playDate.AddDate(0, 0, -1)
|
||||
Expect(mr.IncPlayCount(id, playDateEarly)).To(BeNil())
|
||||
mf, err = mr.Get(id)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(mf.PlayDate.Unix()).To(Equal(playDateLate.Unix()))
|
||||
Expect(mf.PlayCount).To(Equal(int64(3)))
|
||||
})
|
||||
|
||||
It("increments play count on newly starred items", func() {
|
||||
id := "star.incplay"
|
||||
Expect(mr.Put(&model.MediaFile{LibraryID: 1, ID: id})).To(BeNil())
|
||||
Expect(mr.SetStar(true, id)).To(BeNil())
|
||||
playDate := time.Now()
|
||||
Expect(mr.IncPlayCount(id, playDate)).To(BeNil())
|
||||
|
||||
mf, err := mr.Get(id)
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
Expect(mf.PlayDate.Unix()).To(Equal(playDate.Unix()))
|
||||
Expect(mf.PlayCount).To(Equal(int64(1)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Sort options", func() {
|
||||
Context("recently_added sort", func() {
|
||||
var testMediaFiles []model.MediaFile
|
||||
|
||||
BeforeEach(func() {
|
||||
DeferCleanup(configtest.SetupConfig())
|
||||
|
||||
// Create test media files with specific timestamps
|
||||
testMediaFiles = []model.MediaFile{
|
||||
{
|
||||
ID: id.NewRandom(),
|
||||
LibraryID: 1,
|
||||
Title: "Old Song",
|
||||
Path: "/test/old.mp3",
|
||||
},
|
||||
{
|
||||
ID: id.NewRandom(),
|
||||
LibraryID: 1,
|
||||
Title: "Middle Song",
|
||||
Path: "/test/middle.mp3",
|
||||
},
|
||||
{
|
||||
ID: id.NewRandom(),
|
||||
LibraryID: 1,
|
||||
Title: "New Song",
|
||||
Path: "/test/new.mp3",
|
||||
},
|
||||
}
|
||||
|
||||
// Insert test data first
|
||||
for i := range testMediaFiles {
|
||||
Expect(mr.Put(&testMediaFiles[i])).To(Succeed())
|
||||
}
|
||||
|
||||
// Then manually update timestamps using direct SQL to bypass the repository logic
|
||||
db := GetDBXBuilder()
|
||||
|
||||
// Set specific timestamps for testing
|
||||
oldTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
middleTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
newTime := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// Update "Old Song": created long ago, updated recently
|
||||
_, err := db.Update("media_file",
|
||||
map[string]interface{}{
|
||||
"created_at": oldTime,
|
||||
"updated_at": newTime,
|
||||
},
|
||||
dbx.HashExp{"id": testMediaFiles[0].ID}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Update "Middle Song": created and updated at the same middle time
|
||||
_, err = db.Update("media_file",
|
||||
map[string]interface{}{
|
||||
"created_at": middleTime,
|
||||
"updated_at": middleTime,
|
||||
},
|
||||
dbx.HashExp{"id": testMediaFiles[1].ID}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Update "New Song": created recently, updated long ago
|
||||
_, err = db.Update("media_file",
|
||||
map[string]interface{}{
|
||||
"created_at": newTime,
|
||||
"updated_at": oldTime,
|
||||
},
|
||||
dbx.HashExp{"id": testMediaFiles[2].ID}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up test data
|
||||
for _, mf := range testMediaFiles {
|
||||
_ = mr.Delete(mf.ID)
|
||||
}
|
||||
})
|
||||
|
||||
When("RecentlyAddedByModTime is false", func() {
|
||||
var testRepo model.MediaFileRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
conf.Server.RecentlyAddedByModTime = false
|
||||
// Create repository AFTER setting config
|
||||
ctx := log.NewContext(GinkgoT().Context())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid"})
|
||||
testRepo = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
|
||||
})
|
||||
|
||||
It("sorts by created_at", func() {
|
||||
// Get results sorted by recently_added (should use created_at)
|
||||
results, err := testRepo.GetAll(model.QueryOptions{
|
||||
Sort: "recently_added",
|
||||
Order: "desc",
|
||||
Filters: squirrel.Eq{"media_file.id": []string{testMediaFiles[0].ID, testMediaFiles[1].ID, testMediaFiles[2].ID}},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(3))
|
||||
|
||||
// Verify sorting by created_at (newest first in descending order)
|
||||
Expect(results[0].Title).To(Equal("New Song")) // created 2022
|
||||
Expect(results[1].Title).To(Equal("Middle Song")) // created 2021
|
||||
Expect(results[2].Title).To(Equal("Old Song")) // created 2020
|
||||
})
|
||||
|
||||
It("sorts in ascending order when specified", func() {
|
||||
// Get results sorted by recently_added in ascending order
|
||||
results, err := testRepo.GetAll(model.QueryOptions{
|
||||
Sort: "recently_added",
|
||||
Order: "asc",
|
||||
Filters: squirrel.Eq{"media_file.id": []string{testMediaFiles[0].ID, testMediaFiles[1].ID, testMediaFiles[2].ID}},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(3))
|
||||
|
||||
// Verify sorting by created_at (oldest first)
|
||||
Expect(results[0].Title).To(Equal("Old Song")) // created 2020
|
||||
Expect(results[1].Title).To(Equal("Middle Song")) // created 2021
|
||||
Expect(results[2].Title).To(Equal("New Song")) // created 2022
|
||||
})
|
||||
})
|
||||
|
||||
When("RecentlyAddedByModTime is true", func() {
|
||||
var testRepo model.MediaFileRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
conf.Server.RecentlyAddedByModTime = true
|
||||
// Create repository AFTER setting config
|
||||
ctx := log.NewContext(GinkgoT().Context())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid"})
|
||||
testRepo = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
|
||||
})
|
||||
|
||||
It("sorts by updated_at", func() {
|
||||
// Get results sorted by recently_added (should use updated_at)
|
||||
results, err := testRepo.GetAll(model.QueryOptions{
|
||||
Sort: "recently_added",
|
||||
Order: "desc",
|
||||
Filters: squirrel.Eq{"media_file.id": []string{testMediaFiles[0].ID, testMediaFiles[1].ID, testMediaFiles[2].ID}},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(3))
|
||||
|
||||
// Verify sorting by updated_at (newest first in descending order)
|
||||
Expect(results[0].Title).To(Equal("Old Song")) // updated 2022
|
||||
Expect(results[1].Title).To(Equal("Middle Song")) // updated 2021
|
||||
Expect(results[2].Title).To(Equal("New Song")) // updated 2020
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Search", func() {
|
||||
Context("text search", func() {
|
||||
It("finds media files by title", func() {
|
||||
results, err := mr.Search("Antenna", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(3)) // songAntenna, songAntennaWithLyrics, songAntenna2
|
||||
for _, result := range results {
|
||||
Expect(result.Title).To(Equal("Antenna"))
|
||||
}
|
||||
})
|
||||
|
||||
It("finds media files case insensitively", func() {
|
||||
results, err := mr.Search("antenna", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(3))
|
||||
for _, result := range results {
|
||||
Expect(result.Title).To(Equal("Antenna"))
|
||||
}
|
||||
})
|
||||
|
||||
It("returns empty result when no matches found", func() {
|
||||
results, err := mr.Search("nonexistent", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("MBID search", func() {
|
||||
var mediaFileWithMBID model.MediaFile
|
||||
var raw *mediaFileRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
raw = mr.(*mediaFileRepository)
|
||||
// Create a test media file with MBID
|
||||
mediaFileWithMBID = model.MediaFile{
|
||||
ID: "test-mbid-mediafile",
|
||||
Title: "Test MBID MediaFile",
|
||||
MbzRecordingID: "550e8400-e29b-41d4-a716-446655440020", // Valid UUID v4
|
||||
MbzReleaseTrackID: "550e8400-e29b-41d4-a716-446655440021", // Valid UUID v4
|
||||
LibraryID: 1,
|
||||
Path: "/test/path/test.mp3",
|
||||
}
|
||||
|
||||
// Insert the test media file into the database
|
||||
err := mr.Put(&mediaFileWithMBID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up test data using direct SQL
|
||||
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": mediaFileWithMBID.ID}))
|
||||
})
|
||||
|
||||
It("finds media file by mbz_recording_id", func() {
|
||||
results, err := mr.Search("550e8400-e29b-41d4-a716-446655440020", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].ID).To(Equal("test-mbid-mediafile"))
|
||||
Expect(results[0].Title).To(Equal("Test MBID MediaFile"))
|
||||
})
|
||||
|
||||
It("finds media file by mbz_release_track_id", func() {
|
||||
results, err := mr.Search("550e8400-e29b-41d4-a716-446655440021", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(HaveLen(1))
|
||||
Expect(results[0].ID).To(Equal("test-mbid-mediafile"))
|
||||
Expect(results[0].Title).To(Equal("Test MBID MediaFile"))
|
||||
})
|
||||
|
||||
It("returns empty result when MBID is not found", func() {
|
||||
results, err := mr.Search("550e8400-e29b-41d4-a716-446655440099", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("missing media files are never returned by search", func() {
|
||||
// Create a missing media file with MBID
|
||||
missingMediaFile := model.MediaFile{
|
||||
ID: "test-missing-mbid-mediafile",
|
||||
Title: "Test Missing MBID MediaFile",
|
||||
MbzRecordingID: "550e8400-e29b-41d4-a716-446655440022",
|
||||
LibraryID: 1,
|
||||
Path: "/test/path/missing.mp3",
|
||||
Missing: true,
|
||||
}
|
||||
|
||||
err := mr.Put(&missingMediaFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Search never returns missing media files (hardcoded behavior)
|
||||
results, err := mr.Search("550e8400-e29b-41d4-a716-446655440022", 0, 10)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(results).To(BeEmpty())
|
||||
|
||||
// Clean up
|
||||
_, _ = raw.executeSQL(squirrel.Delete(raw.tableName).Where(squirrel.Eq{"id": missingMediaFile.ID}))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
191
persistence/meilisearch.go
Normal file
191
persistence/meilisearch.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/meilisearch/meilisearch-go"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
)
|
||||
|
||||
type MeilisearchService struct {
|
||||
client meilisearch.ServiceManager
|
||||
}
|
||||
|
||||
func NewMeilisearchService() *MeilisearchService {
|
||||
if !conf.Server.Meilisearch.Enabled {
|
||||
return nil
|
||||
}
|
||||
client := meilisearch.New(conf.Server.Meilisearch.Host, meilisearch.WithAPIKey(conf.Server.Meilisearch.ApiKey))
|
||||
return &MeilisearchService{client: client}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) IndexMediaFile(mf *model.MediaFile) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.IndexMediaFiles([]model.MediaFile{*mf})
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) IndexMediaFiles(mfs []model.MediaFile) {
|
||||
if s == nil || len(mfs) == 0 {
|
||||
return
|
||||
}
|
||||
docs := make([]map[string]interface{}, len(mfs))
|
||||
for i, mf := range mfs {
|
||||
docs[i] = map[string]interface{}{
|
||||
"id": mf.ID,
|
||||
"title": mf.Title,
|
||||
"artist": mf.Artist,
|
||||
"album": mf.Album,
|
||||
"albumArtist": mf.AlbumArtist,
|
||||
"path": mf.Path,
|
||||
"year": mf.Year,
|
||||
"genre": mf.Genre,
|
||||
}
|
||||
}
|
||||
_, err := s.client.Index("mediafiles").AddDocuments(docs, nil)
|
||||
if err != nil {
|
||||
log.Error("Error indexing mediafiles", "count", len(mfs), err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) DeleteMediaFile(id string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
_, err := s.client.Index("mediafiles").DeleteDocument(id)
|
||||
if err != nil {
|
||||
log.Error("Error deleting mediafile from index", "id", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) DeleteMediaFiles(ids []string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
_, err := s.client.Index("mediafiles").DeleteDocuments(ids)
|
||||
if err != nil {
|
||||
log.Error("Error deleting mediafiles from index", "ids", ids, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) IndexAlbum(album *model.Album) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.IndexAlbums([]model.Album{*album})
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) IndexAlbums(albums []model.Album) {
|
||||
if s == nil || len(albums) == 0 {
|
||||
return
|
||||
}
|
||||
docs := make([]map[string]interface{}, len(albums))
|
||||
for i, album := range albums {
|
||||
docs[i] = map[string]interface{}{
|
||||
"id": album.ID,
|
||||
"name": album.Name,
|
||||
"artist": album.AlbumArtist,
|
||||
"albumArtist": album.AlbumArtist,
|
||||
"year": album.MinYear,
|
||||
"genre": album.Genre,
|
||||
}
|
||||
}
|
||||
_, err := s.client.Index("albums").AddDocuments(docs, nil)
|
||||
if err != nil {
|
||||
log.Error("Error indexing albums", "count", len(albums), err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) DeleteAlbum(id string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
_, err := s.client.Index("albums").DeleteDocument(id)
|
||||
if err != nil {
|
||||
log.Error("Error deleting album from index", "id", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) DeleteAlbums(ids []string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
_, err := s.client.Index("albums").DeleteDocuments(ids)
|
||||
if err != nil {
|
||||
log.Error("Error deleting albums from index", "ids", ids, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) IndexArtist(artist *model.Artist) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.IndexArtists([]model.Artist{*artist})
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) IndexArtists(artists []model.Artist) {
|
||||
if s == nil || len(artists) == 0 {
|
||||
return
|
||||
}
|
||||
docs := make([]map[string]interface{}, len(artists))
|
||||
for i, artist := range artists {
|
||||
docs[i] = map[string]interface{}{
|
||||
"id": artist.ID,
|
||||
"name": artist.Name,
|
||||
}
|
||||
}
|
||||
_, err := s.client.Index("artists").AddDocuments(docs, nil)
|
||||
if err != nil {
|
||||
log.Error("Error indexing artists", "count", len(artists), err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) DeleteArtist(id string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
_, err := s.client.Index("artists").DeleteDocument(id)
|
||||
if err != nil {
|
||||
log.Error("Error deleting artist from index", "id", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) DeleteArtists(ids []string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
_, err := s.client.Index("artists").DeleteDocuments(ids)
|
||||
if err != nil {
|
||||
log.Error("Error deleting artists from index", "ids", ids, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MeilisearchService) Search(indexName string, query string, offset, limit int) ([]string, error) {
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("meilisearch is not enabled")
|
||||
}
|
||||
searchRes, err := s.client.Index(indexName).Search(query, &meilisearch.SearchRequest{
|
||||
Offset: int64(offset),
|
||||
Limit: int64(limit),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ids []string
|
||||
for _, hit := range searchRes.Hits {
|
||||
if id, ok := hit["id"]; ok {
|
||||
var idStr string
|
||||
if err := json.Unmarshal(id, &idStr); err == nil {
|
||||
ids = append(ids, idStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle case where id might be non-string if necessary, though simpler is better for now.
|
||||
// Meilisearch returns map[string]interface{}
|
||||
return ids, nil
|
||||
}
|
||||
246
persistence/persistence.go
Normal file
246
persistence/persistence.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/navidrome/navidrome/db"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/run"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type SQLStore struct {
|
||||
db dbx.Builder
|
||||
ms *MeilisearchService
|
||||
}
|
||||
|
||||
func New(conn *sql.DB) model.DataStore {
|
||||
return &SQLStore{
|
||||
db: dbx.NewFromDB(conn, db.Driver),
|
||||
ms: NewMeilisearchService(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SQLStore) Album(ctx context.Context) model.AlbumRepository {
|
||||
return NewAlbumRepository(ctx, s.getDBXBuilder(), s.ms)
|
||||
}
|
||||
|
||||
func (s *SQLStore) Artist(ctx context.Context) model.ArtistRepository {
|
||||
return NewArtistRepository(ctx, s.getDBXBuilder(), s.ms)
|
||||
}
|
||||
|
||||
func (s *SQLStore) MediaFile(ctx context.Context) model.MediaFileRepository {
|
||||
return NewMediaFileRepository(ctx, s.getDBXBuilder(), s.ms)
|
||||
}
|
||||
|
||||
func (s *SQLStore) Library(ctx context.Context) model.LibraryRepository {
|
||||
return NewLibraryRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Folder(ctx context.Context) model.FolderRepository {
|
||||
return newFolderRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Genre(ctx context.Context) model.GenreRepository {
|
||||
return NewGenreRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Tag(ctx context.Context) model.TagRepository {
|
||||
return NewTagRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) PlayQueue(ctx context.Context) model.PlayQueueRepository {
|
||||
return NewPlayQueueRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Playlist(ctx context.Context) model.PlaylistRepository {
|
||||
return NewPlaylistRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Property(ctx context.Context) model.PropertyRepository {
|
||||
return NewPropertyRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Radio(ctx context.Context) model.RadioRepository {
|
||||
return NewRadioRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) UserProps(ctx context.Context) model.UserPropsRepository {
|
||||
return NewUserPropsRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Share(ctx context.Context) model.ShareRepository {
|
||||
return NewShareRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) User(ctx context.Context) model.UserRepository {
|
||||
return NewUserRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Transcoding(ctx context.Context) model.TranscodingRepository {
|
||||
return NewTranscodingRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Player(ctx context.Context) model.PlayerRepository {
|
||||
return NewPlayerRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) ScrobbleBuffer(ctx context.Context) model.ScrobbleBufferRepository {
|
||||
return NewScrobbleBufferRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Scrobble(ctx context.Context) model.ScrobbleRepository {
|
||||
return NewScrobbleRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Resource(ctx context.Context, m interface{}) model.ResourceRepository {
|
||||
switch m.(type) {
|
||||
case model.User:
|
||||
return s.User(ctx).(model.ResourceRepository)
|
||||
case model.Transcoding:
|
||||
return s.Transcoding(ctx).(model.ResourceRepository)
|
||||
case model.Player:
|
||||
return s.Player(ctx).(model.ResourceRepository)
|
||||
case model.Artist:
|
||||
return s.Artist(ctx).(model.ResourceRepository)
|
||||
case model.Album:
|
||||
return s.Album(ctx).(model.ResourceRepository)
|
||||
case model.MediaFile:
|
||||
return s.MediaFile(ctx).(model.ResourceRepository)
|
||||
case model.Genre:
|
||||
return s.Genre(ctx).(model.ResourceRepository)
|
||||
case model.Playlist:
|
||||
return s.Playlist(ctx).(model.ResourceRepository)
|
||||
case model.Radio:
|
||||
return s.Radio(ctx).(model.ResourceRepository)
|
||||
case model.Share:
|
||||
return s.Share(ctx).(model.ResourceRepository)
|
||||
case model.Tag:
|
||||
return s.Tag(ctx).(model.ResourceRepository)
|
||||
}
|
||||
log.Error("Resource not implemented", "model", reflect.TypeOf(m).Name())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) ReindexAll(ctx context.Context) error {
|
||||
if s.ms == nil {
|
||||
return nil
|
||||
}
|
||||
log.Info("Starting full re-index")
|
||||
// Index Artists
|
||||
artists, err := s.Artist(ctx).GetAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching artists: %w", err)
|
||||
}
|
||||
s.ms.IndexArtists(artists)
|
||||
log.Info(ctx, "Indexed artists", "count", len(artists))
|
||||
|
||||
// Index Albums
|
||||
albums, err := s.Album(ctx).GetAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching albums: %w", err)
|
||||
}
|
||||
s.ms.IndexAlbums(albums)
|
||||
log.Info(ctx, "Indexed albums", "count", len(albums))
|
||||
|
||||
// Index MediaFiles
|
||||
mfs, err := s.MediaFile(ctx).GetAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching media files: %w", err)
|
||||
}
|
||||
batchSize := 2000
|
||||
for i := 0; i < len(mfs); i += batchSize {
|
||||
end := i + batchSize
|
||||
if end > len(mfs) {
|
||||
end = len(mfs)
|
||||
}
|
||||
s.ms.IndexMediaFiles(mfs[i:end])
|
||||
log.Info(ctx, "Indexed media files batch", "start", i, "end", end)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) WithTx(block func(tx model.DataStore) error, scope ...string) error {
|
||||
var msg string
|
||||
if len(scope) > 0 {
|
||||
msg = scope[0]
|
||||
}
|
||||
start := time.Now()
|
||||
conn, inTx := s.db.(*dbx.DB)
|
||||
if !inTx {
|
||||
log.Trace("Nested Transaction started", "scope", msg)
|
||||
conn = dbx.NewFromDB(db.Db(), db.Driver)
|
||||
} else {
|
||||
log.Trace("Transaction started", "scope", msg)
|
||||
}
|
||||
return conn.Transactional(func(tx *dbx.Tx) error {
|
||||
newDb := &SQLStore{db: tx}
|
||||
err := block(newDb)
|
||||
if !inTx {
|
||||
log.Trace("Nested Transaction finished", "scope", msg, "elapsed", time.Since(start), err)
|
||||
} else {
|
||||
log.Trace("Transaction finished", "scope", msg, "elapsed", time.Since(start), err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SQLStore) WithTxImmediate(block func(tx model.DataStore) error, scope ...string) error {
|
||||
ctx := context.Background()
|
||||
return s.WithTx(func(tx model.DataStore) error {
|
||||
// Workaround to force the transaction to be upgraded to immediate mode to avoid deadlocks
|
||||
// See https://berthub.eu/articles/posts/a-brief-post-on-sqlite3-database-locked-despite-timeout/
|
||||
_ = tx.Property(ctx).Put("tmp_lock_flag", "")
|
||||
defer func() {
|
||||
_ = tx.Property(ctx).Delete("tmp_lock_flag")
|
||||
}()
|
||||
|
||||
return block(tx)
|
||||
}, scope...)
|
||||
}
|
||||
|
||||
func (s *SQLStore) GC(ctx context.Context, libraryIDs ...int) error {
|
||||
trace := func(ctx context.Context, msg string, f func() error) func() error {
|
||||
return func() error {
|
||||
start := time.Now()
|
||||
err := f()
|
||||
log.Debug(ctx, "GC: "+msg, "elapsed", time.Since(start), err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If libraryIDs are provided, scope operations to those libraries where possible
|
||||
scoped := len(libraryIDs) > 0
|
||||
if scoped {
|
||||
log.Debug(ctx, "GC: Running selective garbage collection", "libraryIDs", libraryIDs)
|
||||
}
|
||||
|
||||
err := run.Sequentially(
|
||||
trace(ctx, "purge empty albums", func() error { return s.Album(ctx).(*albumRepository).purgeEmpty(libraryIDs...) }),
|
||||
trace(ctx, "purge empty artists", func() error { return s.Artist(ctx).(*artistRepository).purgeEmpty() }),
|
||||
trace(ctx, "mark missing artists", func() error { return s.Artist(ctx).(*artistRepository).markMissing() }),
|
||||
trace(ctx, "purge empty folders", func() error { return s.Folder(ctx).(*folderRepository).purgeEmpty(libraryIDs...) }),
|
||||
trace(ctx, "clean album annotations", func() error { return s.Album(ctx).(*albumRepository).cleanAnnotations() }),
|
||||
trace(ctx, "clean artist annotations", func() error { return s.Artist(ctx).(*artistRepository).cleanAnnotations() }),
|
||||
trace(ctx, "clean media file annotations", func() error { return s.MediaFile(ctx).(*mediaFileRepository).cleanAnnotations() }),
|
||||
trace(ctx, "clean media file bookmarks", func() error { return s.MediaFile(ctx).(*mediaFileRepository).cleanBookmarks() }),
|
||||
trace(ctx, "purge non used tags", func() error { return s.Tag(ctx).(*tagRepository).purgeUnused() }),
|
||||
trace(ctx, "remove orphan playlist tracks", func() error { return s.Playlist(ctx).(*playlistRepository).removeOrphans() }),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error(ctx, "Error tidying up database", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLStore) getDBXBuilder() dbx.Builder {
|
||||
if s.db == nil {
|
||||
return dbx.NewFromDB(db.Db(), db.Driver)
|
||||
}
|
||||
return s.db
|
||||
}
|
||||
271
persistence/persistence_suite_test.go
Normal file
271
persistence/persistence_suite_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/db"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
"github.com/navidrome/navidrome/tests"
|
||||
"github.com/navidrome/navidrome/utils/gg"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
func TestPersistence(t *testing.T) {
|
||||
tests.Init(t, true)
|
||||
|
||||
//os.Remove("./test-123.db")
|
||||
//conf.Server.DbPath = "./test-123.db"
|
||||
conf.Server.DbPath = "file::memory:?cache=shared&_foreign_keys=on"
|
||||
defer db.Init(context.Background())()
|
||||
log.SetLevel(log.LevelFatal)
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Persistence Suite")
|
||||
}
|
||||
|
||||
func mf(mf model.MediaFile) model.MediaFile {
|
||||
mf.Tags = model.Tags{}
|
||||
mf.LibraryID = 1
|
||||
mf.LibraryPath = "music" // Default folder
|
||||
mf.LibraryName = "Music Library"
|
||||
mf.Participants = model.Participants{
|
||||
model.RoleArtist: model.ParticipantList{
|
||||
model.Participant{Artist: model.Artist{ID: mf.ArtistID, Name: mf.Artist}},
|
||||
},
|
||||
}
|
||||
if mf.Lyrics == "" {
|
||||
mf.Lyrics = "[]"
|
||||
}
|
||||
return mf
|
||||
}
|
||||
|
||||
func al(al model.Album) model.Album {
|
||||
al.LibraryID = 1
|
||||
al.LibraryPath = "music"
|
||||
al.LibraryName = "Music Library"
|
||||
al.Discs = model.Discs{}
|
||||
al.Tags = model.Tags{}
|
||||
al.Participants = model.Participants{}
|
||||
return al
|
||||
}
|
||||
|
||||
var (
|
||||
artistKraftwerk = model.Artist{ID: "2", Name: "Kraftwerk", OrderArtistName: "kraftwerk"}
|
||||
artistBeatles = model.Artist{ID: "3", Name: "The Beatles", OrderArtistName: "beatles"}
|
||||
testArtists = model.Artists{
|
||||
artistKraftwerk,
|
||||
artistBeatles,
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
albumSgtPeppers = al(model.Album{ID: "101", Name: "Sgt Peppers", AlbumArtist: "The Beatles", OrderAlbumName: "sgt peppers", AlbumArtistID: "3", EmbedArtPath: p("/beatles/1/sgt/a day.mp3"), SongCount: 1, MaxYear: 1967})
|
||||
albumAbbeyRoad = al(model.Album{ID: "102", Name: "Abbey Road", AlbumArtist: "The Beatles", OrderAlbumName: "abbey road", AlbumArtistID: "3", EmbedArtPath: p("/beatles/1/come together.mp3"), SongCount: 1, MaxYear: 1969})
|
||||
albumRadioactivity = al(model.Album{ID: "103", Name: "Radioactivity", AlbumArtist: "Kraftwerk", OrderAlbumName: "radioactivity", AlbumArtistID: "2", EmbedArtPath: p("/kraft/radio/radio.mp3"), SongCount: 2})
|
||||
albumMultiDisc = al(model.Album{ID: "104", Name: "Multi Disc Album", AlbumArtist: "Test Artist", OrderAlbumName: "multi disc album", AlbumArtistID: "1", EmbedArtPath: p("/test/multi/disc1/track1.mp3"), SongCount: 4})
|
||||
testAlbums = model.Albums{
|
||||
albumSgtPeppers,
|
||||
albumAbbeyRoad,
|
||||
albumRadioactivity,
|
||||
albumMultiDisc,
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
songDayInALife = mf(model.MediaFile{ID: "1001", Title: "A Day In A Life", ArtistID: "3", Artist: "The Beatles", AlbumID: "101", Album: "Sgt Peppers", Path: p("/beatles/1/sgt/a day.mp3")})
|
||||
songComeTogether = mf(model.MediaFile{ID: "1002", Title: "Come Together", ArtistID: "3", Artist: "The Beatles", AlbumID: "102", Album: "Abbey Road", Path: p("/beatles/1/come together.mp3")})
|
||||
songRadioactivity = mf(model.MediaFile{ID: "1003", Title: "Radioactivity", ArtistID: "2", Artist: "Kraftwerk", AlbumID: "103", Album: "Radioactivity", Path: p("/kraft/radio/radio.mp3")})
|
||||
songAntenna = mf(model.MediaFile{ID: "1004", Title: "Antenna", ArtistID: "2", Artist: "Kraftwerk",
|
||||
AlbumID: "103",
|
||||
Path: p("/kraft/radio/antenna.mp3"),
|
||||
RGAlbumGain: gg.P(1.0), RGAlbumPeak: gg.P(2.0), RGTrackGain: gg.P(3.0), RGTrackPeak: gg.P(4.0),
|
||||
})
|
||||
songAntennaWithLyrics = mf(model.MediaFile{
|
||||
ID: "1005",
|
||||
Title: "Antenna",
|
||||
ArtistID: "2",
|
||||
Artist: "Kraftwerk",
|
||||
AlbumID: "103",
|
||||
Lyrics: `[{"lang":"xxx","line":[{"value":"This is a set of lyrics"}],"synced":false}]`,
|
||||
})
|
||||
songAntenna2 = mf(model.MediaFile{ID: "1006", Title: "Antenna", ArtistID: "2", Artist: "Kraftwerk", AlbumID: "103"})
|
||||
// Multi-disc album tracks (intentionally out of order to test sorting)
|
||||
songDisc2Track11 = mf(model.MediaFile{ID: "2001", Title: "Disc 2 Track 11", ArtistID: "1", Artist: "Test Artist", AlbumID: "104", Album: "Multi Disc Album", DiscNumber: 2, TrackNumber: 11, Path: p("/test/multi/disc2/track11.mp3"), OrderAlbumName: "multi disc album", OrderArtistName: "test artist"})
|
||||
songDisc1Track01 = mf(model.MediaFile{ID: "2002", Title: "Disc 1 Track 1", ArtistID: "1", Artist: "Test Artist", AlbumID: "104", Album: "Multi Disc Album", DiscNumber: 1, TrackNumber: 1, Path: p("/test/multi/disc1/track1.mp3"), OrderAlbumName: "multi disc album", OrderArtistName: "test artist"})
|
||||
songDisc2Track01 = mf(model.MediaFile{ID: "2003", Title: "Disc 2 Track 1", ArtistID: "1", Artist: "Test Artist", AlbumID: "104", Album: "Multi Disc Album", DiscNumber: 2, TrackNumber: 1, Path: p("/test/multi/disc2/track1.mp3"), OrderAlbumName: "multi disc album", OrderArtistName: "test artist"})
|
||||
songDisc1Track02 = mf(model.MediaFile{ID: "2004", Title: "Disc 1 Track 2", ArtistID: "1", Artist: "Test Artist", AlbumID: "104", Album: "Multi Disc Album", DiscNumber: 1, TrackNumber: 2, Path: p("/test/multi/disc1/track2.mp3"), OrderAlbumName: "multi disc album", OrderArtistName: "test artist"})
|
||||
testSongs = model.MediaFiles{
|
||||
songDayInALife,
|
||||
songComeTogether,
|
||||
songRadioactivity,
|
||||
songAntenna,
|
||||
songAntennaWithLyrics,
|
||||
songAntenna2,
|
||||
songDisc2Track11,
|
||||
songDisc1Track01,
|
||||
songDisc2Track01,
|
||||
songDisc1Track02,
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
radioWithoutHomePage = model.Radio{ID: "1235", StreamUrl: "https://example.com:8000/1/stream.mp3", HomePageUrl: "", Name: "No Homepage"}
|
||||
radioWithHomePage = model.Radio{ID: "5010", StreamUrl: "https://example.com/stream.mp3", Name: "Example Radio", HomePageUrl: "https://example.com"}
|
||||
testRadios = model.Radios{radioWithoutHomePage, radioWithHomePage}
|
||||
)
|
||||
|
||||
var (
|
||||
plsBest model.Playlist
|
||||
plsCool model.Playlist
|
||||
testPlaylists []*model.Playlist
|
||||
)
|
||||
|
||||
var (
|
||||
adminUser = model.User{ID: "userid", UserName: "userid", Name: "admin", Email: "admin@email.com", IsAdmin: true}
|
||||
regularUser = model.User{ID: "2222", UserName: "regular-user", Name: "Regular User", Email: "regular@example.com"}
|
||||
testUsers = model.Users{adminUser, regularUser}
|
||||
)
|
||||
|
||||
func p(path string) string {
|
||||
return filepath.FromSlash(path)
|
||||
}
|
||||
|
||||
// Initialize test DB
|
||||
// TODO Load this data setup from file(s)
|
||||
var _ = BeforeSuite(func() {
|
||||
conn := GetDBXBuilder()
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, adminUser)
|
||||
|
||||
ur := NewUserRepository(ctx, conn)
|
||||
for i := range testUsers {
|
||||
err := ur.Put(&testUsers[i])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Associate users with library 1 (default test library)
|
||||
for i := range testUsers {
|
||||
err := ur.SetUserLibraries(testUsers[i].ID, []int{1})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
alr := NewAlbumRepository(ctx, conn, nil).(*albumRepository)
|
||||
for i := range testAlbums {
|
||||
a := testAlbums[i]
|
||||
err := alr.Put(&a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
arr := NewArtistRepository(ctx, conn, nil)
|
||||
for i := range testArtists {
|
||||
a := testArtists[i]
|
||||
err := arr.Put(&a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Associate artists with library 1 (default test library)
|
||||
lr := NewLibraryRepository(ctx, conn)
|
||||
for i := range testArtists {
|
||||
err := lr.AddArtist(1, testArtists[i].ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
mr := NewMediaFileRepository(ctx, conn, nil)
|
||||
for i := range testSongs {
|
||||
err := mr.Put(&testSongs[i])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
rar := NewRadioRepository(ctx, conn)
|
||||
for i := range testRadios {
|
||||
r := testRadios[i]
|
||||
err := rar.Put(&r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
plsBest = model.Playlist{
|
||||
Name: "Best",
|
||||
Comment: "No Comments",
|
||||
OwnerID: "userid",
|
||||
OwnerName: "userid",
|
||||
Public: true,
|
||||
SongCount: 2,
|
||||
}
|
||||
plsBest.AddMediaFilesByID([]string{"1001", "1003"})
|
||||
plsCool = model.Playlist{Name: "Cool", OwnerID: "userid", OwnerName: "userid"}
|
||||
plsCool.AddMediaFilesByID([]string{"1004"})
|
||||
testPlaylists = []*model.Playlist{&plsBest, &plsCool}
|
||||
|
||||
pr := NewPlaylistRepository(ctx, conn)
|
||||
for i := range testPlaylists {
|
||||
err := pr.Put(testPlaylists[i])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare annotations
|
||||
if err := arr.SetStar(true, artistBeatles.ID); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ar, err := arr.Get(artistBeatles.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if ar == nil {
|
||||
panic("artist not found after SetStar")
|
||||
}
|
||||
artistBeatles.Starred = true
|
||||
artistBeatles.StarredAt = ar.StarredAt
|
||||
testArtists[1] = artistBeatles
|
||||
|
||||
if err := alr.SetStar(true, albumRadioactivity.ID); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
al, err := alr.Get(albumRadioactivity.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if al == nil {
|
||||
panic("album not found after SetStar")
|
||||
}
|
||||
albumRadioactivity.Starred = true
|
||||
albumRadioactivity.StarredAt = al.StarredAt
|
||||
testAlbums[2] = albumRadioactivity
|
||||
|
||||
if err := mr.SetStar(true, songComeTogether.ID); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mf, err := mr.Get(songComeTogether.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
songComeTogether.Starred = true
|
||||
songComeTogether.StarredAt = mf.StarredAt
|
||||
testSongs[1] = songComeTogether
|
||||
})
|
||||
|
||||
func GetDBXBuilder() *dbx.DB {
|
||||
return dbx.NewFromDB(db.Db(), db.Dialect)
|
||||
}
|
||||
58
persistence/persistence_test.go
Normal file
58
persistence/persistence_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/navidrome/navidrome/db"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("SQLStore", func() {
|
||||
var ds model.DataStore
|
||||
var ctx context.Context
|
||||
BeforeEach(func() {
|
||||
ds = New(db.Db())
|
||||
ctx = context.Background()
|
||||
})
|
||||
Describe("WithTx", func() {
|
||||
Context("When block returns nil", func() {
|
||||
It("commits changes to the DB", func() {
|
||||
err := ds.WithTx(func(tx model.DataStore) error {
|
||||
pl := tx.Player(ctx)
|
||||
err := pl.Put(&model.Player{ID: "666", UserId: "userid"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
pr := tx.Property(ctx)
|
||||
err = pr.Put("777", "value")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return nil
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ds.Player(ctx).Get("666")).To(Equal(&model.Player{ID: "666", UserId: "userid", Username: "userid"}))
|
||||
Expect(ds.Property(ctx).Get("777")).To(Equal("value"))
|
||||
})
|
||||
})
|
||||
Context("When block returns an error", func() {
|
||||
It("rollbacks changes to the DB", func() {
|
||||
err := ds.WithTx(func(tx model.DataStore) error {
|
||||
pr := tx.Property(ctx)
|
||||
err := pr.Put("999", "value")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Will fail as it is missing the UserName
|
||||
pl := tx.Player(ctx)
|
||||
err = pl.Put(&model.Player{ID: "888"})
|
||||
Expect(err).To(HaveOccurred())
|
||||
return err
|
||||
})
|
||||
Expect(err).To(HaveOccurred())
|
||||
_, err = ds.Property(ctx).Get("999")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
_, err = ds.Player(ctx).Get("888")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
169
persistence/player_repository.go
Normal file
169
persistence/player_repository.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type playerRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewPlayerRepository(ctx context.Context, db dbx.Builder) model.PlayerRepository {
|
||||
r := &playerRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.registerModel(&model.Player{}, map[string]filterFunc{
|
||||
"name": containsFilter("player.name"),
|
||||
})
|
||||
r.setSortMappings(map[string]string{
|
||||
"user_name": "username", //TODO rename all user_name and userName to username
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *playerRepository) Put(p *model.Player) error {
|
||||
_, err := r.put(p.ID, p)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *playerRepository) selectPlayer(options ...model.QueryOptions) SelectBuilder {
|
||||
return r.newSelect(options...).
|
||||
Columns("player.*").
|
||||
Join("user ON player.user_id = user.id").
|
||||
Columns("user.user_name username")
|
||||
}
|
||||
|
||||
func (r *playerRepository) Get(id string) (*model.Player, error) {
|
||||
sel := r.selectPlayer().Where(Eq{"player.id": id})
|
||||
var res model.Player
|
||||
err := r.queryOne(sel, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *playerRepository) FindMatch(userId, client, userAgent string) (*model.Player, error) {
|
||||
sel := r.selectPlayer().Where(And{
|
||||
Eq{"client": client},
|
||||
Eq{"user_agent": userAgent},
|
||||
Eq{"user_id": userId},
|
||||
})
|
||||
var res model.Player
|
||||
err := r.queryOne(sel, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *playerRepository) newRestSelect(options ...model.QueryOptions) SelectBuilder {
|
||||
s := r.selectPlayer(options...)
|
||||
return s.Where(r.addRestriction())
|
||||
}
|
||||
|
||||
func (r *playerRepository) addRestriction(sql ...Sqlizer) Sqlizer {
|
||||
s := And{}
|
||||
if len(sql) > 0 {
|
||||
s = append(s, sql[0])
|
||||
}
|
||||
u := loggedUser(r.ctx)
|
||||
if u.IsAdmin {
|
||||
return s
|
||||
}
|
||||
return append(s, Eq{"user_id": u.ID})
|
||||
}
|
||||
|
||||
func (r *playerRepository) CountByClient(options ...model.QueryOptions) (map[string]int64, error) {
|
||||
sel := r.newSelect(options...).
|
||||
Columns(
|
||||
"case when client = 'NavidromeUI' then name else client end as player",
|
||||
"count(*) as count",
|
||||
).GroupBy("client")
|
||||
var res []struct {
|
||||
Player string
|
||||
Count int64
|
||||
}
|
||||
err := r.queryAll(sel, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts := make(map[string]int64, len(res))
|
||||
for _, c := range res {
|
||||
counts[c.Player] = c.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (r *playerRepository) CountAll(options ...model.QueryOptions) (int64, error) {
|
||||
return r.count(r.newRestSelect(), options...)
|
||||
}
|
||||
|
||||
func (r *playerRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *playerRepository) Read(id string) (interface{}, error) {
|
||||
sel := r.newRestSelect().Where(Eq{"player.id": id})
|
||||
var res model.Player
|
||||
err := r.queryOne(sel, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *playerRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
sel := r.newRestSelect(r.parseRestOptions(r.ctx, options...))
|
||||
res := model.Players{}
|
||||
err := r.queryAll(sel, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *playerRepository) EntityName() string {
|
||||
return "player"
|
||||
}
|
||||
|
||||
func (r *playerRepository) NewInstance() interface{} {
|
||||
return &model.Player{}
|
||||
}
|
||||
|
||||
func (r *playerRepository) isPermitted(p *model.Player) bool {
|
||||
u := loggedUser(r.ctx)
|
||||
return u.IsAdmin || p.UserId == u.ID
|
||||
}
|
||||
|
||||
func (r *playerRepository) Save(entity interface{}) (string, error) {
|
||||
t := entity.(*model.Player)
|
||||
if !r.isPermitted(t) {
|
||||
return "", rest.ErrPermissionDenied
|
||||
}
|
||||
id, err := r.put(t.ID, t)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return "", rest.ErrNotFound
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
func (r *playerRepository) Update(id string, entity interface{}, cols ...string) error {
|
||||
t := entity.(*model.Player)
|
||||
t.ID = id
|
||||
if !r.isPermitted(t) {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
_, err := r.put(id, t, cols...)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *playerRepository) Delete(id string) error {
|
||||
filter := r.addRestriction(And{Eq{"player.id": id}})
|
||||
err := r.delete(filter)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var _ model.PlayerRepository = (*playerRepository)(nil)
|
||||
var _ rest.Repository = (*playerRepository)(nil)
|
||||
var _ rest.Persistable = (*playerRepository)(nil)
|
||||
247
persistence/player_repository_test.go
Normal file
247
persistence/player_repository_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ = Describe("PlayerRepository", func() {
|
||||
var adminRepo *playerRepository
|
||||
var database *dbx.DB
|
||||
|
||||
var (
|
||||
adminPlayer1 = model.Player{ID: "1", Name: "NavidromeUI [Firefox/Linux]", UserAgent: "Firefox/Linux", UserId: adminUser.ID, Username: adminUser.UserName, Client: "NavidromeUI", IP: "127.0.0.1", ReportRealPath: true, ScrobbleEnabled: true}
|
||||
adminPlayer2 = model.Player{ID: "2", Name: "GenericClient [Chrome/Windows]", IP: "192.168.0.5", UserAgent: "Chrome/Windows", UserId: adminUser.ID, Username: adminUser.UserName, Client: "GenericClient", MaxBitRate: 128}
|
||||
regularPlayer = model.Player{ID: "3", Name: "NavidromeUI [Safari/macOS]", UserAgent: "Safari/macOS", UserId: regularUser.ID, Username: regularUser.UserName, Client: "NavidromeUI", ReportRealPath: true, ScrobbleEnabled: false}
|
||||
|
||||
players = model.Players{adminPlayer1, adminPlayer2, regularPlayer}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, adminUser)
|
||||
|
||||
database = GetDBXBuilder()
|
||||
adminRepo = NewPlayerRepository(ctx, database).(*playerRepository)
|
||||
|
||||
for idx := range players {
|
||||
err := adminRepo.Put(&players[idx])
|
||||
Expect(err).To(BeNil())
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
items, err := adminRepo.ReadAll()
|
||||
Expect(err).To(BeNil())
|
||||
players, ok := items.(model.Players)
|
||||
Expect(ok).To(BeTrue())
|
||||
for i := range players {
|
||||
err = adminRepo.Delete(players[i].ID)
|
||||
Expect(err).To(BeNil())
|
||||
}
|
||||
})
|
||||
|
||||
Describe("EntityName", func() {
|
||||
It("returns the right name", func() {
|
||||
Expect(adminRepo.EntityName()).To(Equal("player"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("FindMatch", func() {
|
||||
It("finds existing match", func() {
|
||||
player, err := adminRepo.FindMatch(adminUser.ID, "NavidromeUI", "Firefox/Linux")
|
||||
Expect(err).To(BeNil())
|
||||
Expect(*player).To(Equal(adminPlayer1))
|
||||
})
|
||||
|
||||
It("doesn't find bad match", func() {
|
||||
_, err := adminRepo.FindMatch(regularUser.ID, "NavidromeUI", "Firefox/Linux")
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Get", func() {
|
||||
It("Gets an existing item from user", func() {
|
||||
player, err := adminRepo.Get(adminPlayer1.ID)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(*player).To(Equal(adminPlayer1))
|
||||
})
|
||||
|
||||
It("Gets an existing item from another user", func() {
|
||||
player, err := adminRepo.Get(regularPlayer.ID)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(*player).To(Equal(regularPlayer))
|
||||
})
|
||||
|
||||
It("does not get nonexistent item", func() {
|
||||
_, err := adminRepo.Get("i don't exist")
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
DescribeTableSubtree("per context", func(admin bool, players model.Players, userPlayer model.Player, otherPlayer model.Player) {
|
||||
var repo *playerRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
if admin {
|
||||
repo = adminRepo
|
||||
} else {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, regularUser)
|
||||
repo = NewPlayerRepository(ctx, database).(*playerRepository)
|
||||
}
|
||||
})
|
||||
|
||||
baseCount := int64(len(players))
|
||||
|
||||
Describe("Count", func() {
|
||||
It("should return all", func() {
|
||||
count, err := repo.Count()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(count).To(Equal(baseCount))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Delete", func() {
|
||||
DescribeTable("item type", func(player model.Player) {
|
||||
err := repo.Delete(player.ID)
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
isReal := player.UserId != ""
|
||||
canDelete := admin || player.UserId == userPlayer.UserId
|
||||
|
||||
count, err := repo.Count()
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
if isReal && canDelete {
|
||||
Expect(count).To(Equal(baseCount - 1))
|
||||
} else {
|
||||
Expect(count).To(Equal(baseCount))
|
||||
}
|
||||
|
||||
item, err := repo.Get(player.ID)
|
||||
if !isReal || canDelete {
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
} else {
|
||||
Expect(*item).To(Equal(player))
|
||||
}
|
||||
},
|
||||
Entry("same user", userPlayer),
|
||||
Entry("other item", otherPlayer),
|
||||
Entry("fake item", model.Player{}),
|
||||
)
|
||||
})
|
||||
|
||||
Describe("Read", func() {
|
||||
It("can read from current user", func() {
|
||||
player, err := repo.Read(userPlayer.ID)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(player).To(Equal(&userPlayer))
|
||||
})
|
||||
|
||||
It("can read from other user or fail if not admin", func() {
|
||||
player, err := repo.Read(otherPlayer.ID)
|
||||
if admin {
|
||||
Expect(err).To(BeNil())
|
||||
Expect(player).To(Equal(&otherPlayer))
|
||||
} else {
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
}
|
||||
})
|
||||
|
||||
It("does not get nonexistent item", func() {
|
||||
_, err := repo.Read("i don't exist")
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ReadAll", func() {
|
||||
It("should get all items", func() {
|
||||
data, err := repo.ReadAll()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(data).To(Equal(players))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Save", func() {
|
||||
DescribeTable("item type", func(player model.Player) {
|
||||
clone := player
|
||||
clone.ID = ""
|
||||
clone.IP = "192.168.1.1"
|
||||
id, err := repo.Save(&clone)
|
||||
|
||||
if clone.UserId == "" {
|
||||
Expect(err).To(HaveOccurred())
|
||||
} else if !admin && player.Username == adminPlayer1.Username {
|
||||
Expect(err).To(Equal(rest.ErrPermissionDenied))
|
||||
clone.UserId = ""
|
||||
} else {
|
||||
Expect(err).To(BeNil())
|
||||
Expect(id).ToNot(BeEmpty())
|
||||
}
|
||||
|
||||
count, err := repo.Count()
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
clone.ID = id
|
||||
newItem, err := repo.Get(id)
|
||||
|
||||
if clone.UserId == "" {
|
||||
Expect(count).To(Equal(baseCount))
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
} else {
|
||||
Expect(count).To(Equal(baseCount + 1))
|
||||
Expect(err).To(BeNil())
|
||||
Expect(*newItem).To(Equal(clone))
|
||||
}
|
||||
},
|
||||
Entry("same user", userPlayer),
|
||||
Entry("other item", otherPlayer),
|
||||
Entry("fake item", model.Player{}),
|
||||
)
|
||||
})
|
||||
|
||||
Describe("Update", func() {
|
||||
DescribeTable("item type", func(player model.Player) {
|
||||
clone := player
|
||||
clone.IP = "192.168.1.1"
|
||||
clone.MaxBitRate = 10000
|
||||
err := repo.Update(clone.ID, &clone, "ip")
|
||||
|
||||
if clone.UserId == "" {
|
||||
Expect(err).To(HaveOccurred())
|
||||
} else if !admin && player.Username == adminPlayer1.Username {
|
||||
Expect(err).To(Equal(rest.ErrPermissionDenied))
|
||||
clone.IP = player.IP
|
||||
} else {
|
||||
Expect(err).To(BeNil())
|
||||
}
|
||||
|
||||
clone.MaxBitRate = player.MaxBitRate
|
||||
newItem, err := repo.Get(clone.ID)
|
||||
|
||||
if player.UserId == "" {
|
||||
Expect(err).To(Equal(model.ErrNotFound))
|
||||
} else if !admin && player.UserId == adminUser.ID {
|
||||
Expect(*newItem).To(Equal(player))
|
||||
} else {
|
||||
Expect(*newItem).To(Equal(clone))
|
||||
}
|
||||
},
|
||||
Entry("same user", userPlayer),
|
||||
Entry("other item", otherPlayer),
|
||||
Entry("fake item", model.Player{}),
|
||||
)
|
||||
})
|
||||
},
|
||||
Entry("admin context", true, players, adminPlayer1, regularPlayer),
|
||||
Entry("regular context", false, model.Players{regularPlayer}, regularPlayer, adminPlayer1),
|
||||
)
|
||||
})
|
||||
528
persistence/playlist_repository.go
Normal file
528
persistence/playlist_repository.go
Normal file
@@ -0,0 +1,528 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/criteria"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type playlistRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
type dbPlaylist struct {
|
||||
model.Playlist `structs:",flatten"`
|
||||
Rules sql.NullString `structs:"-"`
|
||||
}
|
||||
|
||||
func (p *dbPlaylist) PostScan() error {
|
||||
if p.Rules.String != "" {
|
||||
return json.Unmarshal([]byte(p.Rules.String), &p.Playlist.Rules)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p dbPlaylist) PostMapArgs(args map[string]any) error {
|
||||
var err error
|
||||
if p.Playlist.IsSmartPlaylist() {
|
||||
args["rules"], err = json.Marshal(p.Playlist.Rules)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid criteria expression: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
delete(args, "rules")
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewPlaylistRepository(ctx context.Context, db dbx.Builder) model.PlaylistRepository {
|
||||
r := &playlistRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.registerModel(&model.Playlist{}, map[string]filterFunc{
|
||||
"q": playlistFilter,
|
||||
"smart": smartPlaylistFilter,
|
||||
})
|
||||
r.setSortMappings(map[string]string{
|
||||
"owner_name": "owner_name",
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func playlistFilter(_ string, value interface{}) Sqlizer {
|
||||
return Or{
|
||||
substringFilter("playlist.name", value),
|
||||
substringFilter("playlist.comment", value),
|
||||
}
|
||||
}
|
||||
|
||||
func smartPlaylistFilter(string, interface{}) Sqlizer {
|
||||
return Or{
|
||||
Eq{"rules": ""},
|
||||
Eq{"rules": nil},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *playlistRepository) userFilter() Sqlizer {
|
||||
user := loggedUser(r.ctx)
|
||||
if user.IsAdmin {
|
||||
return And{}
|
||||
}
|
||||
return Or{
|
||||
Eq{"public": true},
|
||||
Eq{"owner_id": user.ID},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *playlistRepository) CountAll(options ...model.QueryOptions) (int64, error) {
|
||||
sq := Select().Where(r.userFilter())
|
||||
return r.count(sq, options...)
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Exists(id string) (bool, error) {
|
||||
return r.exists(And{Eq{"id": id}, r.userFilter()})
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Delete(id string) error {
|
||||
usr := loggedUser(r.ctx)
|
||||
if !usr.IsAdmin {
|
||||
pls, err := r.Get(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if pls.OwnerID != usr.ID {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
return r.delete(And{Eq{"id": id}, r.userFilter()})
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Put(p *model.Playlist) error {
|
||||
pls := dbPlaylist{Playlist: *p}
|
||||
if pls.ID == "" {
|
||||
pls.CreatedAt = time.Now()
|
||||
} else {
|
||||
ok, err := r.Exists(pls.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return model.ErrNotAuthorized
|
||||
}
|
||||
}
|
||||
pls.UpdatedAt = time.Now()
|
||||
|
||||
id, err := r.put(pls.ID, pls)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.ID = id
|
||||
|
||||
if p.IsSmartPlaylist() {
|
||||
// Do not update tracks at this point, as it may take a long time and lock the DB, breaking the scan process
|
||||
//r.refreshSmartPlaylist(p)
|
||||
return nil
|
||||
}
|
||||
// Only update tracks if they were specified
|
||||
if len(pls.Tracks) > 0 {
|
||||
return r.updateTracks(id, p.MediaFiles())
|
||||
}
|
||||
return r.refreshCounters(&pls.Playlist)
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Get(id string) (*model.Playlist, error) {
|
||||
return r.findBy(And{Eq{"playlist.id": id}, r.userFilter()})
|
||||
}
|
||||
|
||||
func (r *playlistRepository) GetWithTracks(id string, refreshSmartPlaylist, includeMissing bool) (*model.Playlist, error) {
|
||||
pls, err := r.Get(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refreshSmartPlaylist {
|
||||
r.refreshSmartPlaylist(pls)
|
||||
}
|
||||
tracks, err := r.loadTracks(Select().From("playlist_tracks").
|
||||
Where(Eq{"missing": false}).
|
||||
OrderBy("playlist_tracks.id"), id)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error loading playlist tracks ", "playlist", pls.Name, "id", pls.ID, err)
|
||||
return nil, err
|
||||
}
|
||||
pls.SetTracks(tracks)
|
||||
return pls, nil
|
||||
}
|
||||
|
||||
func (r *playlistRepository) FindByPath(path string) (*model.Playlist, error) {
|
||||
return r.findBy(Eq{"path": path})
|
||||
}
|
||||
|
||||
func (r *playlistRepository) findBy(sql Sqlizer) (*model.Playlist, error) {
|
||||
sel := r.selectPlaylist().Where(sql)
|
||||
var pls []dbPlaylist
|
||||
err := r.queryAll(sel, &pls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(pls) == 0 {
|
||||
return nil, model.ErrNotFound
|
||||
}
|
||||
|
||||
return &pls[0].Playlist, nil
|
||||
}
|
||||
|
||||
func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playlists, error) {
|
||||
sel := r.selectPlaylist(options...).Where(r.userFilter())
|
||||
var res []dbPlaylist
|
||||
err := r.queryAll(sel, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
playlists := make(model.Playlists, len(res))
|
||||
for i, p := range res {
|
||||
playlists[i] = p.Playlist
|
||||
}
|
||||
return playlists, err
|
||||
}
|
||||
|
||||
func (r *playlistRepository) GetPlaylists(mediaFileId string) (model.Playlists, error) {
|
||||
sel := r.selectPlaylist(model.QueryOptions{Sort: "name"}).
|
||||
Join("playlist_tracks on playlist.id = playlist_tracks.playlist_id").
|
||||
Where(And{Eq{"playlist_tracks.media_file_id": mediaFileId}, r.userFilter()})
|
||||
var res []dbPlaylist
|
||||
err := r.queryAll(sel, &res)
|
||||
if err != nil {
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return model.Playlists{}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
playlists := make(model.Playlists, len(res))
|
||||
for i, p := range res {
|
||||
playlists[i] = p.Playlist
|
||||
}
|
||||
return playlists, nil
|
||||
}
|
||||
|
||||
func (r *playlistRepository) selectPlaylist(options ...model.QueryOptions) SelectBuilder {
|
||||
return r.newSelect(options...).Join("user on user.id = owner_id").
|
||||
Columns(r.tableName+".*", "user.user_name as owner_name")
|
||||
}
|
||||
|
||||
func (r *playlistRepository) refreshSmartPlaylist(pls *model.Playlist) bool {
|
||||
// Only refresh if it is a smart playlist and was not refreshed within the interval provided by the refresh delay config
|
||||
if !pls.IsSmartPlaylist() || (pls.EvaluatedAt != nil && time.Since(*pls.EvaluatedAt) < conf.Server.SmartPlaylistRefreshDelay) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Never refresh other users' playlists
|
||||
usr := loggedUser(r.ctx)
|
||||
if pls.OwnerID != usr.ID {
|
||||
log.Trace(r.ctx, "Not refreshing smart playlist from other user", "playlist", pls.Name, "id", pls.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debug(r.ctx, "Refreshing smart playlist", "playlist", pls.Name, "id", pls.ID)
|
||||
start := time.Now()
|
||||
|
||||
// Remove old tracks
|
||||
del := Delete("playlist_tracks").Where(Eq{"playlist_id": pls.ID})
|
||||
_, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error deleting old smart playlist tracks", "playlist", pls.Name, "id", pls.ID, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Re-populate playlist based on Smart Playlist criteria
|
||||
rules := *pls.Rules
|
||||
|
||||
// If the playlist depends on other playlists, recursively refresh them first
|
||||
childPlaylistIds := rules.ChildPlaylistIds()
|
||||
for _, id := range childPlaylistIds {
|
||||
childPls, err := r.Get(id)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error loading child playlist", "id", pls.ID, "childId", id, err)
|
||||
return false
|
||||
}
|
||||
r.refreshSmartPlaylist(childPls)
|
||||
}
|
||||
|
||||
sq := Select("row_number() over (order by "+rules.OrderBy()+") as id", "'"+pls.ID+"' as playlist_id", "media_file.id as media_file_id").
|
||||
From("media_file").LeftJoin("annotation on (" +
|
||||
"annotation.item_id = media_file.id" +
|
||||
" AND annotation.item_type = 'media_file'" +
|
||||
" AND annotation.user_id = '" + usr.ID + "')")
|
||||
|
||||
// Only include media files from libraries the user has access to
|
||||
sq = r.applyLibraryFilter(sq, "media_file")
|
||||
|
||||
// Apply the criteria rules
|
||||
sq = r.addCriteria(sq, rules)
|
||||
insSql := Insert("playlist_tracks").Columns("id", "playlist_id", "media_file_id").Select(sq)
|
||||
_, err = r.executeSQL(insSql)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error refreshing smart playlist tracks", "playlist", pls.Name, "id", pls.ID, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Update playlist stats
|
||||
err = r.refreshCounters(pls)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error updating smart playlist stats", "playlist", pls.Name, "id", pls.ID, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Update when the playlist was last refreshed (for cache purposes)
|
||||
updSql := Update(r.tableName).Set("evaluated_at", time.Now()).Where(Eq{"id": pls.ID})
|
||||
_, err = r.executeSQL(updSql)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error updating smart playlist", "playlist", pls.Name, "id", pls.ID, err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debug(r.ctx, "Refreshed playlist", "playlist", pls.Name, "id", pls.ID, "numTracks", pls.SongCount, "elapsed", time.Since(start))
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *playlistRepository) addCriteria(sql SelectBuilder, c criteria.Criteria) SelectBuilder {
|
||||
sql = sql.Where(c)
|
||||
if c.Limit > 0 {
|
||||
sql = sql.Limit(uint64(c.Limit)).Offset(uint64(c.Offset))
|
||||
}
|
||||
if order := c.OrderBy(); order != "" {
|
||||
sql = sql.OrderBy(order)
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (r *playlistRepository) updateTracks(id string, tracks model.MediaFiles) error {
|
||||
ids := make([]string, len(tracks))
|
||||
for i := range tracks {
|
||||
ids[i] = tracks[i].ID
|
||||
}
|
||||
return r.updatePlaylist(id, ids)
|
||||
}
|
||||
|
||||
func (r *playlistRepository) updatePlaylist(playlistId string, mediaFileIds []string) error {
|
||||
if !r.isWritable(playlistId) {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
|
||||
// Remove old tracks
|
||||
del := Delete("playlist_tracks").Where(Eq{"playlist_id": playlistId})
|
||||
_, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.addTracks(playlistId, 1, mediaFileIds)
|
||||
}
|
||||
|
||||
func (r *playlistRepository) addTracks(playlistId string, startingPos int, mediaFileIds []string) error {
|
||||
// Break the track list in chunks to avoid hitting SQLITE_MAX_VARIABLE_NUMBER limit
|
||||
// Add new tracks, chunk by chunk
|
||||
pos := startingPos
|
||||
for chunk := range slices.Chunk(mediaFileIds, 200) {
|
||||
ins := Insert("playlist_tracks").Columns("playlist_id", "media_file_id", "id")
|
||||
for _, t := range chunk {
|
||||
ins = ins.Values(playlistId, t, pos)
|
||||
pos++
|
||||
}
|
||||
_, err := r.executeSQL(ins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.refreshCounters(&model.Playlist{ID: playlistId})
|
||||
}
|
||||
|
||||
// refreshCounters updates total playlist duration, size and count
|
||||
func (r *playlistRepository) refreshCounters(pls *model.Playlist) error {
|
||||
statsSql := Select(
|
||||
"coalesce(sum(duration), 0) as duration",
|
||||
"coalesce(sum(size), 0) as size",
|
||||
"count(*) as count",
|
||||
).
|
||||
From("media_file").
|
||||
Join("playlist_tracks f on f.media_file_id = media_file.id").
|
||||
Where(Eq{"playlist_id": pls.ID})
|
||||
var res struct{ Duration, Size, Count float32 }
|
||||
err := r.queryOne(statsSql, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update playlist's total duration, size and count
|
||||
upd := Update("playlist").
|
||||
Set("duration", res.Duration).
|
||||
Set("size", res.Size).
|
||||
Set("song_count", res.Count).
|
||||
Set("updated_at", time.Now()).
|
||||
Where(Eq{"id": pls.ID})
|
||||
_, err = r.executeSQL(upd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pls.SongCount = int(res.Count)
|
||||
pls.Duration = res.Duration
|
||||
pls.Size = int64(res.Size)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *playlistRepository) loadTracks(sel SelectBuilder, id string) (model.PlaylistTracks, error) {
|
||||
sel = r.applyLibraryFilter(sel, "f")
|
||||
userID := loggedUser(r.ctx).ID
|
||||
tracksQuery := sel.
|
||||
Columns(
|
||||
"coalesce(starred, 0) as starred",
|
||||
"starred_at",
|
||||
"coalesce(play_count, 0) as play_count",
|
||||
"play_date",
|
||||
"coalesce(rating, 0) as rating",
|
||||
"rated_at",
|
||||
"f.*",
|
||||
"playlist_tracks.*",
|
||||
"library.path as library_path",
|
||||
"library.name as library_name",
|
||||
).
|
||||
LeftJoin("annotation on (" +
|
||||
"annotation.item_id = media_file_id" +
|
||||
" AND annotation.item_type = 'media_file'" +
|
||||
" AND annotation.user_id = '" + userID + "')").
|
||||
Join("media_file f on f.id = media_file_id").
|
||||
Join("library on f.library_id = library.id").
|
||||
Where(Eq{"playlist_id": id})
|
||||
tracks := dbPlaylistTracks{}
|
||||
err := r.queryAll(tracksQuery, &tracks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tracks.toModels(), err
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Read(id string) (interface{}, error) {
|
||||
return r.Get(id)
|
||||
}
|
||||
|
||||
func (r *playlistRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *playlistRepository) EntityName() string {
|
||||
return "playlist"
|
||||
}
|
||||
|
||||
func (r *playlistRepository) NewInstance() interface{} {
|
||||
return &model.Playlist{}
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Save(entity interface{}) (string, error) {
|
||||
pls := entity.(*model.Playlist)
|
||||
pls.OwnerID = loggedUser(r.ctx).ID
|
||||
pls.ID = "" // Make sure we don't override an existing playlist
|
||||
err := r.Put(pls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return pls.ID, err
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Update(id string, entity interface{}, cols ...string) error {
|
||||
pls := dbPlaylist{Playlist: *entity.(*model.Playlist)}
|
||||
current, err := r.Get(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
usr := loggedUser(r.ctx)
|
||||
if !usr.IsAdmin {
|
||||
// Only the owner can update the playlist
|
||||
if current.OwnerID != usr.ID {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
// Regular users can't change the ownership of a playlist
|
||||
if pls.OwnerID != "" && pls.OwnerID != usr.ID {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
pls.ID = id
|
||||
pls.UpdatedAt = time.Now()
|
||||
_, err = r.put(id, pls, append(cols, "updatedAt")...)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *playlistRepository) removeOrphans() error {
|
||||
sel := Select("playlist_tracks.playlist_id as id", "p.name").From("playlist_tracks").
|
||||
Join("playlist p on playlist_tracks.playlist_id = p.id").
|
||||
LeftJoin("media_file mf on playlist_tracks.media_file_id = mf.id").
|
||||
Where(Eq{"mf.id": nil}).
|
||||
GroupBy("playlist_tracks.playlist_id")
|
||||
|
||||
var pls []struct{ Id, Name string }
|
||||
err := r.queryAll(sel, &pls)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching playlists with orphan tracks: %w", err)
|
||||
}
|
||||
|
||||
for _, pl := range pls {
|
||||
log.Debug(r.ctx, "Cleaning-up orphan tracks from playlist", "id", pl.Id, "name", pl.Name)
|
||||
del := Delete("playlist_tracks").Where(And{
|
||||
ConcatExpr("media_file_id not in (select id from media_file)"),
|
||||
Eq{"playlist_id": pl.Id},
|
||||
})
|
||||
n, err := r.executeSQL(del)
|
||||
if n == 0 || err != nil {
|
||||
return fmt.Errorf("deleting orphan tracks from playlist %s: %w", pl.Name, err)
|
||||
}
|
||||
log.Debug(r.ctx, "Deleted tracks, now reordering", "id", pl.Id, "name", pl.Name, "deleted", n)
|
||||
|
||||
// Renumber the playlist if any track was removed
|
||||
if err := r.renumber(pl.Id); err != nil {
|
||||
return fmt.Errorf("renumbering playlist %s: %w", pl.Name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *playlistRepository) renumber(id string) error {
|
||||
var ids []string
|
||||
sq := Select("media_file_id").From("playlist_tracks").Where(Eq{"playlist_id": id}).OrderBy("id")
|
||||
err := r.queryAllSlice(sq, &ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.updatePlaylist(id, ids)
|
||||
}
|
||||
|
||||
func (r *playlistRepository) isWritable(playlistId string) bool {
|
||||
usr := loggedUser(r.ctx)
|
||||
if usr.IsAdmin {
|
||||
return true
|
||||
}
|
||||
pls, err := r.Get(playlistId)
|
||||
return err == nil && pls.OwnerID == usr.ID
|
||||
}
|
||||
|
||||
var _ model.PlaylistRepository = (*playlistRepository)(nil)
|
||||
var _ rest.Repository = (*playlistRepository)(nil)
|
||||
var _ rest.Persistable = (*playlistRepository)(nil)
|
||||
501
persistence/playlist_repository_test.go
Normal file
501
persistence/playlist_repository_test.go
Normal file
@@ -0,0 +1,501 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/criteria"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ = Describe("PlaylistRepository", func() {
|
||||
var repo model.PlaylistRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(GinkgoT().Context())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
|
||||
repo = NewPlaylistRepository(ctx, GetDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
It("returns the number of playlists in the DB", func() {
|
||||
Expect(repo.CountAll()).To(Equal(int64(2)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Exists", func() {
|
||||
It("returns true for an existing playlist", func() {
|
||||
Expect(repo.Exists(plsCool.ID)).To(BeTrue())
|
||||
})
|
||||
It("returns false for a non-existing playlist", func() {
|
||||
Expect(repo.Exists("666")).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Get", func() {
|
||||
It("returns an existing playlist", func() {
|
||||
p, err := repo.Get(plsBest.ID)
|
||||
Expect(err).To(BeNil())
|
||||
// Compare all but Tracks and timestamps
|
||||
p2 := *p
|
||||
p2.Tracks = plsBest.Tracks
|
||||
p2.UpdatedAt = plsBest.UpdatedAt
|
||||
p2.CreatedAt = plsBest.CreatedAt
|
||||
Expect(p2).To(Equal(plsBest))
|
||||
// Compare tracks
|
||||
for i := range p.Tracks {
|
||||
Expect(p.Tracks[i].ID).To(Equal(plsBest.Tracks[i].ID))
|
||||
}
|
||||
})
|
||||
It("returns ErrNotFound for a non-existing playlist", func() {
|
||||
_, err := repo.Get("666")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
It("returns all tracks", func() {
|
||||
pls, err := repo.GetWithTracks(plsBest.ID, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pls.Name).To(Equal(plsBest.Name))
|
||||
Expect(pls.Tracks).To(HaveLen(2))
|
||||
Expect(pls.Tracks[0].ID).To(Equal("1"))
|
||||
Expect(pls.Tracks[0].PlaylistID).To(Equal(plsBest.ID))
|
||||
Expect(pls.Tracks[0].MediaFileID).To(Equal(songDayInALife.ID))
|
||||
Expect(pls.Tracks[0].MediaFile.ID).To(Equal(songDayInALife.ID))
|
||||
Expect(pls.Tracks[1].ID).To(Equal("2"))
|
||||
Expect(pls.Tracks[1].PlaylistID).To(Equal(plsBest.ID))
|
||||
Expect(pls.Tracks[1].MediaFileID).To(Equal(songRadioactivity.ID))
|
||||
Expect(pls.Tracks[1].MediaFile.ID).To(Equal(songRadioactivity.ID))
|
||||
mfs := pls.MediaFiles()
|
||||
Expect(mfs).To(HaveLen(2))
|
||||
Expect(mfs[0].ID).To(Equal(songDayInALife.ID))
|
||||
Expect(mfs[1].ID).To(Equal(songRadioactivity.ID))
|
||||
})
|
||||
})
|
||||
|
||||
It("Put/Exists/Delete", func() {
|
||||
By("saves the playlist to the DB")
|
||||
newPls := model.Playlist{Name: "Great!", OwnerID: "userid"}
|
||||
newPls.AddMediaFilesByID([]string{"1004", "1003"})
|
||||
|
||||
By("saves the playlist to the DB")
|
||||
Expect(repo.Put(&newPls)).To(BeNil())
|
||||
|
||||
By("adds repeated songs to a playlist and keeps the order")
|
||||
newPls.AddMediaFilesByID([]string{"1004"})
|
||||
Expect(repo.Put(&newPls)).To(BeNil())
|
||||
saved, _ := repo.GetWithTracks(newPls.ID, true, false)
|
||||
Expect(saved.Tracks).To(HaveLen(3))
|
||||
Expect(saved.Tracks[0].MediaFileID).To(Equal("1004"))
|
||||
Expect(saved.Tracks[1].MediaFileID).To(Equal("1003"))
|
||||
Expect(saved.Tracks[2].MediaFileID).To(Equal("1004"))
|
||||
|
||||
By("returns the newly created playlist")
|
||||
Expect(repo.Exists(newPls.ID)).To(BeTrue())
|
||||
|
||||
By("returns deletes the playlist")
|
||||
Expect(repo.Delete(newPls.ID)).To(BeNil())
|
||||
|
||||
By("returns error if tries to retrieve the deleted playlist")
|
||||
Expect(repo.Exists(newPls.ID)).To(BeFalse())
|
||||
})
|
||||
|
||||
Describe("GetAll", func() {
|
||||
It("returns all playlists from DB", func() {
|
||||
all, err := repo.GetAll()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(all[0].ID).To(Equal(plsBest.ID))
|
||||
Expect(all[1].ID).To(Equal(plsCool.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetPlaylists", func() {
|
||||
It("returns playlists for a track", func() {
|
||||
pls, err := repo.GetPlaylists(songRadioactivity.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pls).To(HaveLen(1))
|
||||
Expect(pls[0].ID).To(Equal(plsBest.ID))
|
||||
})
|
||||
|
||||
It("returns empty when none", func() {
|
||||
pls, err := repo.GetPlaylists("9999")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pls).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Smart Playlists", func() {
|
||||
var rules *criteria.Criteria
|
||||
BeforeEach(func() {
|
||||
rules = &criteria.Criteria{
|
||||
Expression: criteria.All{
|
||||
criteria.Contains{"title": "love"},
|
||||
},
|
||||
}
|
||||
})
|
||||
Context("valid rules", func() {
|
||||
Specify("Put/Get", func() {
|
||||
newPls := model.Playlist{Name: "Great!", OwnerID: "userid", Rules: rules}
|
||||
Expect(repo.Put(&newPls)).To(Succeed())
|
||||
|
||||
savedPls, err := repo.Get(newPls.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(savedPls.Rules).To(Equal(rules))
|
||||
})
|
||||
})
|
||||
|
||||
Context("invalid rules", func() {
|
||||
It("fails to Put it in the DB", func() {
|
||||
rules = &criteria.Criteria{
|
||||
// This is invalid because "contains" cannot have multiple fields
|
||||
Expression: criteria.All{
|
||||
criteria.Contains{"genre": "Hardcore", "filetype": "mp3"},
|
||||
},
|
||||
}
|
||||
newPls := model.Playlist{Name: "Great!", OwnerID: "userid", Rules: rules}
|
||||
Expect(repo.Put(&newPls)).To(MatchError(ContainSubstring("invalid criteria expression")))
|
||||
})
|
||||
})
|
||||
|
||||
// TODO Validate these tests
|
||||
XContext("child smart playlists", func() {
|
||||
When("refresh day has expired", func() {
|
||||
It("should refresh tracks for smart playlist referenced in parent smart playlist criteria", func() {
|
||||
conf.Server.SmartPlaylistRefreshDelay = -1 * time.Second
|
||||
|
||||
nestedPls := model.Playlist{Name: "Nested", OwnerID: "userid", Rules: rules}
|
||||
Expect(repo.Put(&nestedPls)).To(Succeed())
|
||||
|
||||
parentPls := model.Playlist{Name: "Parent", OwnerID: "userid", Rules: &criteria.Criteria{
|
||||
Expression: criteria.All{
|
||||
criteria.InPlaylist{"id": nestedPls.ID},
|
||||
},
|
||||
}}
|
||||
Expect(repo.Put(&parentPls)).To(Succeed())
|
||||
|
||||
nestedPlsRead, err := repo.Get(nestedPls.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = repo.GetWithTracks(parentPls.ID, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Check that the nested playlist was refreshed by parent get by verifying evaluatedAt is updated since first nestedPls get
|
||||
nestedPlsAfterParentGet, err := repo.Get(nestedPls.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(*nestedPlsAfterParentGet.EvaluatedAt).To(BeTemporally(">", *nestedPlsRead.EvaluatedAt))
|
||||
})
|
||||
})
|
||||
|
||||
When("refresh day has not expired", func() {
|
||||
It("should NOT refresh tracks for smart playlist referenced in parent smart playlist criteria", func() {
|
||||
conf.Server.SmartPlaylistRefreshDelay = 1 * time.Hour
|
||||
|
||||
nestedPls := model.Playlist{Name: "Nested", OwnerID: "userid", Rules: rules}
|
||||
Expect(repo.Put(&nestedPls)).To(Succeed())
|
||||
|
||||
parentPls := model.Playlist{Name: "Parent", OwnerID: "userid", Rules: &criteria.Criteria{
|
||||
Expression: criteria.All{
|
||||
criteria.InPlaylist{"id": nestedPls.ID},
|
||||
},
|
||||
}}
|
||||
Expect(repo.Put(&parentPls)).To(Succeed())
|
||||
|
||||
nestedPlsRead, err := repo.Get(nestedPls.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
_, err = repo.GetWithTracks(parentPls.ID, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Check that the nested playlist was not refreshed by parent get by verifying evaluatedAt is not updated since first nestedPls get
|
||||
nestedPlsAfterParentGet, err := repo.Get(nestedPls.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(*nestedPlsAfterParentGet.EvaluatedAt).To(Equal(*nestedPlsRead.EvaluatedAt))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Playlist Track Sorting", func() {
|
||||
var testPlaylistID string
|
||||
|
||||
AfterEach(func() {
|
||||
if testPlaylistID != "" {
|
||||
Expect(repo.Delete(testPlaylistID)).To(BeNil())
|
||||
testPlaylistID = ""
|
||||
}
|
||||
})
|
||||
|
||||
It("sorts tracks correctly by album (disc and track number)", func() {
|
||||
By("creating a playlist with multi-disc album tracks in arbitrary order")
|
||||
newPls := model.Playlist{Name: "Multi-Disc Test", OwnerID: "userid"}
|
||||
// Add tracks in intentionally scrambled order
|
||||
newPls.AddMediaFilesByID([]string{"2001", "2002", "2003", "2004"})
|
||||
Expect(repo.Put(&newPls)).To(Succeed())
|
||||
testPlaylistID = newPls.ID
|
||||
|
||||
By("retrieving tracks sorted by album")
|
||||
tracksRepo := repo.Tracks(newPls.ID, false)
|
||||
tracks, err := tracksRepo.GetAll(model.QueryOptions{Sort: "album", Order: "asc"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("verifying tracks are sorted by disc number then track number")
|
||||
Expect(tracks).To(HaveLen(4))
|
||||
// Expected order: Disc 1 Track 1, Disc 1 Track 2, Disc 2 Track 1, Disc 2 Track 11
|
||||
Expect(tracks[0].MediaFileID).To(Equal("2002")) // Disc 1, Track 1
|
||||
Expect(tracks[1].MediaFileID).To(Equal("2004")) // Disc 1, Track 2
|
||||
Expect(tracks[2].MediaFileID).To(Equal("2003")) // Disc 2, Track 1
|
||||
Expect(tracks[3].MediaFileID).To(Equal("2001")) // Disc 2, Track 11
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Smart Playlists with Tag Criteria", func() {
|
||||
var mfRepo model.MediaFileRepository
|
||||
var testPlaylistID string
|
||||
var songWithGrouping, songWithoutGrouping model.MediaFile
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(GinkgoT().Context())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
|
||||
mfRepo = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
|
||||
|
||||
// Register 'grouping' as a valid tag for smart playlists
|
||||
criteria.AddTagNames([]string{"grouping"})
|
||||
|
||||
// Create a song with the grouping tag
|
||||
songWithGrouping = model.MediaFile{
|
||||
ID: "test-grouping-1",
|
||||
Title: "Song With Grouping",
|
||||
Artist: "Test Artist",
|
||||
ArtistID: "1",
|
||||
Album: "Test Album",
|
||||
AlbumID: "101",
|
||||
Path: "/test/grouping/song1.mp3",
|
||||
Tags: model.Tags{
|
||||
"grouping": []string{"My Crate"},
|
||||
},
|
||||
Participants: model.Participants{},
|
||||
LibraryID: 1,
|
||||
Lyrics: "[]",
|
||||
}
|
||||
Expect(mfRepo.Put(&songWithGrouping)).To(Succeed())
|
||||
|
||||
// Create a song without the grouping tag
|
||||
songWithoutGrouping = model.MediaFile{
|
||||
ID: "test-grouping-2",
|
||||
Title: "Song Without Grouping",
|
||||
Artist: "Test Artist",
|
||||
ArtistID: "1",
|
||||
Album: "Test Album",
|
||||
AlbumID: "101",
|
||||
Path: "/test/grouping/song2.mp3",
|
||||
Tags: model.Tags{},
|
||||
Participants: model.Participants{},
|
||||
LibraryID: 1,
|
||||
Lyrics: "[]",
|
||||
}
|
||||
Expect(mfRepo.Put(&songWithoutGrouping)).To(Succeed())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if testPlaylistID != "" {
|
||||
_ = repo.Delete(testPlaylistID)
|
||||
testPlaylistID = ""
|
||||
}
|
||||
// Clean up test media files
|
||||
_, _ = GetDBXBuilder().Delete("media_file", dbx.HashExp{"id": "test-grouping-1"}).Execute()
|
||||
_, _ = GetDBXBuilder().Delete("media_file", dbx.HashExp{"id": "test-grouping-2"}).Execute()
|
||||
})
|
||||
|
||||
It("matches tracks with a tag value using 'contains' with empty string (issue #4728 workaround)", func() {
|
||||
By("creating a smart playlist that checks if grouping tag has any value")
|
||||
// This is the workaround for issue #4728: using 'contains' with empty string
|
||||
// generates SQL: value LIKE '%%' which matches any non-empty string
|
||||
rules := &criteria.Criteria{
|
||||
Expression: criteria.All{
|
||||
criteria.Contains{"grouping": ""},
|
||||
},
|
||||
}
|
||||
newPls := model.Playlist{Name: "Tracks with Grouping", OwnerID: "userid", Rules: rules}
|
||||
Expect(repo.Put(&newPls)).To(Succeed())
|
||||
testPlaylistID = newPls.ID
|
||||
|
||||
By("refreshing the smart playlist")
|
||||
conf.Server.SmartPlaylistRefreshDelay = -1 * time.Second // Force refresh
|
||||
pls, err := repo.GetWithTracks(newPls.ID, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("verifying only the track with grouping tag is matched")
|
||||
Expect(pls.Tracks).To(HaveLen(1))
|
||||
Expect(pls.Tracks[0].MediaFileID).To(Equal(songWithGrouping.ID))
|
||||
})
|
||||
|
||||
It("excludes tracks with a tag value using 'notContains' with empty string", func() {
|
||||
By("creating a smart playlist that checks if grouping tag is NOT set")
|
||||
rules := &criteria.Criteria{
|
||||
Expression: criteria.All{
|
||||
criteria.NotContains{"grouping": ""},
|
||||
},
|
||||
}
|
||||
newPls := model.Playlist{Name: "Tracks without Grouping", OwnerID: "userid", Rules: rules}
|
||||
Expect(repo.Put(&newPls)).To(Succeed())
|
||||
testPlaylistID = newPls.ID
|
||||
|
||||
By("refreshing the smart playlist")
|
||||
conf.Server.SmartPlaylistRefreshDelay = -1 * time.Second // Force refresh
|
||||
pls, err := repo.GetWithTracks(newPls.ID, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("verifying the track with grouping is NOT in the playlist")
|
||||
for _, track := range pls.Tracks {
|
||||
Expect(track.MediaFileID).ToNot(Equal(songWithGrouping.ID))
|
||||
}
|
||||
|
||||
By("verifying the track without grouping IS in the playlist")
|
||||
var foundWithoutGrouping bool
|
||||
for _, track := range pls.Tracks {
|
||||
if track.MediaFileID == songWithoutGrouping.ID {
|
||||
foundWithoutGrouping = true
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(foundWithoutGrouping).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Smart Playlists Library Filtering", func() {
|
||||
var mfRepo model.MediaFileRepository
|
||||
var testPlaylistID string
|
||||
var lib2ID int
|
||||
var restrictedUserID string
|
||||
var uniqueLibPath string
|
||||
|
||||
BeforeEach(func() {
|
||||
db := GetDBXBuilder()
|
||||
|
||||
// Generate unique IDs for this test run
|
||||
uniqueSuffix := time.Now().Format("20060102150405.000")
|
||||
restrictedUserID = "restricted-user-" + uniqueSuffix
|
||||
uniqueLibPath = "/music/lib2-" + uniqueSuffix
|
||||
|
||||
// Create a second library with unique name and path to avoid conflicts with other tests
|
||||
_, err := db.DB().Exec("INSERT INTO library (name, path, created_at, updated_at) VALUES (?, ?, datetime('now'), datetime('now'))", "Library 2-"+uniqueSuffix, uniqueLibPath)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = db.DB().QueryRow("SELECT last_insert_rowid()").Scan(&lib2ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create a restricted user with access only to library 1
|
||||
_, err = db.DB().Exec("INSERT INTO user (id, user_name, name, is_admin, password, created_at, updated_at) VALUES (?, ?, 'Restricted User', false, 'pass', datetime('now'), datetime('now'))", restrictedUserID, restrictedUserID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = db.DB().Exec("INSERT INTO user_library (user_id, library_id) VALUES (?, 1)", restrictedUserID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create test media files in each library
|
||||
ctx := log.NewContext(GinkgoT().Context())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
|
||||
mfRepo = NewMediaFileRepository(ctx, db, nil)
|
||||
|
||||
// Song in library 1 (accessible by restricted user)
|
||||
songLib1 := model.MediaFile{
|
||||
ID: "lib1-song",
|
||||
Title: "Song in Lib1",
|
||||
Artist: "Test Artist",
|
||||
ArtistID: "1",
|
||||
Album: "Test Album",
|
||||
AlbumID: "101",
|
||||
Path: "/music/lib1/song.mp3",
|
||||
LibraryID: 1,
|
||||
Participants: model.Participants{},
|
||||
Tags: model.Tags{},
|
||||
Lyrics: "[]",
|
||||
}
|
||||
Expect(mfRepo.Put(&songLib1)).To(Succeed())
|
||||
|
||||
// Song in library 2 (NOT accessible by restricted user)
|
||||
songLib2 := model.MediaFile{
|
||||
ID: "lib2-song",
|
||||
Title: "Song in Lib2",
|
||||
Artist: "Test Artist",
|
||||
ArtistID: "1",
|
||||
Album: "Test Album",
|
||||
AlbumID: "101",
|
||||
Path: uniqueLibPath + "/song.mp3",
|
||||
LibraryID: lib2ID,
|
||||
Participants: model.Participants{},
|
||||
Tags: model.Tags{},
|
||||
Lyrics: "[]",
|
||||
}
|
||||
Expect(mfRepo.Put(&songLib2)).To(Succeed())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
db := GetDBXBuilder()
|
||||
if testPlaylistID != "" {
|
||||
_ = repo.Delete(testPlaylistID)
|
||||
testPlaylistID = ""
|
||||
}
|
||||
// Clean up test data
|
||||
_, _ = db.Delete("media_file", dbx.HashExp{"id": "lib1-song"}).Execute()
|
||||
_, _ = db.Delete("media_file", dbx.HashExp{"id": "lib2-song"}).Execute()
|
||||
_, _ = db.Delete("user_library", dbx.HashExp{"user_id": restrictedUserID}).Execute()
|
||||
_, _ = db.Delete("user", dbx.HashExp{"id": restrictedUserID}).Execute()
|
||||
_, _ = db.DB().Exec("DELETE FROM library WHERE id = ?", lib2ID)
|
||||
})
|
||||
|
||||
It("should only include tracks from libraries the user has access to (issue #4738)", func() {
|
||||
db := GetDBXBuilder()
|
||||
ctx := log.NewContext(GinkgoT().Context())
|
||||
|
||||
// Create the smart playlist as the restricted user
|
||||
restrictedUser := model.User{ID: restrictedUserID, UserName: restrictedUserID, IsAdmin: false}
|
||||
ctx = request.WithUser(ctx, restrictedUser)
|
||||
restrictedRepo := NewPlaylistRepository(ctx, db)
|
||||
|
||||
// Create a smart playlist that matches all songs
|
||||
rules := &criteria.Criteria{
|
||||
Expression: criteria.All{
|
||||
criteria.Gt{"playCount": -1}, // Matches everything
|
||||
},
|
||||
}
|
||||
newPls := model.Playlist{Name: "All Songs", OwnerID: restrictedUserID, Rules: rules}
|
||||
Expect(restrictedRepo.Put(&newPls)).To(Succeed())
|
||||
testPlaylistID = newPls.ID
|
||||
|
||||
By("refreshing the smart playlist")
|
||||
conf.Server.SmartPlaylistRefreshDelay = -1 * time.Second // Force refresh
|
||||
pls, err := restrictedRepo.GetWithTracks(newPls.ID, true, false)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("verifying only the track from library 1 is in the playlist")
|
||||
var foundLib1Song, foundLib2Song bool
|
||||
for _, track := range pls.Tracks {
|
||||
if track.MediaFileID == "lib1-song" {
|
||||
foundLib1Song = true
|
||||
}
|
||||
if track.MediaFileID == "lib2-song" {
|
||||
foundLib2Song = true
|
||||
}
|
||||
}
|
||||
Expect(foundLib1Song).To(BeTrue(), "Song from library 1 should be in the playlist")
|
||||
Expect(foundLib2Song).To(BeFalse(), "Song from library 2 should NOT be in the playlist")
|
||||
|
||||
By("verifying playlist_tracks table only contains the accessible track")
|
||||
var playlistTracksCount int
|
||||
err = db.DB().QueryRow("SELECT count(*) FROM playlist_tracks WHERE playlist_id = ?", newPls.ID).Scan(&playlistTracksCount)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Count should only include tracks visible to the user (lib1-song)
|
||||
// The count may include other test songs from library 1, but NOT lib2-song
|
||||
var lib2TrackCount int
|
||||
err = db.DB().QueryRow("SELECT count(*) FROM playlist_tracks WHERE playlist_id = ? AND media_file_id = 'lib2-song'", newPls.ID).Scan(&lib2TrackCount)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(lib2TrackCount).To(Equal(0), "lib2-song should not be in playlist_tracks")
|
||||
|
||||
By("verifying SongCount matches visible tracks")
|
||||
Expect(pls.SongCount).To(Equal(len(pls.Tracks)), "SongCount should match the number of visible tracks")
|
||||
})
|
||||
})
|
||||
})
|
||||
247
persistence/playlist_track_repository.go
Normal file
247
persistence/playlist_track_repository.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
)
|
||||
|
||||
type playlistTrackRepository struct {
|
||||
sqlRepository
|
||||
playlistId string
|
||||
playlist *model.Playlist
|
||||
playlistRepo *playlistRepository
|
||||
}
|
||||
|
||||
type dbPlaylistTrack struct {
|
||||
dbMediaFile
|
||||
*model.PlaylistTrack `structs:",flatten"`
|
||||
}
|
||||
|
||||
func (t *dbPlaylistTrack) PostScan() error {
|
||||
if err := t.dbMediaFile.PostScan(); err != nil {
|
||||
return err
|
||||
}
|
||||
t.PlaylistTrack.MediaFile = *t.dbMediaFile.MediaFile
|
||||
t.PlaylistTrack.MediaFile.ID = t.MediaFileID
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbPlaylistTracks []dbPlaylistTrack
|
||||
|
||||
func (t dbPlaylistTracks) toModels() model.PlaylistTracks {
|
||||
return slice.Map(t, func(trk dbPlaylistTrack) model.PlaylistTrack {
|
||||
return *trk.PlaylistTrack
|
||||
})
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Tracks(playlistId string, refreshSmartPlaylist bool) model.PlaylistTrackRepository {
|
||||
p := &playlistTrackRepository{}
|
||||
p.playlistRepo = r
|
||||
p.playlistId = playlistId
|
||||
p.ctx = r.ctx
|
||||
p.db = r.db
|
||||
p.tableName = "playlist_tracks"
|
||||
p.registerModel(&model.PlaylistTrack{}, map[string]filterFunc{
|
||||
"missing": booleanFilter,
|
||||
"library_id": libraryIdFilter,
|
||||
})
|
||||
p.setSortMappings(
|
||||
map[string]string{
|
||||
"id": "playlist_tracks.id",
|
||||
"artist": "order_artist_name",
|
||||
"album_artist": "order_album_artist_name",
|
||||
"album": "order_album_name, album_id, disc_number, track_number, order_artist_name, title",
|
||||
"title": "order_title",
|
||||
// To make sure these fields will be whitelisted
|
||||
"duration": "duration",
|
||||
"year": "year",
|
||||
"bpm": "bpm",
|
||||
"channels": "channels",
|
||||
},
|
||||
"f") // TODO I don't like this solution, but I won't change it now as it's not the focus of BFR.
|
||||
|
||||
pls, err := r.Get(playlistId)
|
||||
if err != nil {
|
||||
log.Warn(r.ctx, "Error getting playlist's tracks", "playlistId", playlistId, err)
|
||||
return nil
|
||||
}
|
||||
if refreshSmartPlaylist {
|
||||
r.refreshSmartPlaylist(pls)
|
||||
}
|
||||
p.playlist = pls
|
||||
return p
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
query := Select().
|
||||
LeftJoin("media_file f on f.id = media_file_id").
|
||||
Where(Eq{"playlist_id": r.playlistId})
|
||||
return r.count(query, r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) Read(id string) (interface{}, error) {
|
||||
userID := loggedUser(r.ctx).ID
|
||||
sel := r.newSelect().
|
||||
LeftJoin("annotation on ("+
|
||||
"annotation.item_id = media_file_id"+
|
||||
" AND annotation.item_type = 'media_file'"+
|
||||
" AND annotation.user_id = '"+userID+"')").
|
||||
Columns(
|
||||
"coalesce(starred, 0) as starred",
|
||||
"coalesce(play_count, 0) as play_count",
|
||||
"coalesce(rating, 0) as rating",
|
||||
"starred_at",
|
||||
"play_date",
|
||||
"rated_at",
|
||||
"f.*",
|
||||
"playlist_tracks.*",
|
||||
).
|
||||
Join("media_file f on f.id = media_file_id").
|
||||
Where(And{Eq{"playlist_id": r.playlistId}, Eq{"playlist_tracks.id": id}})
|
||||
var trk dbPlaylistTrack
|
||||
err := r.queryOne(sel, &trk)
|
||||
return trk.PlaylistTrack, err
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) GetAll(options ...model.QueryOptions) (model.PlaylistTracks, error) {
|
||||
tracks, err := r.playlistRepo.loadTracks(r.newSelect(options...), r.playlistId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tracks, err
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) GetAlbumIDs(options ...model.QueryOptions) ([]string, error) {
|
||||
query := r.newSelect(options...).Columns("distinct mf.album_id").
|
||||
Join("media_file mf on mf.id = media_file_id").
|
||||
Where(Eq{"playlist_id": r.playlistId})
|
||||
var ids []string
|
||||
err := r.queryAllSlice(query, &ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) EntityName() string {
|
||||
return "playlist_tracks"
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) NewInstance() interface{} {
|
||||
return &model.PlaylistTrack{}
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) isTracksEditable() bool {
|
||||
return r.playlistRepo.isWritable(r.playlistId) && !r.playlist.IsSmartPlaylist()
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) Add(mediaFileIds []string) (int, error) {
|
||||
if !r.isTracksEditable() {
|
||||
return 0, rest.ErrPermissionDenied
|
||||
}
|
||||
|
||||
if len(mediaFileIds) > 0 {
|
||||
log.Debug(r.ctx, "Adding songs to playlist", "playlistId", r.playlistId, "mediaFileIds", mediaFileIds)
|
||||
} else {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Get next pos (ID) in playlist
|
||||
sq := r.newSelect().Columns("max(id) as max").Where(Eq{"playlist_id": r.playlistId})
|
||||
var res struct{ Max sql.NullInt32 }
|
||||
err := r.queryOne(sq, &res)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(mediaFileIds), r.playlistRepo.addTracks(r.playlistId, int(res.Max.Int32+1), mediaFileIds)
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) addMediaFileIds(cond Sqlizer) (int, error) {
|
||||
sq := Select("id").From("media_file").Where(cond).OrderBy("album_artist, album, release_date, disc_number, track_number")
|
||||
var ids []string
|
||||
err := r.queryAllSlice(sq, &ids)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error getting tracks to add to playlist", err)
|
||||
return 0, err
|
||||
}
|
||||
return r.Add(ids)
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) AddAlbums(albumIds []string) (int, error) {
|
||||
return r.addMediaFileIds(Eq{"album_id": albumIds})
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) AddArtists(artistIds []string) (int, error) {
|
||||
return r.addMediaFileIds(Eq{"album_artist_id": artistIds})
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) AddDiscs(discs []model.DiscID) (int, error) {
|
||||
if len(discs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var clauses Or
|
||||
for _, d := range discs {
|
||||
clauses = append(clauses, And{Eq{"album_id": d.AlbumID}, Eq{"release_date": d.ReleaseDate}, Eq{"disc_number": d.DiscNumber}})
|
||||
}
|
||||
return r.addMediaFileIds(clauses)
|
||||
}
|
||||
|
||||
// Get ids from all current tracks
|
||||
func (r *playlistTrackRepository) getTracks() ([]string, error) {
|
||||
all := r.newSelect().Columns("media_file_id").Where(Eq{"playlist_id": r.playlistId}).OrderBy("id")
|
||||
var ids []string
|
||||
err := r.queryAllSlice(all, &ids)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error querying current tracks from playlist", "playlistId", r.playlistId, err)
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) Delete(ids ...string) error {
|
||||
if !r.isTracksEditable() {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
err := r.delete(And{Eq{"playlist_id": r.playlistId}, Eq{"id": ids}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.playlistRepo.renumber(r.playlistId)
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) DeleteAll() error {
|
||||
if !r.isTracksEditable() {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
err := r.delete(Eq{"playlist_id": r.playlistId})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.playlistRepo.renumber(r.playlistId)
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) Reorder(pos int, newPos int) error {
|
||||
if !r.isTracksEditable() {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
ids, err := r.getTracks()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newOrder := slice.Move(ids, pos-1, newPos-1)
|
||||
return r.playlistRepo.updatePlaylist(r.playlistId, newOrder)
|
||||
}
|
||||
|
||||
var _ model.PlaylistTrackRepository = (*playlistTrackRepository)(nil)
|
||||
178
persistence/playqueue_repository.go
Normal file
178
persistence/playqueue_repository.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type playQueueRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewPlayQueueRepository(ctx context.Context, db dbx.Builder) model.PlayQueueRepository {
|
||||
r := &playQueueRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.tableName = "playqueue"
|
||||
return r
|
||||
}
|
||||
|
||||
type playQueue struct {
|
||||
ID string `structs:"id"`
|
||||
UserID string `structs:"user_id"`
|
||||
Current int `structs:"current"`
|
||||
Position int64 `structs:"position"`
|
||||
ChangedBy string `structs:"changed_by"`
|
||||
Items string `structs:"items"`
|
||||
CreatedAt time.Time `structs:"created_at"`
|
||||
UpdatedAt time.Time `structs:"updated_at"`
|
||||
}
|
||||
|
||||
func (r *playQueueRepository) Store(q *model.PlayQueue, colNames ...string) error {
|
||||
u := loggedUser(r.ctx)
|
||||
|
||||
// Always find existing playqueue for this user
|
||||
existingQueue, err := r.Retrieve(q.UserID)
|
||||
if err != nil && !errors.Is(err, model.ErrNotFound) {
|
||||
log.Error(r.ctx, "Error retrieving existing playqueue", "user", u.UserName, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Use existing ID if found, otherwise keep the provided ID (which may be empty for new records)
|
||||
if !errors.Is(err, model.ErrNotFound) && existingQueue.ID != "" {
|
||||
q.ID = existingQueue.ID
|
||||
}
|
||||
|
||||
// When no specific columns are provided, we replace the whole queue
|
||||
if len(colNames) == 0 {
|
||||
err := r.clearPlayQueue(q.UserID)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error deleting previous playqueue", "user", u.UserName, err)
|
||||
return err
|
||||
}
|
||||
if len(q.Items) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
pq := r.fromModel(q)
|
||||
if pq.ID == "" {
|
||||
pq.CreatedAt = time.Now()
|
||||
}
|
||||
pq.UpdatedAt = time.Now()
|
||||
_, err = r.put(pq.ID, pq, colNames...)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error saving playqueue", "user", u.UserName, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *playQueueRepository) RetrieveWithMediaFiles(userId string) (*model.PlayQueue, error) {
|
||||
sel := r.newSelect().Columns("*").Where(Eq{"user_id": userId})
|
||||
var res playQueue
|
||||
err := r.queryOne(sel, &res)
|
||||
q := r.toModel(&res)
|
||||
q.Items = r.loadTracks(q.Items)
|
||||
return &q, err
|
||||
}
|
||||
|
||||
func (r *playQueueRepository) Retrieve(userId string) (*model.PlayQueue, error) {
|
||||
sel := r.newSelect().Columns("*").Where(Eq{"user_id": userId})
|
||||
var res playQueue
|
||||
err := r.queryOne(sel, &res)
|
||||
q := r.toModel(&res)
|
||||
return &q, err
|
||||
}
|
||||
|
||||
func (r *playQueueRepository) fromModel(q *model.PlayQueue) playQueue {
|
||||
pq := playQueue{
|
||||
ID: q.ID,
|
||||
UserID: q.UserID,
|
||||
Current: q.Current,
|
||||
Position: q.Position,
|
||||
ChangedBy: q.ChangedBy,
|
||||
CreatedAt: q.CreatedAt,
|
||||
UpdatedAt: q.UpdatedAt,
|
||||
}
|
||||
var itemIDs []string
|
||||
for _, t := range q.Items {
|
||||
itemIDs = append(itemIDs, t.ID)
|
||||
}
|
||||
pq.Items = strings.Join(itemIDs, ",")
|
||||
return pq
|
||||
}
|
||||
|
||||
func (r *playQueueRepository) toModel(pq *playQueue) model.PlayQueue {
|
||||
q := model.PlayQueue{
|
||||
ID: pq.ID,
|
||||
UserID: pq.UserID,
|
||||
Current: pq.Current,
|
||||
Position: pq.Position,
|
||||
ChangedBy: pq.ChangedBy,
|
||||
CreatedAt: pq.CreatedAt,
|
||||
UpdatedAt: pq.UpdatedAt,
|
||||
}
|
||||
if strings.TrimSpace(pq.Items) != "" {
|
||||
tracks := strings.Split(pq.Items, ",")
|
||||
for _, t := range tracks {
|
||||
q.Items = append(q.Items, model.MediaFile{ID: t})
|
||||
}
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// loadTracks loads the tracks from the database. It receives a list of track IDs and returns a list of MediaFiles
|
||||
// in the same order as the input list.
|
||||
func (r *playQueueRepository) loadTracks(tracks model.MediaFiles) model.MediaFiles {
|
||||
if len(tracks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.db, nil)
|
||||
trackMap := map[string]model.MediaFile{}
|
||||
|
||||
// Create an iterator to collect all track IDs
|
||||
ids := slice.SeqFunc(tracks, func(t model.MediaFile) string { return t.ID })
|
||||
|
||||
// Break the list in chunks, up to 500 items, to avoid hitting SQLITE_MAX_VARIABLE_NUMBER limit
|
||||
for chunk := range slice.CollectChunks(ids, 500) {
|
||||
idsFilter := Eq{"media_file.id": chunk}
|
||||
tracks, err := mfRepo.GetAll(model.QueryOptions{Filters: idsFilter})
|
||||
if err != nil {
|
||||
u := loggedUser(r.ctx)
|
||||
log.Error(r.ctx, "Could not load playqueue/bookmark's tracks", "user", u.UserName, err)
|
||||
}
|
||||
for _, t := range tracks {
|
||||
trackMap[t.ID] = t
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new list of tracks with the same order as the original
|
||||
// Exclude tracks that are not in the DB anymore
|
||||
newTracks := make(model.MediaFiles, 0, len(tracks))
|
||||
for _, t := range tracks {
|
||||
if track, ok := trackMap[t.ID]; ok {
|
||||
newTracks = append(newTracks, track)
|
||||
}
|
||||
}
|
||||
return newTracks
|
||||
}
|
||||
|
||||
func (r *playQueueRepository) clearPlayQueue(userId string) error {
|
||||
return r.delete(Eq{"user_id": userId})
|
||||
}
|
||||
|
||||
func (r *playQueueRepository) Clear(userId string) error {
|
||||
return r.clearPlayQueue(userId)
|
||||
}
|
||||
|
||||
var _ model.PlayQueueRepository = (*playQueueRepository)(nil)
|
||||
435
persistence/playqueue_repository_test.go
Normal file
435
persistence/playqueue_repository_test.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/conf/configtest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("PlayQueueRepository", func() {
|
||||
var repo model.PlayQueueRepository
|
||||
var ctx context.Context
|
||||
|
||||
BeforeEach(func() {
|
||||
DeferCleanup(configtest.SetupConfig())
|
||||
ctx = log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
|
||||
repo = NewPlayQueueRepository(ctx, GetDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Store", func() {
|
||||
It("stores a complete playqueue", func() {
|
||||
expected := aPlayQueue("userid", 1, 123, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(expected)).To(Succeed())
|
||||
|
||||
actual, err := repo.RetrieveWithMediaFiles("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
AssertPlayQueue(expected, actual)
|
||||
Expect(countPlayQueues(repo, "userid")).To(Equal(1))
|
||||
})
|
||||
|
||||
It("replaces existing playqueue when storing without column names", func() {
|
||||
By("Storing initial playqueue")
|
||||
initial := aPlayQueue("userid", 0, 100, songComeTogether)
|
||||
Expect(repo.Store(initial)).To(Succeed())
|
||||
|
||||
By("Storing replacement playqueue")
|
||||
replacement := aPlayQueue("userid", 1, 200, songDayInALife, songAntenna)
|
||||
Expect(repo.Store(replacement)).To(Succeed())
|
||||
|
||||
actual, err := repo.RetrieveWithMediaFiles("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
AssertPlayQueue(replacement, actual)
|
||||
Expect(countPlayQueues(repo, "userid")).To(Equal(1))
|
||||
})
|
||||
|
||||
It("clears playqueue when storing empty items", func() {
|
||||
By("Storing initial playqueue")
|
||||
initial := aPlayQueue("userid", 0, 100, songComeTogether)
|
||||
Expect(repo.Store(initial)).To(Succeed())
|
||||
|
||||
By("Storing empty playqueue")
|
||||
empty := aPlayQueue("userid", 0, 0)
|
||||
Expect(repo.Store(empty)).To(Succeed())
|
||||
|
||||
By("Verifying playqueue is cleared")
|
||||
_, err := repo.Retrieve("userid")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
|
||||
It("updates only current field when specified", func() {
|
||||
By("Storing initial playqueue")
|
||||
initial := aPlayQueue("userid", 0, 100, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(initial)).To(Succeed())
|
||||
|
||||
By("Getting the existing playqueue to obtain its ID")
|
||||
existing, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("Updating only current field")
|
||||
update := &model.PlayQueue{
|
||||
ID: existing.ID, // Use existing ID for partial update
|
||||
UserID: "userid",
|
||||
Current: 1,
|
||||
ChangedBy: "test-update",
|
||||
}
|
||||
Expect(repo.Store(update, "current")).To(Succeed())
|
||||
|
||||
By("Verifying only current was updated")
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Current).To(Equal(1))
|
||||
Expect(actual.Position).To(Equal(int64(100))) // Should remain unchanged
|
||||
Expect(actual.Items).To(HaveLen(2)) // Should remain unchanged
|
||||
})
|
||||
|
||||
It("updates only position field when specified", func() {
|
||||
By("Storing initial playqueue")
|
||||
initial := aPlayQueue("userid", 1, 100, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(initial)).To(Succeed())
|
||||
|
||||
By("Getting the existing playqueue to obtain its ID")
|
||||
existing, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("Updating only position field")
|
||||
update := &model.PlayQueue{
|
||||
ID: existing.ID, // Use existing ID for partial update
|
||||
UserID: "userid",
|
||||
Position: 500,
|
||||
ChangedBy: "test-update",
|
||||
}
|
||||
Expect(repo.Store(update, "position")).To(Succeed())
|
||||
|
||||
By("Verifying only position was updated")
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Position).To(Equal(int64(500)))
|
||||
Expect(actual.Current).To(Equal(1)) // Should remain unchanged
|
||||
Expect(actual.Items).To(HaveLen(2)) // Should remain unchanged
|
||||
})
|
||||
|
||||
It("updates multiple specified fields", func() {
|
||||
By("Storing initial playqueue")
|
||||
initial := aPlayQueue("userid", 0, 100, songComeTogether)
|
||||
Expect(repo.Store(initial)).To(Succeed())
|
||||
|
||||
By("Getting the existing playqueue to obtain its ID")
|
||||
existing, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("Updating current and position fields")
|
||||
update := &model.PlayQueue{
|
||||
ID: existing.ID, // Use existing ID for partial update
|
||||
UserID: "userid",
|
||||
Current: 1,
|
||||
Position: 300,
|
||||
ChangedBy: "test-update",
|
||||
}
|
||||
Expect(repo.Store(update, "current", "position")).To(Succeed())
|
||||
|
||||
By("Verifying both fields were updated")
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Current).To(Equal(1))
|
||||
Expect(actual.Position).To(Equal(int64(300)))
|
||||
Expect(actual.Items).To(HaveLen(1)) // Should remain unchanged
|
||||
})
|
||||
|
||||
It("preserves existing data when updating with empty items list and column names", func() {
|
||||
By("Storing initial playqueue")
|
||||
initial := aPlayQueue("userid", 0, 100, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(initial)).To(Succeed())
|
||||
|
||||
By("Getting the existing playqueue to obtain its ID")
|
||||
existing, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("Updating only position with empty items")
|
||||
update := &model.PlayQueue{
|
||||
ID: existing.ID, // Use existing ID for partial update
|
||||
UserID: "userid",
|
||||
Position: 200,
|
||||
ChangedBy: "test-update",
|
||||
Items: []model.MediaFile{}, // Empty items
|
||||
}
|
||||
Expect(repo.Store(update, "position")).To(Succeed())
|
||||
|
||||
By("Verifying items are preserved")
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Position).To(Equal(int64(200)))
|
||||
Expect(actual.Items).To(HaveLen(2)) // Should remain unchanged
|
||||
})
|
||||
|
||||
It("ensures only one record per user by reusing existing record ID", func() {
|
||||
By("Storing initial playqueue")
|
||||
initial := aPlayQueue("userid", 0, 100, songComeTogether)
|
||||
Expect(repo.Store(initial)).To(Succeed())
|
||||
initialCount := countPlayQueues(repo, "userid")
|
||||
Expect(initialCount).To(Equal(1))
|
||||
|
||||
By("Storing another playqueue with different ID but same user")
|
||||
different := aPlayQueue("userid", 1, 200, songDayInALife)
|
||||
different.ID = "different-id" // Force a different ID
|
||||
Expect(repo.Store(different)).To(Succeed())
|
||||
|
||||
By("Verifying only one record exists for the user")
|
||||
finalCount := countPlayQueues(repo, "userid")
|
||||
Expect(finalCount).To(Equal(1))
|
||||
|
||||
By("Verifying the record was updated, not duplicated")
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Current).To(Equal(1)) // Should be updated value
|
||||
Expect(actual.Position).To(Equal(int64(200))) // Should be updated value
|
||||
Expect(actual.Items).To(HaveLen(1)) // Should be new items
|
||||
Expect(actual.Items[0].ID).To(Equal(songDayInALife.ID))
|
||||
})
|
||||
|
||||
It("ensures only one record per user even with partial updates", func() {
|
||||
By("Storing initial playqueue")
|
||||
initial := aPlayQueue("userid", 0, 100, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(initial)).To(Succeed())
|
||||
initialCount := countPlayQueues(repo, "userid")
|
||||
Expect(initialCount).To(Equal(1))
|
||||
|
||||
By("Storing partial update with different ID but same user")
|
||||
partialUpdate := &model.PlayQueue{
|
||||
ID: "completely-different-id", // Use a completely different ID
|
||||
UserID: "userid",
|
||||
Current: 1,
|
||||
ChangedBy: "test-partial",
|
||||
}
|
||||
Expect(repo.Store(partialUpdate, "current")).To(Succeed())
|
||||
|
||||
By("Verifying only one record still exists for the user")
|
||||
finalCount := countPlayQueues(repo, "userid")
|
||||
Expect(finalCount).To(Equal(1))
|
||||
|
||||
By("Verifying the existing record was updated with new current value")
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Current).To(Equal(1)) // Should be updated value
|
||||
Expect(actual.Position).To(Equal(int64(100))) // Should remain unchanged
|
||||
Expect(actual.Items).To(HaveLen(2)) // Should remain unchanged
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Retrieve", func() {
|
||||
It("returns notfound error if there's no playqueue for the user", func() {
|
||||
_, err := repo.Retrieve("user999")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
|
||||
It("retrieves the playqueue with only track IDs (no full MediaFile data)", func() {
|
||||
By("Storing a playqueue for the user")
|
||||
|
||||
expected := aPlayQueue("userid", 1, 123, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(expected)).To(Succeed())
|
||||
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Basic playqueue properties should match
|
||||
Expect(actual.ID).To(Equal(expected.ID))
|
||||
Expect(actual.UserID).To(Equal(expected.UserID))
|
||||
Expect(actual.Current).To(Equal(expected.Current))
|
||||
Expect(actual.Position).To(Equal(expected.Position))
|
||||
Expect(actual.ChangedBy).To(Equal(expected.ChangedBy))
|
||||
Expect(actual.Items).To(HaveLen(len(expected.Items)))
|
||||
|
||||
// Items should only contain IDs, not full MediaFile data
|
||||
for i, item := range actual.Items {
|
||||
Expect(item.ID).To(Equal(expected.Items[i].ID))
|
||||
// These fields should be empty since we're not loading full MediaFiles
|
||||
Expect(item.Title).To(BeEmpty())
|
||||
Expect(item.Path).To(BeEmpty())
|
||||
Expect(item.Album).To(BeEmpty())
|
||||
Expect(item.Artist).To(BeEmpty())
|
||||
}
|
||||
})
|
||||
|
||||
It("returns items with IDs even when some tracks don't exist in the DB", func() {
|
||||
// Add a new song to the DB
|
||||
newSong := songRadioactivity
|
||||
newSong.ID = "temp-track"
|
||||
newSong.Path = "/new-path"
|
||||
mfRepo := NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
|
||||
|
||||
Expect(mfRepo.Put(&newSong)).To(Succeed())
|
||||
|
||||
// Create a playqueue with the new song
|
||||
pq := aPlayQueue("userid", 0, 0, newSong, songAntenna)
|
||||
Expect(repo.Store(pq)).To(Succeed())
|
||||
|
||||
// Delete the new song from the database
|
||||
Expect(mfRepo.Delete("temp-track")).To(Succeed())
|
||||
|
||||
// Retrieve the playqueue with Retrieve method
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// The playqueue should still contain both track IDs (including the deleted one)
|
||||
Expect(actual.Items).To(HaveLen(2))
|
||||
Expect(actual.Items[0].ID).To(Equal("temp-track"))
|
||||
Expect(actual.Items[1].ID).To(Equal(songAntenna.ID))
|
||||
|
||||
// Items should only contain IDs, no other data
|
||||
for _, item := range actual.Items {
|
||||
Expect(item.Title).To(BeEmpty())
|
||||
Expect(item.Path).To(BeEmpty())
|
||||
Expect(item.Album).To(BeEmpty())
|
||||
Expect(item.Artist).To(BeEmpty())
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Describe("RetrieveWithMediaFiles", func() {
|
||||
It("returns notfound error if there's no playqueue for the user", func() {
|
||||
_, err := repo.RetrieveWithMediaFiles("user999")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
|
||||
It("retrieves the playqueue with full MediaFile data", func() {
|
||||
By("Storing a playqueue for the user")
|
||||
|
||||
expected := aPlayQueue("userid", 1, 123, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(expected)).To(Succeed())
|
||||
|
||||
actual, err := repo.RetrieveWithMediaFiles("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
AssertPlayQueue(expected, actual)
|
||||
})
|
||||
|
||||
It("does not return tracks if they don't exist in the DB", func() {
|
||||
// Add a new song to the DB
|
||||
newSong := songRadioactivity
|
||||
newSong.ID = "temp-track"
|
||||
newSong.Path = "/new-path"
|
||||
mfRepo := NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
|
||||
|
||||
Expect(mfRepo.Put(&newSong)).To(Succeed())
|
||||
|
||||
// Create a playqueue with the new song
|
||||
pq := aPlayQueue("userid", 0, 0, newSong, songAntenna)
|
||||
Expect(repo.Store(pq)).To(Succeed())
|
||||
|
||||
// Retrieve the playqueue
|
||||
actual, err := repo.RetrieveWithMediaFiles("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// The playqueue should contain both tracks
|
||||
AssertPlayQueue(pq, actual)
|
||||
|
||||
// Delete the new song
|
||||
Expect(mfRepo.Delete("temp-track")).To(Succeed())
|
||||
|
||||
// Retrieve the playqueue
|
||||
actual, err = repo.RetrieveWithMediaFiles("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// The playqueue should not contain the deleted track
|
||||
Expect(actual.Items).To(HaveLen(1))
|
||||
Expect(actual.Items[0].ID).To(Equal(songAntenna.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Clear", func() {
|
||||
It("clears an existing playqueue", func() {
|
||||
By("Storing a playqueue")
|
||||
expected := aPlayQueue("userid", 1, 123, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(expected)).To(Succeed())
|
||||
|
||||
By("Verifying playqueue exists")
|
||||
_, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
By("Clearing the playqueue")
|
||||
Expect(repo.Clear("userid")).To(Succeed())
|
||||
|
||||
By("Verifying playqueue is cleared")
|
||||
_, err = repo.Retrieve("userid")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
|
||||
It("does not error when clearing non-existent playqueue", func() {
|
||||
// Clear should not error even if no playqueue exists
|
||||
Expect(repo.Clear("nonexistent-user")).To(Succeed())
|
||||
})
|
||||
|
||||
It("only clears the specified user's playqueue", func() {
|
||||
By("Creating users in the database to avoid foreign key constraints")
|
||||
userRepo := NewUserRepository(ctx, GetDBXBuilder())
|
||||
user1 := &model.User{ID: "user1", UserName: "user1", Name: "User 1", Email: "user1@test.com"}
|
||||
user2 := &model.User{ID: "user2", UserName: "user2", Name: "User 2", Email: "user2@test.com"}
|
||||
Expect(userRepo.Put(user1)).To(Succeed())
|
||||
Expect(userRepo.Put(user2)).To(Succeed())
|
||||
|
||||
By("Storing playqueues for two users")
|
||||
user1Queue := aPlayQueue("user1", 0, 100, songComeTogether)
|
||||
user2Queue := aPlayQueue("user2", 1, 200, songDayInALife)
|
||||
Expect(repo.Store(user1Queue)).To(Succeed())
|
||||
Expect(repo.Store(user2Queue)).To(Succeed())
|
||||
|
||||
By("Clearing only user1's playqueue")
|
||||
Expect(repo.Clear("user1")).To(Succeed())
|
||||
|
||||
By("Verifying user1's playqueue is cleared")
|
||||
_, err := repo.Retrieve("user1")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
|
||||
By("Verifying user2's playqueue still exists")
|
||||
actual, err := repo.Retrieve("user2")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.UserID).To(Equal("user2"))
|
||||
Expect(actual.Current).To(Equal(1))
|
||||
Expect(actual.Position).To(Equal(int64(200)))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func countPlayQueues(repo model.PlayQueueRepository, userId string) int {
|
||||
r := repo.(*playQueueRepository)
|
||||
c, err := r.count(squirrel.Select().Where(squirrel.Eq{"user_id": userId}))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return int(c)
|
||||
}
|
||||
|
||||
func AssertPlayQueue(expected, actual *model.PlayQueue) {
|
||||
Expect(actual.ID).To(Equal(expected.ID))
|
||||
Expect(actual.UserID).To(Equal(expected.UserID))
|
||||
Expect(actual.Current).To(Equal(expected.Current))
|
||||
Expect(actual.Position).To(Equal(expected.Position))
|
||||
Expect(actual.ChangedBy).To(Equal(expected.ChangedBy))
|
||||
Expect(actual.Items).To(HaveLen(len(expected.Items)))
|
||||
for i, item := range actual.Items {
|
||||
Expect(item.Title).To(Equal(expected.Items[i].Title))
|
||||
}
|
||||
}
|
||||
|
||||
func aPlayQueue(userId string, current int, position int64, items ...model.MediaFile) *model.PlayQueue {
|
||||
createdAt := time.Now()
|
||||
updatedAt := createdAt.Add(time.Minute)
|
||||
return &model.PlayQueue{
|
||||
ID: id.NewRandom(),
|
||||
UserID: userId,
|
||||
Current: current,
|
||||
Position: position,
|
||||
ChangedBy: "test",
|
||||
Items: items,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
}
|
||||
}
|
||||
63
persistence/property_repository.go
Normal file
63
persistence/property_repository.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type propertyRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewPropertyRepository(ctx context.Context, db dbx.Builder) model.PropertyRepository {
|
||||
r := &propertyRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.tableName = "property"
|
||||
return r
|
||||
}
|
||||
|
||||
func (r propertyRepository) Put(id string, value string) error {
|
||||
update := Update(r.tableName).Set("value", value).Where(Eq{"id": id})
|
||||
count, err := r.executeSQL(update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
insert := Insert(r.tableName).Columns("id", "value").Values(id, value)
|
||||
_, err = r.executeSQL(insert)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r propertyRepository) Get(id string) (string, error) {
|
||||
sel := Select("value").From(r.tableName).Where(Eq{"id": id})
|
||||
resp := struct {
|
||||
Value string
|
||||
}{}
|
||||
err := r.queryOne(sel, &resp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.Value, nil
|
||||
}
|
||||
|
||||
func (r propertyRepository) DefaultGet(id string, defaultValue string) (string, error) {
|
||||
value, err := r.Get(id)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
if err != nil {
|
||||
return defaultValue, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (r propertyRepository) Delete(id string) error {
|
||||
return r.delete(Eq{"id": id})
|
||||
}
|
||||
34
persistence/property_repository_test.go
Normal file
34
persistence/property_repository_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Property Repository", func() {
|
||||
var pr model.PropertyRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
pr = NewPropertyRepository(log.NewContext(context.TODO()), GetDBXBuilder())
|
||||
})
|
||||
|
||||
It("saves and restore a new property", func() {
|
||||
id := "1"
|
||||
value := "a_value"
|
||||
Expect(pr.Put(id, value)).To(BeNil())
|
||||
Expect(pr.Get(id)).To(Equal("a_value"))
|
||||
})
|
||||
|
||||
It("updates a property", func() {
|
||||
Expect(pr.Put("1", "another_value")).To(BeNil())
|
||||
Expect(pr.Get("1")).To(Equal("another_value"))
|
||||
})
|
||||
|
||||
It("returns a default value if property does not exist", func() {
|
||||
Expect(pr.DefaultGet("2", "default")).To(Equal("default"))
|
||||
})
|
||||
})
|
||||
139
persistence/radio_repository.go
Normal file
139
persistence/radio_repository.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type radioRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewRadioRepository(ctx context.Context, db dbx.Builder) model.RadioRepository {
|
||||
r := &radioRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.registerModel(&model.Radio{}, map[string]filterFunc{
|
||||
"name": containsFilter("name"),
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *radioRepository) isPermitted() bool {
|
||||
user := loggedUser(r.ctx)
|
||||
return user.IsAdmin
|
||||
}
|
||||
|
||||
func (r *radioRepository) CountAll(options ...model.QueryOptions) (int64, error) {
|
||||
sql := r.newSelect()
|
||||
return r.count(sql, options...)
|
||||
}
|
||||
|
||||
func (r *radioRepository) Delete(id string) error {
|
||||
if !r.isPermitted() {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
|
||||
return r.delete(Eq{"id": id})
|
||||
}
|
||||
|
||||
func (r *radioRepository) Get(id string) (*model.Radio, error) {
|
||||
sel := r.newSelect().Where(Eq{"id": id}).Columns("*")
|
||||
res := model.Radio{}
|
||||
err := r.queryOne(sel, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *radioRepository) GetAll(options ...model.QueryOptions) (model.Radios, error) {
|
||||
sel := r.newSelect(options...).Columns("*")
|
||||
res := model.Radios{}
|
||||
err := r.queryAll(sel, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *radioRepository) Put(radio *model.Radio) error {
|
||||
if !r.isPermitted() {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
|
||||
var values map[string]interface{}
|
||||
|
||||
radio.UpdatedAt = time.Now()
|
||||
|
||||
if radio.ID == "" {
|
||||
radio.CreatedAt = time.Now()
|
||||
radio.ID = id.NewRandom()
|
||||
values, _ = toSQLArgs(*radio)
|
||||
} else {
|
||||
values, _ = toSQLArgs(*radio)
|
||||
update := Update(r.tableName).Where(Eq{"id": radio.ID}).SetMap(values)
|
||||
count, err := r.executeSQL(update)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
} else if count > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
values["created_at"] = time.Now()
|
||||
insert := Insert(r.tableName).SetMap(values)
|
||||
_, err := r.executeSQL(insert)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *radioRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *radioRepository) EntityName() string {
|
||||
return "radio"
|
||||
}
|
||||
|
||||
func (r *radioRepository) NewInstance() interface{} {
|
||||
return &model.Radio{}
|
||||
}
|
||||
|
||||
func (r *radioRepository) Read(id string) (interface{}, error) {
|
||||
return r.Get(id)
|
||||
}
|
||||
|
||||
func (r *radioRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *radioRepository) Save(entity interface{}) (string, error) {
|
||||
t := entity.(*model.Radio)
|
||||
if !r.isPermitted() {
|
||||
return "", rest.ErrPermissionDenied
|
||||
}
|
||||
err := r.Put(t)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return "", rest.ErrNotFound
|
||||
}
|
||||
return t.ID, err
|
||||
}
|
||||
|
||||
func (r *radioRepository) Update(id string, entity interface{}, cols ...string) error {
|
||||
t := entity.(*model.Radio)
|
||||
t.ID = id
|
||||
if !r.isPermitted() {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
err := r.Put(t)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var _ model.RadioRepository = (*radioRepository)(nil)
|
||||
var _ rest.Repository = (*radioRepository)(nil)
|
||||
var _ rest.Persistable = (*radioRepository)(nil)
|
||||
175
persistence/radio_repository_test.go
Normal file
175
persistence/radio_repository_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var (
|
||||
NewId string = "123-456-789"
|
||||
)
|
||||
|
||||
var _ = Describe("RadioRepository", func() {
|
||||
var repo model.RadioRepository
|
||||
|
||||
Describe("Admin User", func() {
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
|
||||
repo = NewRadioRepository(ctx, GetDBXBuilder())
|
||||
_ = repo.Put(&radioWithHomePage)
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
all, _ := repo.GetAll()
|
||||
|
||||
for _, radio := range all {
|
||||
_ = repo.Delete(radio.ID)
|
||||
}
|
||||
|
||||
for i := range testRadios {
|
||||
r := testRadios[i]
|
||||
err := repo.Put(&r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
It("returns the number of radios in the DB", func() {
|
||||
Expect(repo.CountAll()).To(Equal(int64(2)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Delete", func() {
|
||||
It("deletes existing item", func() {
|
||||
err := repo.Delete(radioWithHomePage.ID)
|
||||
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
_, err = repo.Get(radioWithHomePage.ID)
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Get", func() {
|
||||
It("returns an existing item", func() {
|
||||
res, err := repo.Get(radioWithHomePage.ID)
|
||||
|
||||
Expect(err).To(BeNil())
|
||||
Expect(res.ID).To(Equal(radioWithHomePage.ID))
|
||||
})
|
||||
|
||||
It("errors when missing", func() {
|
||||
_, err := repo.Get("notanid")
|
||||
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetAll", func() {
|
||||
It("returns all items from the DB", func() {
|
||||
all, err := repo.GetAll()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(all[0].ID).To(Equal(radioWithoutHomePage.ID))
|
||||
Expect(all[1].ID).To(Equal(radioWithHomePage.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Put", func() {
|
||||
It("successfully updates item", func() {
|
||||
err := repo.Put(&model.Radio{
|
||||
ID: radioWithHomePage.ID,
|
||||
Name: "New Name",
|
||||
StreamUrl: "https://example.com:4533/app",
|
||||
})
|
||||
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
item, err := repo.Get(radioWithHomePage.ID)
|
||||
Expect(err).To(BeNil())
|
||||
|
||||
Expect(item.HomePageUrl).To(Equal(""))
|
||||
})
|
||||
|
||||
It("successfully creates item", func() {
|
||||
err := repo.Put(&model.Radio{
|
||||
Name: "New radio",
|
||||
StreamUrl: "https://example.com:4533/app",
|
||||
})
|
||||
|
||||
Expect(err).To(BeNil())
|
||||
Expect(repo.CountAll()).To(Equal(int64(3)))
|
||||
|
||||
all, err := repo.GetAll()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(all[2].StreamUrl).To(Equal("https://example.com:4533/app"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Regular User", func() {
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: false})
|
||||
repo = NewRadioRepository(ctx, GetDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
It("returns the number of radios in the DB", func() {
|
||||
Expect(repo.CountAll()).To(Equal(int64(2)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Delete", func() {
|
||||
It("fails to delete items", func() {
|
||||
err := repo.Delete(radioWithHomePage.ID)
|
||||
|
||||
Expect(err).To(Equal(rest.ErrPermissionDenied))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Get", func() {
|
||||
It("returns an existing item", func() {
|
||||
res, err := repo.Get(radioWithHomePage.ID)
|
||||
|
||||
Expect(err).To((BeNil()))
|
||||
Expect(res.ID).To(Equal(radioWithHomePage.ID))
|
||||
})
|
||||
|
||||
It("errors when missing", func() {
|
||||
_, err := repo.Get("notanid")
|
||||
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("GetAll", func() {
|
||||
It("returns all items from the DB", func() {
|
||||
all, err := repo.GetAll()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(all[0].ID).To(Equal(radioWithoutHomePage.ID))
|
||||
Expect(all[1].ID).To(Equal(radioWithHomePage.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Put", func() {
|
||||
It("fails to update item", func() {
|
||||
err := repo.Put(&model.Radio{
|
||||
ID: radioWithHomePage.ID,
|
||||
Name: "New Name",
|
||||
StreamUrl: "https://example.com:4533/app",
|
||||
})
|
||||
|
||||
Expect(err).To(Equal(rest.ErrPermissionDenied))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
100
persistence/scrobble_buffer_repository.go
Normal file
100
persistence/scrobble_buffer_repository.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type scrobbleBufferRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
type dbScrobbleBuffer struct {
|
||||
dbMediaFile
|
||||
*model.ScrobbleEntry `structs:",flatten"`
|
||||
}
|
||||
|
||||
func (t *dbScrobbleBuffer) PostScan() error {
|
||||
if err := t.dbMediaFile.PostScan(); err != nil {
|
||||
return err
|
||||
}
|
||||
t.ScrobbleEntry.MediaFile = *t.dbMediaFile.MediaFile
|
||||
t.ScrobbleEntry.MediaFile.ID = t.MediaFileID
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewScrobbleBufferRepository(ctx context.Context, db dbx.Builder) model.ScrobbleBufferRepository {
|
||||
r := &scrobbleBufferRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.tableName = "scrobble_buffer"
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *scrobbleBufferRepository) UserIDs(service string) ([]string, error) {
|
||||
sql := Select().Columns("user_id").
|
||||
From(r.tableName).
|
||||
Where(And{
|
||||
Eq{"service": service},
|
||||
}).
|
||||
GroupBy("user_id").
|
||||
OrderBy("count(*)")
|
||||
var userIds []string
|
||||
err := r.queryAllSlice(sql, &userIds)
|
||||
return userIds, err
|
||||
}
|
||||
|
||||
func (r *scrobbleBufferRepository) Enqueue(service, userId, mediaFileId string, playTime time.Time) error {
|
||||
ins := Insert(r.tableName).SetMap(map[string]interface{}{
|
||||
"id": id.NewRandom(),
|
||||
"user_id": userId,
|
||||
"service": service,
|
||||
"media_file_id": mediaFileId,
|
||||
"play_time": playTime,
|
||||
"enqueue_time": time.Now(),
|
||||
})
|
||||
_, err := r.executeSQL(ins)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *scrobbleBufferRepository) Next(service string, userId string) (*model.ScrobbleEntry, error) {
|
||||
// Put `s.*` last or else m.id overrides s.id
|
||||
sql := Select().Columns("m.*, s.*").
|
||||
From(r.tableName+" s").
|
||||
LeftJoin("media_file m on m.id = s.media_file_id").
|
||||
Where(And{
|
||||
Eq{"service": service},
|
||||
Eq{"user_id": userId},
|
||||
}).
|
||||
OrderBy("play_time", "s.rowid").Limit(1)
|
||||
|
||||
var res dbScrobbleBuffer
|
||||
err := r.queryOne(sql, &res)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.ScrobbleEntry.Participants, err = r.getParticipants(&res.ScrobbleEntry.MediaFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.ScrobbleEntry, nil
|
||||
}
|
||||
|
||||
func (r *scrobbleBufferRepository) Dequeue(entry *model.ScrobbleEntry) error {
|
||||
return r.delete(Eq{"id": entry.ID})
|
||||
}
|
||||
|
||||
func (r *scrobbleBufferRepository) Length() (int64, error) {
|
||||
return r.count(Select())
|
||||
}
|
||||
|
||||
var _ model.ScrobbleBufferRepository = (*scrobbleBufferRepository)(nil)
|
||||
208
persistence/scrobble_buffer_repository_test.go
Normal file
208
persistence/scrobble_buffer_repository_test.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ScrobbleBufferRepository", func() {
|
||||
var scrobble model.ScrobbleBufferRepository
|
||||
var rawRepo sqlRepository
|
||||
|
||||
enqueueTime := time.Date(2025, 01, 01, 00, 00, 00, 00, time.Local)
|
||||
var ids []string
|
||||
|
||||
var insertManually = func(service, userId, mediaFileId string, playTime time.Time) {
|
||||
id := id.NewRandom()
|
||||
ids = append(ids, id)
|
||||
|
||||
ins := squirrel.Insert("scrobble_buffer").SetMap(map[string]interface{}{
|
||||
"id": id,
|
||||
"user_id": userId,
|
||||
"service": service,
|
||||
"media_file_id": mediaFileId,
|
||||
"play_time": playTime,
|
||||
"enqueue_time": enqueueTime,
|
||||
})
|
||||
_, err := rawRepo.executeSQL(ins)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", UserName: "johndoe", IsAdmin: true})
|
||||
db := GetDBXBuilder()
|
||||
scrobble = NewScrobbleBufferRepository(ctx, db)
|
||||
|
||||
rawRepo = sqlRepository{
|
||||
ctx: ctx,
|
||||
tableName: "scrobble_buffer",
|
||||
db: db,
|
||||
}
|
||||
ids = []string{}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
del := squirrel.Delete(rawRepo.tableName)
|
||||
_, err := rawRepo.executeSQL(del)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Describe("Without data", func() {
|
||||
Describe("Count", func() {
|
||||
It("returns zero when empty", func() {
|
||||
count, err := scrobble.Length()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(BeZero())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Dequeue", func() {
|
||||
It("is a no-op when deleting a nonexistent item", func() {
|
||||
err := scrobble.Dequeue(&model.ScrobbleEntry{ID: "fake"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
count, err := scrobble.Length()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(0)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Next", func() {
|
||||
It("should not fail with no item for the service", func() {
|
||||
entry, err := scrobble.Next("fake", "userid")
|
||||
Expect(entry).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UserIds", func() {
|
||||
It("should return empty list with no data", func() {
|
||||
ids, err := scrobble.UserIDs("service")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ids).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("With data", func() {
|
||||
timeA := enqueueTime.Add(24 * time.Hour)
|
||||
timeB := enqueueTime.Add(48 * time.Hour)
|
||||
timeC := enqueueTime.Add(72 * time.Hour)
|
||||
timeD := enqueueTime.Add(96 * time.Hour)
|
||||
|
||||
BeforeEach(func() {
|
||||
insertManually("a", "userid", "1001", timeB)
|
||||
insertManually("a", "userid", "1002", timeA)
|
||||
insertManually("a", "2222", "1003", timeC)
|
||||
insertManually("b", "2222", "1004", timeD)
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
It("Returns count when populated", func() {
|
||||
count, err := scrobble.Length()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(4)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Dequeue", func() {
|
||||
It("is a no-op when deleting a nonexistent item", func() {
|
||||
err := scrobble.Dequeue(&model.ScrobbleEntry{ID: "fake"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
count, err := scrobble.Length()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(4)))
|
||||
})
|
||||
|
||||
It("deletes an item when specified properly", func() {
|
||||
err := scrobble.Dequeue(&model.ScrobbleEntry{ID: ids[3]})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
count, err := scrobble.Length()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(3)))
|
||||
|
||||
entry, err := scrobble.Next("b", "2222")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(entry).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Enqueue", func() {
|
||||
DescribeTable("enqueues an item properly",
|
||||
func(service, userId, fileId string, playTime time.Time) {
|
||||
now := time.Now()
|
||||
err := scrobble.Enqueue(service, userId, fileId, playTime)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
count, err := scrobble.Length()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(5)))
|
||||
|
||||
entry, err := scrobble.Next(service, userId)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(entry).ToNot(BeNil())
|
||||
|
||||
Expect(entry.EnqueueTime).To(BeTemporally("~", now, 100*time.Millisecond))
|
||||
Expect(entry.MediaFileID).To(Equal(fileId))
|
||||
Expect(entry.PlayTime).To(BeTemporally("==", playTime))
|
||||
},
|
||||
Entry("to an existing service with multiple values", "a", "userid", "1004", enqueueTime),
|
||||
Entry("to a new service", "c", "2222", "1001", timeD),
|
||||
Entry("to an existing service as new user", "b", "userid", "1003", timeC),
|
||||
)
|
||||
})
|
||||
|
||||
Describe("Next", func() {
|
||||
DescribeTable("Returns the next item when populated",
|
||||
func(service, id string, playTime time.Time, fileId, artistId string) {
|
||||
entry, err := scrobble.Next(service, id)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(entry).ToNot(BeNil())
|
||||
|
||||
Expect(entry.Service).To(Equal(service))
|
||||
Expect(entry.UserID).To(Equal(id))
|
||||
Expect(entry.PlayTime).To(BeTemporally("==", playTime))
|
||||
Expect(entry.EnqueueTime).To(BeTemporally("==", enqueueTime))
|
||||
Expect(entry.MediaFileID).To(Equal(fileId))
|
||||
|
||||
Expect(entry.MediaFile.Participants).To(HaveLen(1))
|
||||
|
||||
artists, ok := entry.MediaFile.Participants[model.RoleArtist]
|
||||
Expect(ok).To(BeTrue(), "no artist role in participants")
|
||||
|
||||
Expect(artists).To(HaveLen(1))
|
||||
Expect(artists[0].ID).To(Equal(artistId))
|
||||
},
|
||||
|
||||
Entry("Service with multiple values for one user", "a", "userid", timeA, "1002", "3"),
|
||||
Entry("Service with users", "a", "2222", timeC, "1003", "2"),
|
||||
Entry("Service with one user", "b", "2222", timeD, "1004", "2"),
|
||||
)
|
||||
|
||||
})
|
||||
|
||||
Describe("UserIds", func() {
|
||||
It("should return ordered list for services", func() {
|
||||
ids, err := scrobble.UserIDs("a")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ids).To(Equal([]string{"2222", "userid"}))
|
||||
})
|
||||
|
||||
It("should return for a different service", func() {
|
||||
ids, err := scrobble.UserIDs("b")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(ids).To(Equal([]string{"2222"}))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
34
persistence/scrobble_repository.go
Normal file
34
persistence/scrobble_repository.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type scrobbleRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewScrobbleRepository(ctx context.Context, db dbx.Builder) model.ScrobbleRepository {
|
||||
r := &scrobbleRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.tableName = "scrobbles"
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *scrobbleRepository) RecordScrobble(mediaFileID string, submissionTime time.Time) error {
|
||||
userID := loggedUser(r.ctx).ID
|
||||
values := map[string]interface{}{
|
||||
"media_file_id": mediaFileID,
|
||||
"user_id": userID,
|
||||
"submission_time": submissionTime.Unix(),
|
||||
}
|
||||
insert := Insert(r.tableName).SetMap(values)
|
||||
_, err := r.executeSQL(insert)
|
||||
return err
|
||||
}
|
||||
84
persistence/scrobble_repository_test.go
Normal file
84
persistence/scrobble_repository_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ = Describe("ScrobbleRepository", func() {
|
||||
var repo model.ScrobbleRepository
|
||||
var rawRepo sqlRepository
|
||||
var ctx context.Context
|
||||
var fileID string
|
||||
var userID string
|
||||
|
||||
BeforeEach(func() {
|
||||
fileID = id.NewRandom()
|
||||
userID = id.NewRandom()
|
||||
ctx = request.WithUser(log.NewContext(GinkgoT().Context()), model.User{ID: userID, UserName: "johndoe", IsAdmin: true})
|
||||
db := GetDBXBuilder()
|
||||
repo = NewScrobbleRepository(ctx, db)
|
||||
|
||||
rawRepo = sqlRepository{
|
||||
ctx: ctx,
|
||||
tableName: "scrobbles",
|
||||
db: db,
|
||||
}
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
_, _ = rawRepo.db.Delete("scrobbles", dbx.HashExp{"media_file_id": fileID}).Execute()
|
||||
_, _ = rawRepo.db.Delete("media_file", dbx.HashExp{"id": fileID}).Execute()
|
||||
_, _ = rawRepo.db.Delete("user", dbx.HashExp{"id": userID}).Execute()
|
||||
})
|
||||
|
||||
Describe("RecordScrobble", func() {
|
||||
It("records a scrobble event", func() {
|
||||
submissionTime := time.Now().UTC()
|
||||
|
||||
// Insert User
|
||||
_, err := rawRepo.db.Insert("user", dbx.Params{
|
||||
"id": userID,
|
||||
"user_name": "user",
|
||||
"password": "pw",
|
||||
"created_at": time.Now(),
|
||||
"updated_at": time.Now(),
|
||||
}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Insert MediaFile
|
||||
_, err = rawRepo.db.Insert("media_file", dbx.Params{
|
||||
"id": fileID,
|
||||
"path": "path",
|
||||
"created_at": time.Now(),
|
||||
"updated_at": time.Now(),
|
||||
}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = repo.RecordScrobble(fileID, submissionTime)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify insertion
|
||||
var scrobble struct {
|
||||
MediaFileID string `db:"media_file_id"`
|
||||
UserID string `db:"user_id"`
|
||||
SubmissionTime int64 `db:"submission_time"`
|
||||
}
|
||||
err = rawRepo.db.Select("*").From("scrobbles").
|
||||
Where(dbx.HashExp{"media_file_id": fileID, "user_id": userID}).
|
||||
One(&scrobble)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(scrobble.MediaFileID).To(Equal(fileID))
|
||||
Expect(scrobble.UserID).To(Equal(userID))
|
||||
Expect(scrobble.SubmissionTime).To(Equal(submissionTime.Unix()))
|
||||
})
|
||||
})
|
||||
})
|
||||
202
persistence/share_repository.go
Normal file
202
persistence/share_repository.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type shareRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewShareRepository(ctx context.Context, db dbx.Builder) model.ShareRepository {
|
||||
r := &shareRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.registerModel(&model.Share{}, nil)
|
||||
r.setSortMappings(map[string]string{
|
||||
"username": "username",
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *shareRepository) Delete(id string) error {
|
||||
err := r.delete(Eq{"id": id})
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *shareRepository) selectShare(options ...model.QueryOptions) SelectBuilder {
|
||||
return r.newSelect(options...).Join("user u on u.id = share.user_id").
|
||||
Columns("share.*", "user_name as username")
|
||||
}
|
||||
|
||||
func (r *shareRepository) Exists(id string) (bool, error) {
|
||||
return r.exists(Eq{"id": id})
|
||||
}
|
||||
|
||||
func (r *shareRepository) Get(id string) (*model.Share, error) {
|
||||
sel := r.selectShare().Where(Eq{"share.id": id})
|
||||
var res model.Share
|
||||
err := r.queryOne(sel, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.loadMedia(&res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *shareRepository) GetAll(options ...model.QueryOptions) (model.Shares, error) {
|
||||
sq := r.selectShare(options...)
|
||||
res := model.Shares{}
|
||||
err := r.queryAll(sq, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range res {
|
||||
err = r.loadMedia(&res[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading media for share %s: %w", res[i].ID, err)
|
||||
}
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *shareRepository) loadMedia(share *model.Share) error {
|
||||
var err error
|
||||
ids := strings.Split(share.ResourceIDs, ",")
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
noMissing := func(cond Sqlizer) Sqlizer {
|
||||
return And{cond, Eq{"missing": false}}
|
||||
}
|
||||
switch share.ResourceType {
|
||||
case "artist":
|
||||
albumRepo := NewAlbumRepository(r.ctx, r.db, nil)
|
||||
share.Albums, err = albumRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"album_artist_id": ids}), Sort: "artist"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.db, nil)
|
||||
share.Tracks, err = mfRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"album_artist_id": ids}), Sort: "artist"})
|
||||
return err
|
||||
case "album":
|
||||
albumRepo := NewAlbumRepository(r.ctx, r.db, nil)
|
||||
share.Albums, err = albumRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"album.id": ids})})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.db, nil)
|
||||
share.Tracks, err = mfRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"album_id": ids}), Sort: "album"})
|
||||
return err
|
||||
case "playlist":
|
||||
// Create a context with a fake admin user, to be able to access all playlists
|
||||
ctx := request.WithUser(r.ctx, model.User{IsAdmin: true})
|
||||
plsRepo := NewPlaylistRepository(ctx, r.db)
|
||||
tracks, err := plsRepo.Tracks(ids[0], true).GetAll(model.QueryOptions{Sort: "id", Filters: noMissing(Eq{})})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tracks) >= 0 {
|
||||
share.Tracks = tracks.MediaFiles()
|
||||
}
|
||||
return nil
|
||||
case "media_file":
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.db, nil)
|
||||
tracks, err := mfRepo.GetAll(model.QueryOptions{Filters: noMissing(Eq{"media_file.id": ids})})
|
||||
share.Tracks = sortByIdPosition(tracks, ids)
|
||||
return err
|
||||
}
|
||||
log.Warn(r.ctx, "Unsupported Share ResourceType", "share", share.ID, "resourceType", share.ResourceType)
|
||||
return nil
|
||||
}
|
||||
|
||||
func sortByIdPosition(mfs model.MediaFiles, ids []string) model.MediaFiles {
|
||||
m := map[string]int{}
|
||||
for i, mf := range mfs {
|
||||
m[mf.ID] = i
|
||||
}
|
||||
var sorted model.MediaFiles
|
||||
for _, id := range ids {
|
||||
if idx, ok := m[id]; ok {
|
||||
sorted = append(sorted, mfs[idx])
|
||||
}
|
||||
}
|
||||
return sorted
|
||||
}
|
||||
|
||||
func (r *shareRepository) Update(id string, entity interface{}, cols ...string) error {
|
||||
s := entity.(*model.Share)
|
||||
// TODO Validate record
|
||||
s.ID = id
|
||||
s.UpdatedAt = time.Now()
|
||||
cols = append(cols, "updated_at")
|
||||
_, err := r.put(id, s, cols...)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *shareRepository) Save(entity interface{}) (string, error) {
|
||||
s := entity.(*model.Share)
|
||||
// TODO Validate record
|
||||
u := loggedUser(r.ctx)
|
||||
if s.UserID == "" {
|
||||
s.UserID = u.ID
|
||||
}
|
||||
s.CreatedAt = time.Now()
|
||||
s.UpdatedAt = time.Now()
|
||||
id, err := r.put(s.ID, s)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return "", rest.ErrNotFound
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
func (r *shareRepository) CountAll(options ...model.QueryOptions) (int64, error) {
|
||||
return r.count(r.selectShare(), options...)
|
||||
}
|
||||
|
||||
func (r *shareRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *shareRepository) EntityName() string {
|
||||
return "share"
|
||||
}
|
||||
|
||||
func (r *shareRepository) NewInstance() interface{} {
|
||||
return &model.Share{}
|
||||
}
|
||||
|
||||
func (r *shareRepository) Read(id string) (interface{}, error) {
|
||||
sel := r.selectShare().Where(Eq{"share.id": id})
|
||||
var res model.Share
|
||||
err := r.queryOne(sel, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *shareRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
sq := r.selectShare(r.parseRestOptions(r.ctx, options...))
|
||||
res := model.Shares{}
|
||||
err := r.queryAll(sq, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
var _ model.ShareRepository = (*shareRepository)(nil)
|
||||
var _ rest.Repository = (*shareRepository)(nil)
|
||||
var _ rest.Persistable = (*shareRepository)(nil)
|
||||
133
persistence/share_repository_test.go
Normal file
133
persistence/share_repository_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/navidrome/navidrome/conf/configtest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ShareRepository", func() {
|
||||
var repo model.ShareRepository
|
||||
var ctx context.Context
|
||||
var adminUser = model.User{ID: "admin", UserName: "admin", IsAdmin: true}
|
||||
|
||||
BeforeEach(func() {
|
||||
DeferCleanup(configtest.SetupConfig())
|
||||
ctx = request.WithUser(log.NewContext(context.TODO()), adminUser)
|
||||
repo = NewShareRepository(ctx, GetDBXBuilder())
|
||||
|
||||
// Insert the admin user into the database (required for foreign key constraint)
|
||||
ur := NewUserRepository(ctx, GetDBXBuilder())
|
||||
err := ur.Put(&adminUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Clean up shares
|
||||
db := GetDBXBuilder()
|
||||
_, err = db.NewQuery("DELETE FROM share").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Describe("Headless Access", func() {
|
||||
Context("Repository creation and basic operations", func() {
|
||||
It("should create repository successfully with no user context", func() {
|
||||
// Create repository with no user context (headless)
|
||||
headlessRepo := NewShareRepository(context.Background(), GetDBXBuilder())
|
||||
Expect(headlessRepo).ToNot(BeNil())
|
||||
})
|
||||
|
||||
It("should handle GetAll for headless processes", func() {
|
||||
// Create a simple share directly in database
|
||||
shareID := "headless-test-share"
|
||||
_, err := GetDBXBuilder().NewQuery(`
|
||||
INSERT INTO share (id, user_id, description, resource_type, resource_ids, created_at, updated_at)
|
||||
VALUES ({:id}, {:user}, {:desc}, {:type}, {:ids}, {:created}, {:updated})
|
||||
`).Bind(map[string]interface{}{
|
||||
"id": shareID,
|
||||
"user": adminUser.ID,
|
||||
"desc": "Headless Test Share",
|
||||
"type": "song",
|
||||
"ids": "song-1",
|
||||
"created": time.Now(),
|
||||
"updated": time.Now(),
|
||||
}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Headless process should see all shares
|
||||
headlessRepo := NewShareRepository(context.Background(), GetDBXBuilder())
|
||||
shares, err := headlessRepo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
found := false
|
||||
for _, s := range shares {
|
||||
if s.ID == shareID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(found).To(BeTrue(), "Headless process should see all shares")
|
||||
})
|
||||
|
||||
It("should handle individual share retrieval for headless processes", func() {
|
||||
// Create a simple share
|
||||
shareID := "headless-get-share"
|
||||
_, err := GetDBXBuilder().NewQuery(`
|
||||
INSERT INTO share (id, user_id, description, resource_type, resource_ids, created_at, updated_at)
|
||||
VALUES ({:id}, {:user}, {:desc}, {:type}, {:ids}, {:created}, {:updated})
|
||||
`).Bind(map[string]interface{}{
|
||||
"id": shareID,
|
||||
"user": adminUser.ID,
|
||||
"desc": "Headless Get Share",
|
||||
"type": "song",
|
||||
"ids": "song-2",
|
||||
"created": time.Now(),
|
||||
"updated": time.Now(),
|
||||
}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Headless process should be able to get the share
|
||||
headlessRepo := NewShareRepository(context.Background(), GetDBXBuilder())
|
||||
share, err := headlessRepo.Get(shareID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(share.ID).To(Equal(shareID))
|
||||
Expect(share.Description).To(Equal("Headless Get Share"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SQL ambiguity fix verification", func() {
|
||||
It("should handle share operations without SQL ambiguity errors", func() {
|
||||
// This test verifies that the loadMedia function doesn't cause SQL ambiguity
|
||||
// The key fix was using "album.id" instead of "id" in the album query filters
|
||||
|
||||
// Create a share that would trigger the loadMedia function
|
||||
shareID := "sql-test-share"
|
||||
_, err := GetDBXBuilder().NewQuery(`
|
||||
INSERT INTO share (id, user_id, description, resource_type, resource_ids, created_at, updated_at)
|
||||
VALUES ({:id}, {:user}, {:desc}, {:type}, {:ids}, {:created}, {:updated})
|
||||
`).Bind(map[string]interface{}{
|
||||
"id": shareID,
|
||||
"user": adminUser.ID,
|
||||
"desc": "SQL Test Share",
|
||||
"type": "album",
|
||||
"ids": "non-existent-album", // Won't find albums, but shouldn't cause SQL errors
|
||||
"created": time.Now(),
|
||||
"updated": time.Now(),
|
||||
}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// The Get operation should work without SQL ambiguity errors
|
||||
// even if no albums are found
|
||||
share, err := repo.Get(shareID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(share.ID).To(Equal(shareID))
|
||||
// Albums array should be empty since we used non-existent album ID
|
||||
Expect(share.Albums).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
130
persistence/sql_annotations.go
Normal file
130
persistence/sql_annotations.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/consts"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
)
|
||||
|
||||
const annotationTable = "annotation"
|
||||
|
||||
func (r sqlRepository) withAnnotation(query SelectBuilder, idField string) SelectBuilder {
|
||||
userID := loggedUser(r.ctx).ID
|
||||
if userID == invalidUserId {
|
||||
return query
|
||||
}
|
||||
query = query.
|
||||
LeftJoin("annotation on ("+
|
||||
"annotation.item_id = "+idField+
|
||||
" AND annotation.user_id = '"+userID+"')").
|
||||
Columns(
|
||||
"coalesce(starred, 0) as starred",
|
||||
"coalesce(rating, 0) as rating",
|
||||
"starred_at",
|
||||
"play_date",
|
||||
"rated_at",
|
||||
)
|
||||
if conf.Server.AlbumPlayCountMode == consts.AlbumPlayCountModeNormalized && r.tableName == "album" {
|
||||
query = query.Columns(
|
||||
fmt.Sprintf("round(coalesce(round(cast(play_count as float) / coalesce(%[1]s.song_count, 1), 1), 0)) as play_count", r.tableName),
|
||||
)
|
||||
} else {
|
||||
query = query.Columns("coalesce(play_count, 0) as play_count")
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (r sqlRepository) annId(itemID ...string) And {
|
||||
userID := loggedUser(r.ctx).ID
|
||||
return And{
|
||||
Eq{annotationTable + ".user_id": userID},
|
||||
Eq{annotationTable + ".item_type": r.tableName},
|
||||
Eq{annotationTable + ".item_id": itemID},
|
||||
}
|
||||
}
|
||||
|
||||
func (r sqlRepository) annUpsert(values map[string]interface{}, itemIDs ...string) error {
|
||||
upd := Update(annotationTable).Where(r.annId(itemIDs...))
|
||||
for f, v := range values {
|
||||
upd = upd.Set(f, v)
|
||||
}
|
||||
c, err := r.executeSQL(upd)
|
||||
if c == 0 || errors.Is(err, sql.ErrNoRows) {
|
||||
userID := loggedUser(r.ctx).ID
|
||||
for _, itemID := range itemIDs {
|
||||
values["user_id"] = userID
|
||||
values["item_type"] = r.tableName
|
||||
values["item_id"] = itemID
|
||||
ins := Insert(annotationTable).SetMap(values)
|
||||
_, err = r.executeSQL(ins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) SetStar(starred bool, ids ...string) error {
|
||||
starredAt := time.Now()
|
||||
return r.annUpsert(map[string]interface{}{"starred": starred, "starred_at": starredAt}, ids...)
|
||||
}
|
||||
|
||||
func (r sqlRepository) SetRating(rating int, itemID string) error {
|
||||
ratedAt := time.Now()
|
||||
return r.annUpsert(map[string]interface{}{"rating": rating, "rated_at": ratedAt}, itemID)
|
||||
}
|
||||
|
||||
func (r sqlRepository) IncPlayCount(itemID string, ts time.Time) error {
|
||||
upd := Update(annotationTable).Where(r.annId(itemID)).
|
||||
Set("play_count", Expr("play_count+1")).
|
||||
Set("play_date", Expr("max(ifnull(play_date,''),?)", ts))
|
||||
c, err := r.executeSQL(upd)
|
||||
|
||||
if c == 0 || errors.Is(err, sql.ErrNoRows) {
|
||||
userID := loggedUser(r.ctx).ID
|
||||
values := map[string]interface{}{}
|
||||
values["user_id"] = userID
|
||||
values["item_type"] = r.tableName
|
||||
values["item_id"] = itemID
|
||||
values["play_count"] = 1
|
||||
values["play_date"] = ts
|
||||
ins := Insert(annotationTable).SetMap(values)
|
||||
_, err = r.executeSQL(ins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) ReassignAnnotation(prevID string, newID string) error {
|
||||
if prevID == newID || prevID == "" || newID == "" {
|
||||
return nil
|
||||
}
|
||||
upd := Update(annotationTable).Where(And{
|
||||
Eq{annotationTable + ".item_type": r.tableName},
|
||||
Eq{annotationTable + ".item_id": prevID},
|
||||
}).Set("item_id", newID)
|
||||
_, err := r.executeSQL(upd)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) cleanAnnotations() error {
|
||||
del := Delete(annotationTable).Where(Eq{"item_type": r.tableName}).Where("item_id not in (select id from " + r.tableName + ")")
|
||||
c, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error cleaning up %s annotations: %w", r.tableName, err)
|
||||
}
|
||||
if c > 0 {
|
||||
log.Debug(r.ctx, "Clean-up annotations", "table", r.tableName, "totalDeleted", c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
470
persistence/sql_base_repository.go
Normal file
470
persistence/sql_base_repository.go
Normal file
@@ -0,0 +1,470 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"iter"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
id2 "github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
"github.com/navidrome/navidrome/utils/hasher"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// sqlRepository is the base repository for all SQL repositories. It provides common functions to interact with the DB.
|
||||
// When creating a new repository using this base, you must:
|
||||
//
|
||||
// - Embed this struct.
|
||||
// - Set ctx and db fields. ctx should be the context passed to the constructor method, usually obtained from the request
|
||||
// - Call registerModel with the model instance and any possible filters.
|
||||
// - If the model has a different table name than the default (lowercase of the model name), it should be set manually
|
||||
// using the tableName field.
|
||||
// - Sort mappings must be set with setSortMappings method. If a sort field is not in the map, it will be used as the name of the column.
|
||||
//
|
||||
// All fields in filters and sortMappings must be in snake_case. Only sorts and filters based on real field names or
|
||||
// defined in the mappings will be allowed.
|
||||
type sqlRepository struct {
|
||||
ctx context.Context
|
||||
tableName string
|
||||
db dbx.Builder
|
||||
|
||||
// Do not set these fields manually, they are set by the registerModel method
|
||||
filterMappings map[string]filterFunc
|
||||
isFieldWhiteListed fieldWhiteListedFunc
|
||||
// Do not set this field manually, it is set by the setSortMappings method
|
||||
sortMappings map[string]string
|
||||
}
|
||||
|
||||
const invalidUserId = "-1"
|
||||
|
||||
func loggedUser(ctx context.Context) *model.User {
|
||||
if user, ok := request.UserFrom(ctx); !ok {
|
||||
return &model.User{ID: invalidUserId}
|
||||
} else {
|
||||
return &user
|
||||
}
|
||||
}
|
||||
|
||||
func (r *sqlRepository) registerModel(instance any, filters map[string]filterFunc) {
|
||||
if r.tableName == "" {
|
||||
r.tableName = strings.TrimPrefix(reflect.TypeOf(instance).String(), "*model.")
|
||||
r.tableName = toSnakeCase(r.tableName)
|
||||
}
|
||||
r.tableName = strings.ToLower(r.tableName)
|
||||
r.isFieldWhiteListed = registerModelWhiteList(instance)
|
||||
r.filterMappings = filters
|
||||
}
|
||||
|
||||
// setSortMappings sets the mappings for the sort fields. If the sort field is not in the map, it will be used as is.
|
||||
//
|
||||
// If PreferSortTags is enabled, it will map the order fields to the corresponding sort expression,
|
||||
// which gives precedence to sort tags.
|
||||
// Ex: order_title => (coalesce(nullif(sort_title,”),order_title) collate nocase)
|
||||
// To avoid performance issues, indexes should be created for these sort expressions
|
||||
//
|
||||
// NOTE: if an individual item has spaces, it should be wrapped in parentheses. For example,
|
||||
// you should write "(lyrics != '[]')". This prevents the item being split unexpectedly.
|
||||
// Without parentheses, "lyrics != '[]'" would be mapped as simply "lyrics"
|
||||
func (r *sqlRepository) setSortMappings(mappings map[string]string, tableName ...string) {
|
||||
tn := r.tableName
|
||||
if len(tableName) > 0 {
|
||||
tn = tableName[0]
|
||||
}
|
||||
if conf.Server.PreferSortTags {
|
||||
for k, v := range mappings {
|
||||
v = mapSortOrder(tn, v)
|
||||
mappings[k] = v
|
||||
}
|
||||
}
|
||||
r.sortMappings = mappings
|
||||
}
|
||||
|
||||
func (r sqlRepository) newSelect(options ...model.QueryOptions) SelectBuilder {
|
||||
sq := Select().From(r.tableName)
|
||||
if len(options) > 0 {
|
||||
r.resetSeededRandom(options)
|
||||
sq = r.applyOptions(sq, options...)
|
||||
sq = r.applyFilters(sq, options...)
|
||||
}
|
||||
return sq
|
||||
}
|
||||
|
||||
func (r sqlRepository) applyOptions(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
|
||||
if len(options) > 0 {
|
||||
if options[0].Max > 0 {
|
||||
sq = sq.Limit(uint64(options[0].Max))
|
||||
}
|
||||
if options[0].Offset > 0 {
|
||||
sq = sq.Offset(uint64(options[0].Offset))
|
||||
}
|
||||
if options[0].Sort != "" {
|
||||
sq = sq.OrderBy(r.buildSortOrder(options[0].Sort, options[0].Order))
|
||||
}
|
||||
}
|
||||
return sq
|
||||
}
|
||||
|
||||
// TODO Change all sortMappings to have a consistent case
|
||||
func (r sqlRepository) sortMapping(sort string) string {
|
||||
if mapping, ok := r.sortMappings[sort]; ok {
|
||||
return mapping
|
||||
}
|
||||
if mapping, ok := r.sortMappings[toCamelCase(sort)]; ok {
|
||||
return mapping
|
||||
}
|
||||
sort = toSnakeCase(sort)
|
||||
if mapping, ok := r.sortMappings[sort]; ok {
|
||||
return mapping
|
||||
}
|
||||
return sort
|
||||
}
|
||||
|
||||
func (r sqlRepository) buildSortOrder(sort, order string) string {
|
||||
sort = r.sortMapping(sort)
|
||||
order = strings.ToLower(strings.TrimSpace(order))
|
||||
var reverseOrder string
|
||||
if order == "desc" {
|
||||
reverseOrder = "asc"
|
||||
} else {
|
||||
order = "asc"
|
||||
reverseOrder = "desc"
|
||||
}
|
||||
|
||||
parts := strings.FieldsFunc(sort, splitFunc(','))
|
||||
newSort := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
f := strings.FieldsFunc(p, splitFunc(' '))
|
||||
newField := make([]string, 1, len(f))
|
||||
newField[0] = f[0]
|
||||
if len(f) == 1 {
|
||||
newField = append(newField, order)
|
||||
} else {
|
||||
if f[1] == "asc" {
|
||||
newField = append(newField, order)
|
||||
} else {
|
||||
newField = append(newField, reverseOrder)
|
||||
}
|
||||
}
|
||||
newSort = append(newSort, strings.Join(newField, " "))
|
||||
}
|
||||
return strings.Join(newSort, ", ")
|
||||
}
|
||||
|
||||
func splitFunc(delimiter rune) func(c rune) bool {
|
||||
open := 0
|
||||
return func(c rune) bool {
|
||||
if c == '(' {
|
||||
open++
|
||||
return false
|
||||
}
|
||||
if open > 0 {
|
||||
if c == ')' {
|
||||
open--
|
||||
}
|
||||
return false
|
||||
}
|
||||
return c == delimiter
|
||||
}
|
||||
}
|
||||
|
||||
func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
|
||||
if len(options) > 0 && options[0].Filters != nil {
|
||||
sq = sq.Where(options[0].Filters)
|
||||
}
|
||||
return sq
|
||||
}
|
||||
|
||||
func (r *sqlRepository) withTableName(filter filterFunc) filterFunc {
|
||||
return func(field string, value any) Sqlizer {
|
||||
if r.tableName != "" {
|
||||
field = r.tableName + "." + field
|
||||
}
|
||||
return filter(field, value)
|
||||
}
|
||||
}
|
||||
|
||||
// libraryIdFilter is a filter function to be added to resources that have a library_id column.
|
||||
func libraryIdFilter(_ string, value interface{}) Sqlizer {
|
||||
return Eq{"library_id": value}
|
||||
}
|
||||
|
||||
// applyLibraryFilter adds library filtering to queries for tables that have a library_id column
|
||||
// This ensures users only see content from libraries they have access to
|
||||
func (r sqlRepository) applyLibraryFilter(sq SelectBuilder, tableName ...string) SelectBuilder {
|
||||
user := loggedUser(r.ctx)
|
||||
|
||||
// If the user is an admin, or the user ID is invalid (e.g., when no user is logged in), skip the library filter
|
||||
if user.IsAdmin || user.ID == invalidUserId {
|
||||
return sq
|
||||
}
|
||||
|
||||
table := r.tableName
|
||||
if len(tableName) > 0 {
|
||||
table = tableName[0]
|
||||
}
|
||||
|
||||
// Get user's accessible library IDs
|
||||
// Use subquery to filter by user's library access
|
||||
return sq.Where(Expr(table+".library_id IN ("+
|
||||
"SELECT ul.library_id FROM user_library ul WHERE ul.user_id = ?)", user.ID))
|
||||
}
|
||||
|
||||
func (r sqlRepository) seedKey() string {
|
||||
// Seed keys must be all lowercase, or else SQLite3 will encode it, making it not match the seed
|
||||
// used in the query. Hashing the user ID and converting it to a hex string will do the trick
|
||||
userIDHash := md5.Sum([]byte(loggedUser(r.ctx).ID))
|
||||
return fmt.Sprintf("%s|%x", r.tableName, userIDHash)
|
||||
}
|
||||
|
||||
func (r sqlRepository) resetSeededRandom(options []model.QueryOptions) {
|
||||
if len(options) == 0 || options[0].Sort != "random" {
|
||||
return
|
||||
}
|
||||
options[0].Sort = fmt.Sprintf("SEEDEDRAND('%s', %s.id)", r.seedKey(), r.tableName)
|
||||
if options[0].Seed != "" {
|
||||
hasher.SetSeed(r.seedKey(), options[0].Seed)
|
||||
return
|
||||
}
|
||||
if options[0].Offset == 0 {
|
||||
hasher.Reseed(r.seedKey())
|
||||
}
|
||||
}
|
||||
|
||||
func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
start := time.Now()
|
||||
var c int64
|
||||
res, err := r.db.NewQuery(query).Bind(args).WithContext(r.ctx).Execute()
|
||||
if res != nil {
|
||||
c, _ = res.RowsAffected()
|
||||
}
|
||||
r.logSQL(query, args, err, c, start)
|
||||
if err != nil {
|
||||
if err.Error() != "LastInsertId is not supported by this driver" {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
var placeholderRegex = regexp.MustCompile(`\?`)
|
||||
|
||||
func (r sqlRepository) toSQL(sq Sqlizer) (string, dbx.Params, error) {
|
||||
query, args, err := sq.ToSql()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
// Replace query placeholders with named params
|
||||
params := make(dbx.Params, len(args))
|
||||
counter := 0
|
||||
result := placeholderRegex.ReplaceAllStringFunc(query, func(_ string) string {
|
||||
p := fmt.Sprintf("p%d", counter)
|
||||
params[p] = args[counter]
|
||||
counter++
|
||||
return "{:" + p + "}"
|
||||
})
|
||||
return result, params, nil
|
||||
}
|
||||
|
||||
func (r sqlRepository) queryOne(sq Sqlizer, response interface{}) error {
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).One(response)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
r.logSQL(query, args, nil, 0, start)
|
||||
return model.ErrNotFound
|
||||
}
|
||||
r.logSQL(query, args, err, 1, start)
|
||||
return err
|
||||
}
|
||||
|
||||
// queryWithStableResults is a helper function to execute a query and return an iterator that will yield its results
|
||||
// from a cursor, guaranteeing that the results will be stable, even if the underlying data changes.
|
||||
func queryWithStableResults[T any](r sqlRepository, sq SelectBuilder, options ...model.QueryOptions) (iter.Seq2[T, error], error) {
|
||||
if len(options) > 0 && options[0].Offset > 0 {
|
||||
sq = r.optimizePagination(sq, options[0])
|
||||
}
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
start := time.Now()
|
||||
rows, err := r.db.NewQuery(query).Bind(args).WithContext(r.ctx).Rows()
|
||||
r.logSQL(query, args, err, -1, start)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(yield func(T, error) bool) {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var row T
|
||||
err := rows.ScanStruct(&row)
|
||||
if !yield(row, err) || err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
var empty T
|
||||
yield(empty, err)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r sqlRepository) queryAll(sq SelectBuilder, response interface{}, options ...model.QueryOptions) error {
|
||||
if len(options) > 0 && options[0].Offset > 0 {
|
||||
sq = r.optimizePagination(sq, options[0])
|
||||
}
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).All(response)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
r.logSQL(query, args, nil, -1, start)
|
||||
return model.ErrNotFound
|
||||
}
|
||||
r.logSQL(query, args, err, int64(reflect.ValueOf(response).Elem().Len()), start)
|
||||
return err
|
||||
}
|
||||
|
||||
// queryAllSlice is a helper function to query a single column and return the result in a slice
|
||||
func (r sqlRepository) queryAllSlice(sq SelectBuilder, response interface{}) error {
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
err = r.db.NewQuery(query).Bind(args).WithContext(r.ctx).Column(response)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
r.logSQL(query, args, nil, -1, start)
|
||||
return model.ErrNotFound
|
||||
}
|
||||
r.logSQL(query, args, err, int64(reflect.ValueOf(response).Elem().Len()), start)
|
||||
return err
|
||||
}
|
||||
|
||||
// optimizePagination uses a less inefficient pagination, by not using OFFSET.
|
||||
// See https://gist.github.com/ssokolow/262503
|
||||
func (r sqlRepository) optimizePagination(sq SelectBuilder, options model.QueryOptions) SelectBuilder {
|
||||
if options.Offset > conf.Server.DevOffsetOptimize {
|
||||
sq = sq.RemoveOffset()
|
||||
rowidSq := sq.RemoveColumns().Columns(r.tableName + ".rowid")
|
||||
rowidSq = rowidSq.Limit(uint64(options.Offset))
|
||||
rowidSql, args, _ := rowidSq.ToSql()
|
||||
sq = sq.Where(r.tableName+".rowid not in ("+rowidSql+")", args...)
|
||||
}
|
||||
return sq
|
||||
}
|
||||
|
||||
func (r sqlRepository) exists(cond Sqlizer) (bool, error) {
|
||||
existsQuery := Select("count(*) as exist").From(r.tableName).Where(cond)
|
||||
var res struct{ Exist int64 }
|
||||
err := r.queryOne(existsQuery, &res)
|
||||
return res.Exist > 0, err
|
||||
}
|
||||
|
||||
func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOptions) (int64, error) {
|
||||
countQuery = countQuery.
|
||||
RemoveColumns().Columns("count(distinct " + r.tableName + ".id) as count").
|
||||
RemoveOffset().RemoveLimit().
|
||||
OrderBy(r.tableName + ".id"). // To remove any ORDER BY clause that could slow down the query
|
||||
From(r.tableName)
|
||||
countQuery = r.applyFilters(countQuery, options...)
|
||||
var res struct{ Count int64 }
|
||||
err := r.queryOne(countQuery, &res)
|
||||
return res.Count, err
|
||||
}
|
||||
|
||||
func (r sqlRepository) putByMatch(filter Sqlizer, id string, m interface{}, colsToUpdate ...string) (string, error) {
|
||||
if id != "" {
|
||||
return r.put(id, m, colsToUpdate...)
|
||||
}
|
||||
existsQuery := r.newSelect().Columns("id").From(r.tableName).Where(filter)
|
||||
|
||||
var res struct{ ID string }
|
||||
err := r.queryOne(existsQuery, &res)
|
||||
if err != nil && !errors.Is(err, model.ErrNotFound) {
|
||||
return "", err
|
||||
}
|
||||
return r.put(res.ID, m, colsToUpdate...)
|
||||
}
|
||||
|
||||
func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (newId string, err error) {
|
||||
values, err := toSQLArgs(m)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error preparing values to write to DB: %w", err)
|
||||
}
|
||||
// If there's an ID, try to update first
|
||||
if id != "" {
|
||||
updateValues := map[string]interface{}{}
|
||||
|
||||
// This is a map of the columns that need to be updated, if specified
|
||||
c2upd := slice.ToMap(colsToUpdate, func(s string) (string, struct{}) {
|
||||
return toSnakeCase(s), struct{}{}
|
||||
})
|
||||
for k, v := range values {
|
||||
if _, found := c2upd[k]; len(c2upd) == 0 || found {
|
||||
updateValues[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
updateValues["id"] = id
|
||||
delete(updateValues, "created_at")
|
||||
// To avoid updating the media_file birth_time on each scan. Not the best solution, but it works for now
|
||||
// TODO move to mediafile_repository when each repo has its own upsert method
|
||||
delete(updateValues, "birth_time")
|
||||
update := Update(r.tableName).Where(Eq{"id": id}).SetMap(updateValues)
|
||||
count, err := r.executeSQL(update)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if count > 0 {
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
// If it does not have an ID OR the ID was not found (when it is a new record with predefined id)
|
||||
if id == "" {
|
||||
id = id2.NewRandom()
|
||||
values["id"] = id
|
||||
}
|
||||
insert := Insert(r.tableName).SetMap(values)
|
||||
_, err = r.executeSQL(insert)
|
||||
return id, err
|
||||
}
|
||||
|
||||
func (r sqlRepository) delete(cond Sqlizer) error {
|
||||
del := Delete(r.tableName).Where(cond)
|
||||
_, err := r.executeSQL(del)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return model.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) logSQL(sql string, args dbx.Params, err error, rowsAffected int64, start time.Time) {
|
||||
elapsed := time.Since(start)
|
||||
if err == nil || errors.Is(err, context.Canceled) {
|
||||
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
|
||||
} else {
|
||||
log.Error(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
|
||||
}
|
||||
}
|
||||
284
persistence/sql_base_repository_test.go
Normal file
284
persistence/sql_base_repository_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
"github.com/navidrome/navidrome/utils/hasher"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("sqlRepository", func() {
|
||||
var r sqlRepository
|
||||
BeforeEach(func() {
|
||||
r.ctx = request.WithUser(context.Background(), model.User{ID: "user-id"})
|
||||
r.tableName = "table"
|
||||
})
|
||||
|
||||
Describe("applyOptions", func() {
|
||||
var sq squirrel.SelectBuilder
|
||||
BeforeEach(func() {
|
||||
sq = squirrel.Select("*").From("test")
|
||||
r.sortMappings = map[string]string{
|
||||
"name": "title",
|
||||
}
|
||||
})
|
||||
It("does not add any clauses when options is empty", func() {
|
||||
sq = r.applyOptions(sq, model.QueryOptions{})
|
||||
sql, _, _ := sq.ToSql()
|
||||
Expect(sql).To(Equal("SELECT * FROM test"))
|
||||
})
|
||||
It("adds all option clauses", func() {
|
||||
sq = r.applyOptions(sq, model.QueryOptions{
|
||||
Sort: "name",
|
||||
Order: "desc",
|
||||
Max: 1,
|
||||
Offset: 2,
|
||||
})
|
||||
sql, _, _ := sq.ToSql()
|
||||
Expect(sql).To(Equal("SELECT * FROM test ORDER BY title desc LIMIT 1 OFFSET 2"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("toSQL", func() {
|
||||
It("returns error for invalid SQL", func() {
|
||||
sq := squirrel.Select("*").From("test").Where(1)
|
||||
_, _, err := r.toSQL(sq)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("returns the same query when there are no placeholders", func() {
|
||||
sq := squirrel.Select("*").From("test")
|
||||
query, params, err := r.toSQL(sq)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(query).To(Equal("SELECT * FROM test"))
|
||||
Expect(params).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("replaces one placeholder correctly", func() {
|
||||
sq := squirrel.Select("*").From("test").Where(squirrel.Eq{"id": 1})
|
||||
query, params, err := r.toSQL(sq)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(query).To(Equal("SELECT * FROM test WHERE id = {:p0}"))
|
||||
Expect(params).To(HaveKeyWithValue("p0", 1))
|
||||
})
|
||||
|
||||
It("replaces multiple placeholders correctly", func() {
|
||||
sq := squirrel.Select("*").From("test").Where(squirrel.Eq{"id": 1, "name": "test"})
|
||||
query, params, err := r.toSQL(sq)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(query).To(Equal("SELECT * FROM test WHERE id = {:p0} AND name = {:p1}"))
|
||||
Expect(params).To(HaveKeyWithValue("p0", 1))
|
||||
Expect(params).To(HaveKeyWithValue("p1", "test"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("sanitizeSort", func() {
|
||||
BeforeEach(func() {
|
||||
r.registerModel(&struct {
|
||||
Field string `structs:"field"`
|
||||
}{}, nil)
|
||||
r.sortMappings = map[string]string{
|
||||
"sort1": "mappedSort1",
|
||||
}
|
||||
})
|
||||
|
||||
When("sanitizing sort", func() {
|
||||
It("returns empty if the sort key is not found in the model nor in the mappings", func() {
|
||||
sort, _ := r.sanitizeSort("unknown", "")
|
||||
Expect(sort).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns the mapped value when sort key exists", func() {
|
||||
sort, _ := r.sanitizeSort("sort1", "")
|
||||
Expect(sort).To(Equal("mappedSort1"))
|
||||
})
|
||||
|
||||
It("is case insensitive", func() {
|
||||
sort, _ := r.sanitizeSort("Sort1", "")
|
||||
Expect(sort).To(Equal("mappedSort1"))
|
||||
})
|
||||
|
||||
It("returns the field if it is a valid field", func() {
|
||||
sort, _ := r.sanitizeSort("field", "")
|
||||
Expect(sort).To(Equal("field"))
|
||||
})
|
||||
|
||||
It("is case insensitive for fields", func() {
|
||||
sort, _ := r.sanitizeSort("FIELD", "")
|
||||
Expect(sort).To(Equal("field"))
|
||||
})
|
||||
})
|
||||
When("sanitizing order", func() {
|
||||
It("returns 'asc' if order is empty", func() {
|
||||
_, order := r.sanitizeSort("", "")
|
||||
Expect(order).To(Equal(""))
|
||||
})
|
||||
|
||||
It("returns 'asc' if order is 'asc'", func() {
|
||||
_, order := r.sanitizeSort("", "ASC")
|
||||
Expect(order).To(Equal("asc"))
|
||||
})
|
||||
|
||||
It("returns 'desc' if order is 'desc'", func() {
|
||||
_, order := r.sanitizeSort("", "desc")
|
||||
Expect(order).To(Equal("desc"))
|
||||
})
|
||||
|
||||
It("returns 'asc' if order is unknown", func() {
|
||||
_, order := r.sanitizeSort("", "something")
|
||||
Expect(order).To(Equal("asc"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("buildSortOrder", func() {
|
||||
BeforeEach(func() {
|
||||
r.sortMappings = map[string]string{}
|
||||
})
|
||||
|
||||
Context("single field", func() {
|
||||
It("sorts by specified field", func() {
|
||||
sql := r.buildSortOrder("name", "desc")
|
||||
Expect(sql).To(Equal("name desc"))
|
||||
})
|
||||
It("defaults to 'asc'", func() {
|
||||
sql := r.buildSortOrder("name", "")
|
||||
Expect(sql).To(Equal("name asc"))
|
||||
})
|
||||
It("inverts pre-defined order", func() {
|
||||
sql := r.buildSortOrder("name desc", "desc")
|
||||
Expect(sql).To(Equal("name asc"))
|
||||
})
|
||||
It("forces snake case for field names", func() {
|
||||
sql := r.buildSortOrder("AlbumArtist", "asc")
|
||||
Expect(sql).To(Equal("album_artist asc"))
|
||||
})
|
||||
})
|
||||
Context("multiple fields", func() {
|
||||
It("handles multiple fields", func() {
|
||||
sql := r.buildSortOrder("name desc,age asc, status desc ", "asc")
|
||||
Expect(sql).To(Equal("name desc, age asc, status desc"))
|
||||
})
|
||||
It("inverts multiple fields", func() {
|
||||
sql := r.buildSortOrder("name desc, age, status asc", "desc")
|
||||
Expect(sql).To(Equal("name asc, age desc, status desc"))
|
||||
})
|
||||
It("handles spaces in mapped field", func() {
|
||||
r.sortMappings = map[string]string{
|
||||
"has_lyrics": "(lyrics != '[]'), updated_at",
|
||||
}
|
||||
sql := r.buildSortOrder("has_lyrics", "desc")
|
||||
Expect(sql).To(Equal("(lyrics != '[]') desc, updated_at desc"))
|
||||
})
|
||||
|
||||
})
|
||||
Context("function fields", func() {
|
||||
It("handles functions with multiple params", func() {
|
||||
sql := r.buildSortOrder("substr(id, 7)", "asc")
|
||||
Expect(sql).To(Equal("substr(id, 7) asc"))
|
||||
})
|
||||
It("handles functions with multiple params mixed with multiple fields", func() {
|
||||
sql := r.buildSortOrder("name desc, substr(id, 7), status asc", "desc")
|
||||
Expect(sql).To(Equal("name asc, substr(id, 7) desc, status desc"))
|
||||
})
|
||||
It("handles nested functions", func() {
|
||||
sql := r.buildSortOrder("name desc, coalesce(nullif(release_date, ''), nullif(original_date, '')), status asc", "desc")
|
||||
Expect(sql).To(Equal("name asc, coalesce(nullif(release_date, ''), nullif(original_date, '')) desc, status desc"))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("resetSeededRandom", func() {
|
||||
var id string
|
||||
BeforeEach(func() {
|
||||
id = r.seedKey()
|
||||
hasher.SetSeed(id, "")
|
||||
})
|
||||
It("does not reset seed if sort is not random", func() {
|
||||
var options []model.QueryOptions
|
||||
r.resetSeededRandom(options)
|
||||
Expect(hasher.CurrentSeed(id)).To(BeEmpty())
|
||||
})
|
||||
It("resets seed if sort is random", func() {
|
||||
options := []model.QueryOptions{{Sort: "random"}}
|
||||
r.resetSeededRandom(options)
|
||||
Expect(hasher.CurrentSeed(id)).NotTo(BeEmpty())
|
||||
})
|
||||
It("resets seed if sort is random and seed is provided", func() {
|
||||
options := []model.QueryOptions{{Sort: "random", Seed: "seed"}}
|
||||
r.resetSeededRandom(options)
|
||||
Expect(hasher.CurrentSeed(id)).To(Equal("seed"))
|
||||
})
|
||||
It("keeps seed when paginating", func() {
|
||||
options := []model.QueryOptions{{Sort: "random", Seed: "seed", Offset: 0}}
|
||||
r.resetSeededRandom(options)
|
||||
Expect(hasher.CurrentSeed(id)).To(Equal("seed"))
|
||||
|
||||
options = []model.QueryOptions{{Sort: "random", Offset: 1}}
|
||||
r.resetSeededRandom(options)
|
||||
Expect(hasher.CurrentSeed(id)).To(Equal("seed"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("applyLibraryFilter", func() {
|
||||
var sq squirrel.SelectBuilder
|
||||
|
||||
BeforeEach(func() {
|
||||
sq = squirrel.Select("*").From("test_table")
|
||||
})
|
||||
|
||||
Context("Admin User", func() {
|
||||
BeforeEach(func() {
|
||||
r.ctx = request.WithUser(context.Background(), model.User{ID: "admin", IsAdmin: true})
|
||||
})
|
||||
|
||||
It("should not apply library filter for admin users", func() {
|
||||
result := r.applyLibraryFilter(sq)
|
||||
sql, _, _ := result.ToSql()
|
||||
Expect(sql).To(Equal("SELECT * FROM test_table"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Regular User", func() {
|
||||
BeforeEach(func() {
|
||||
r.ctx = request.WithUser(context.Background(), model.User{ID: "user123", IsAdmin: false})
|
||||
})
|
||||
|
||||
It("should apply library filter for regular users", func() {
|
||||
result := r.applyLibraryFilter(sq)
|
||||
sql, args, _ := result.ToSql()
|
||||
Expect(sql).To(ContainSubstring("IN (SELECT ul.library_id FROM user_library ul WHERE ul.user_id = ?)"))
|
||||
Expect(args).To(ContainElement("user123"))
|
||||
})
|
||||
|
||||
It("should use custom table name when provided", func() {
|
||||
result := r.applyLibraryFilter(sq, "custom_table")
|
||||
sql, args, _ := result.ToSql()
|
||||
Expect(sql).To(ContainSubstring("custom_table.library_id IN"))
|
||||
Expect(args).To(ContainElement("user123"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Headless Process (No User Context)", func() {
|
||||
BeforeEach(func() {
|
||||
r.ctx = context.Background() // No user context
|
||||
})
|
||||
|
||||
It("should not apply library filter for headless processes", func() {
|
||||
result := r.applyLibraryFilter(sq)
|
||||
sql, _, _ := result.ToSql()
|
||||
Expect(sql).To(Equal("SELECT * FROM test_table"))
|
||||
})
|
||||
|
||||
It("should not apply library filter even with custom table name", func() {
|
||||
result := r.applyLibraryFilter(sq, "custom_table")
|
||||
sql, _, _ := result.ToSql()
|
||||
Expect(sql).To(Equal("SELECT * FROM test_table"))
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
157
persistence/sql_bookmarks.go
Normal file
157
persistence/sql_bookmarks.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
)
|
||||
|
||||
const bookmarkTable = "bookmark"
|
||||
|
||||
func (r sqlRepository) withBookmark(query SelectBuilder, idField string) SelectBuilder {
|
||||
userID := loggedUser(r.ctx).ID
|
||||
if userID == invalidUserId {
|
||||
return query
|
||||
}
|
||||
return query.
|
||||
LeftJoin("bookmark on (" +
|
||||
"bookmark.item_id = " + idField +
|
||||
" AND bookmark.user_id = '" + userID + "')").
|
||||
Columns("coalesce(position, 0) as bookmark_position")
|
||||
}
|
||||
|
||||
func (r sqlRepository) bmkID(itemID ...string) And {
|
||||
return And{
|
||||
Eq{bookmarkTable + ".user_id": loggedUser(r.ctx).ID},
|
||||
Eq{bookmarkTable + ".item_type": r.tableName},
|
||||
Eq{bookmarkTable + ".item_id": itemID},
|
||||
}
|
||||
}
|
||||
|
||||
func (r sqlRepository) bmkUpsert(itemID, comment string, position int64) error {
|
||||
client, _ := request.ClientFrom(r.ctx)
|
||||
user, _ := request.UserFrom(r.ctx)
|
||||
values := map[string]interface{}{
|
||||
"comment": comment,
|
||||
"position": position,
|
||||
"updated_at": time.Now(),
|
||||
"changed_by": client,
|
||||
}
|
||||
|
||||
upd := Update(bookmarkTable).Where(r.bmkID(itemID)).SetMap(values)
|
||||
c, err := r.executeSQL(upd)
|
||||
if err == nil {
|
||||
log.Debug(r.ctx, "Updated bookmark", "id", itemID, "user", user.UserName, "position", position, "comment", comment)
|
||||
}
|
||||
if c == 0 || errors.Is(err, sql.ErrNoRows) {
|
||||
values["user_id"] = user.ID
|
||||
values["item_type"] = r.tableName
|
||||
values["item_id"] = itemID
|
||||
values["created_at"] = time.Now()
|
||||
values["updated_at"] = time.Now()
|
||||
ins := Insert(bookmarkTable).SetMap(values)
|
||||
_, err = r.executeSQL(ins)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Debug(r.ctx, "Added bookmark", "id", itemID, "user", user.UserName, "position", position, "comment", comment)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) AddBookmark(id, comment string, position int64) error {
|
||||
user, _ := request.UserFrom(r.ctx)
|
||||
err := r.bmkUpsert(id, comment, position)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error adding bookmark", "id", id, "user", user.UserName, "position", position, "comment", comment)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) DeleteBookmark(id string) error {
|
||||
user, _ := request.UserFrom(r.ctx)
|
||||
del := Delete(bookmarkTable).Where(r.bmkID(id))
|
||||
_, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error removing bookmark", "id", id, "user", user.UserName)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type bookmark struct {
|
||||
UserID string `json:"user_id"`
|
||||
ItemID string `json:"item_id"`
|
||||
ItemType string `json:"item_type"`
|
||||
Comment string `json:"comment"`
|
||||
Position int64 `json:"position"`
|
||||
ChangedBy string `json:"changed_by"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
func (r sqlRepository) GetBookmarks() (model.Bookmarks, error) {
|
||||
user, _ := request.UserFrom(r.ctx)
|
||||
|
||||
idField := r.tableName + ".id"
|
||||
sq := r.newSelect().Columns(r.tableName + ".*")
|
||||
sq = r.withAnnotation(sq, idField)
|
||||
sq = r.withBookmark(sq, idField).Where(NotEq{bookmarkTable + ".item_id": nil})
|
||||
var mfs dbMediaFiles // TODO Decouple from media_file
|
||||
err := r.queryAll(sq, &mfs)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error getting mediafiles with bookmarks", "user", user.UserName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ids := make([]string, len(mfs))
|
||||
mfMap := make(map[string]int)
|
||||
for i, mf := range mfs {
|
||||
ids[i] = mf.ID
|
||||
mfMap[mf.ID] = i
|
||||
}
|
||||
|
||||
sq = Select("*").From(bookmarkTable).Where(r.bmkID(ids...))
|
||||
var bmks []bookmark
|
||||
err = r.queryAll(sq, &bmks)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error getting bookmarks", "user", user.UserName, "ids", ids, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := make(model.Bookmarks, len(bmks))
|
||||
for i, bmk := range bmks {
|
||||
if itemIdx, ok := mfMap[bmk.ItemID]; !ok {
|
||||
log.Debug(r.ctx, "Invalid bookmark", "id", bmk.ItemID, "user", user.UserName)
|
||||
continue
|
||||
} else {
|
||||
resp[i] = model.Bookmark{
|
||||
Comment: bmk.Comment,
|
||||
Position: bmk.Position,
|
||||
CreatedAt: bmk.CreatedAt,
|
||||
UpdatedAt: bmk.UpdatedAt,
|
||||
ChangedBy: bmk.ChangedBy,
|
||||
Item: *mfs[itemIdx].MediaFile,
|
||||
}
|
||||
}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (r sqlRepository) cleanBookmarks() error {
|
||||
del := Delete(bookmarkTable).Where(Eq{"item_type": r.tableName}).Where("item_id not in (select id from " + r.tableName + ")")
|
||||
c, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error cleaning up %s bookmarks: %w", r.tableName, err)
|
||||
}
|
||||
if c > 0 {
|
||||
log.Debug(r.ctx, "Clean-up bookmarks", "totalDeleted", c, "itemType", r.tableName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
74
persistence/sql_bookmarks_test.go
Normal file
74
persistence/sql_bookmarks_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("sqlBookmarks", func() {
|
||||
var mr model.MediaFileRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid"})
|
||||
mr = NewMediaFileRepository(ctx, GetDBXBuilder(), nil)
|
||||
})
|
||||
|
||||
Describe("Bookmarks", func() {
|
||||
It("returns an empty collection if there are no bookmarks", func() {
|
||||
Expect(mr.GetBookmarks()).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("saves and overrides bookmarks", func() {
|
||||
By("Saving the bookmark")
|
||||
Expect(mr.AddBookmark(songAntenna.ID, "this is a comment", 123)).To(BeNil())
|
||||
|
||||
bms, err := mr.GetBookmarks()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(bms).To(HaveLen(1))
|
||||
Expect(bms[0].Item.ID).To(Equal(songAntenna.ID))
|
||||
Expect(bms[0].Item.Title).To(Equal(songAntenna.Title))
|
||||
Expect(bms[0].Comment).To(Equal("this is a comment"))
|
||||
Expect(bms[0].Position).To(Equal(int64(123)))
|
||||
created := bms[0].CreatedAt
|
||||
updated := bms[0].UpdatedAt
|
||||
Expect(created.IsZero()).To(BeFalse())
|
||||
Expect(updated).To(BeTemporally(">=", created))
|
||||
|
||||
By("Overriding the bookmark")
|
||||
Expect(mr.AddBookmark(songAntenna.ID, "another comment", 333)).To(BeNil())
|
||||
|
||||
bms, err = mr.GetBookmarks()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(bms[0].Item.ID).To(Equal(songAntenna.ID))
|
||||
Expect(bms[0].Comment).To(Equal("another comment"))
|
||||
Expect(bms[0].Position).To(Equal(int64(333)))
|
||||
Expect(bms[0].CreatedAt).To(Equal(created))
|
||||
Expect(bms[0].UpdatedAt).To(BeTemporally(">=", updated))
|
||||
|
||||
By("Saving another bookmark")
|
||||
Expect(mr.AddBookmark(songComeTogether.ID, "one more comment", 444)).To(BeNil())
|
||||
bms, err = mr.GetBookmarks()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(bms).To(HaveLen(2))
|
||||
|
||||
By("Delete bookmark")
|
||||
Expect(mr.DeleteBookmark(songAntenna.ID)).To(Succeed())
|
||||
bms, err = mr.GetBookmarks()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(bms).To(HaveLen(1))
|
||||
Expect(bms[0].Item.ID).To(Equal(songComeTogether.ID))
|
||||
Expect(bms[0].Item.Title).To(Equal(songComeTogether.Title))
|
||||
|
||||
Expect(mr.DeleteBookmark(songComeTogether.ID)).To(Succeed())
|
||||
Expect(mr.GetBookmarks()).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
119
persistence/sql_participations.go
Normal file
119
persistence/sql_participations.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
)
|
||||
|
||||
type participant struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
SubRole string `json:"subRole,omitempty"`
|
||||
}
|
||||
|
||||
// flatParticipant represents a flattened participant structure for SQL processing
|
||||
type flatParticipant struct {
|
||||
ArtistID string `json:"artist_id"`
|
||||
Role string `json:"role"`
|
||||
SubRole string `json:"sub_role,omitempty"`
|
||||
}
|
||||
|
||||
func marshalParticipants(participants model.Participants) string {
|
||||
dbParticipants := make(map[model.Role][]participant)
|
||||
for role, artists := range participants {
|
||||
for _, artist := range artists {
|
||||
dbParticipants[role] = append(dbParticipants[role], participant{ID: artist.ID, SubRole: artist.SubRole, Name: artist.Name})
|
||||
}
|
||||
}
|
||||
res, _ := json.Marshal(dbParticipants)
|
||||
return string(res)
|
||||
}
|
||||
|
||||
func unmarshalParticipants(data string) (model.Participants, error) {
|
||||
var dbParticipants map[model.Role][]participant
|
||||
err := json.Unmarshal([]byte(data), &dbParticipants)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing participants: %w", err)
|
||||
}
|
||||
|
||||
participants := make(model.Participants, len(dbParticipants))
|
||||
for role, participantList := range dbParticipants {
|
||||
artists := slice.Map(participantList, func(p participant) model.Participant {
|
||||
return model.Participant{Artist: model.Artist{ID: p.ID, Name: p.Name}, SubRole: p.SubRole}
|
||||
})
|
||||
participants[role] = artists
|
||||
}
|
||||
return participants, nil
|
||||
}
|
||||
|
||||
func (r sqlRepository) updateParticipants(itemID string, participants model.Participants) error {
|
||||
ids := participants.AllIDs()
|
||||
sqd := Delete(r.tableName + "_artists").Where(And{Eq{r.tableName + "_id": itemID}, NotEq{"artist_id": ids}})
|
||||
_, err := r.executeSQL(sqd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(participants) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var flatParticipants []flatParticipant
|
||||
for role, artists := range participants {
|
||||
for _, artist := range artists {
|
||||
flatParticipants = append(flatParticipants, flatParticipant{
|
||||
ArtistID: artist.ID,
|
||||
Role: role.String(),
|
||||
SubRole: artist.SubRole,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
participantsJSON, err := json.Marshal(flatParticipants)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling participants: %w", err)
|
||||
}
|
||||
|
||||
// Build the INSERT query using json_each and INNER JOIN to artist table
|
||||
// to automatically filter out non-existent artist IDs
|
||||
query := fmt.Sprintf(`
|
||||
INSERT INTO %[1]s_artists (%[1]s_id, artist_id, role, sub_role)
|
||||
SELECT ?,
|
||||
json_extract(value, '$.artist_id') as artist_id,
|
||||
json_extract(value, '$.role') as role,
|
||||
COALESCE(json_extract(value, '$.sub_role'), '') as sub_role
|
||||
-- Parse the flat JSON array: [{"artist_id": "id", "role": "role", "sub_role": "subRole"}]
|
||||
FROM json_each(?) -- Iterate through each array element
|
||||
-- CRITICAL: Only insert records for artists that actually exist in the database
|
||||
JOIN artist ON artist.id = json_extract(value, '$.artist_id') -- Filter out non-existent artist IDs via INNER JOIN
|
||||
-- Handle duplicate insertions gracefully (e.g., if called multiple times)
|
||||
ON CONFLICT (artist_id, %[1]s_id, role, sub_role) DO NOTHING -- Ignore duplicates
|
||||
`, r.tableName)
|
||||
|
||||
_, err = r.executeSQL(Expr(query, itemID, string(participantsJSON)))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *sqlRepository) getParticipants(m *model.MediaFile) (model.Participants, error) {
|
||||
artistRepo := NewArtistRepository(r.ctx, r.db, nil)
|
||||
ids := m.Participants.AllIDs()
|
||||
artists, err := artistRepo.GetAll(model.QueryOptions{Filters: Eq{"artist.id": ids}})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting participants: %w", err)
|
||||
}
|
||||
artistMap := slice.ToMap(artists, func(a model.Artist) (string, model.Artist) {
|
||||
return a.ID, a
|
||||
})
|
||||
p := m.Participants
|
||||
for role, artistList := range p {
|
||||
for idx, artist := range artistList {
|
||||
if a, ok := artistMap[artist.ID]; ok {
|
||||
p[role][idx].Artist = a
|
||||
}
|
||||
}
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
180
persistence/sql_restful.go
Normal file
180
persistence/sql_restful.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/fatih/structs"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
)
|
||||
|
||||
type filterFunc = func(field string, value any) Sqlizer
|
||||
|
||||
func (r *sqlRepository) parseRestFilters(ctx context.Context, options rest.QueryOptions) Sqlizer {
|
||||
if len(options.Filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
filters := And{}
|
||||
for f, v := range options.Filters {
|
||||
// Ignore filters with empty values
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
// Look for a custom filter function
|
||||
f = strings.ToLower(f)
|
||||
if ff, ok := r.filterMappings[f]; ok {
|
||||
if filter := ff(f, v); filter != nil {
|
||||
filters = append(filters, filter)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Ignore invalid filters (not based on a field or filter function)
|
||||
if r.isFieldWhiteListed != nil && !r.isFieldWhiteListed(f) {
|
||||
log.Warn(ctx, "Ignoring filter not whitelisted", "filter", f, "table", r.tableName)
|
||||
continue
|
||||
}
|
||||
// For fields ending in "id", use an exact match
|
||||
if strings.HasSuffix(f, "id") {
|
||||
filters = append(filters, eqFilter(f, v))
|
||||
continue
|
||||
}
|
||||
// Default to a "starts with" filter
|
||||
filters = append(filters, startsWithFilter(f, v))
|
||||
}
|
||||
return filters
|
||||
}
|
||||
|
||||
func (r *sqlRepository) parseRestOptions(ctx context.Context, options ...rest.QueryOptions) model.QueryOptions {
|
||||
qo := model.QueryOptions{}
|
||||
if len(options) > 0 {
|
||||
qo.Sort, qo.Order = r.sanitizeSort(options[0].Sort, options[0].Order)
|
||||
qo.Max = options[0].Max
|
||||
qo.Offset = options[0].Offset
|
||||
if seed, ok := options[0].Filters["seed"].(string); ok {
|
||||
qo.Seed = seed
|
||||
delete(options[0].Filters, "seed")
|
||||
}
|
||||
qo.Filters = r.parseRestFilters(ctx, options[0])
|
||||
}
|
||||
return qo
|
||||
}
|
||||
|
||||
func (r sqlRepository) sanitizeSort(sort, order string) (string, string) {
|
||||
if sort != "" {
|
||||
sort = toSnakeCase(sort)
|
||||
if mapped, ok := r.sortMappings[sort]; ok {
|
||||
sort = mapped
|
||||
} else {
|
||||
if !r.isFieldWhiteListed(sort) {
|
||||
log.Warn(r.ctx, "Ignoring sort not whitelisted", "sort", sort, "table", r.tableName)
|
||||
sort = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
if order != "" {
|
||||
order = strings.ToLower(order)
|
||||
if order != "desc" {
|
||||
order = "asc"
|
||||
}
|
||||
}
|
||||
return sort, order
|
||||
}
|
||||
|
||||
func eqFilter(field string, value any) Sqlizer {
|
||||
return Eq{field: value}
|
||||
}
|
||||
|
||||
func startsWithFilter(field string, value any) Sqlizer {
|
||||
return Like{field: fmt.Sprintf("%s%%", value)}
|
||||
}
|
||||
|
||||
func containsFilter(field string) func(string, any) Sqlizer {
|
||||
return func(_ string, value any) Sqlizer {
|
||||
return Like{field: fmt.Sprintf("%%%s%%", value)}
|
||||
}
|
||||
}
|
||||
|
||||
func booleanFilter(field string, value any) Sqlizer {
|
||||
v := strings.ToLower(value.(string))
|
||||
return Eq{field: v == "true"}
|
||||
}
|
||||
|
||||
func fullTextFilter(tableName string, mbidFields ...string) func(string, any) Sqlizer {
|
||||
return func(field string, value any) Sqlizer {
|
||||
v := strings.ToLower(value.(string))
|
||||
cond := cmp.Or(
|
||||
mbidExpr(tableName, v, mbidFields...),
|
||||
fullTextExpr(tableName, v),
|
||||
)
|
||||
return cond
|
||||
}
|
||||
}
|
||||
|
||||
func substringFilter(field string, value any) Sqlizer {
|
||||
parts := strings.Fields(value.(string))
|
||||
filters := And{}
|
||||
for _, part := range parts {
|
||||
filters = append(filters, Like{field: "%" + part + "%"})
|
||||
}
|
||||
return filters
|
||||
}
|
||||
|
||||
func idFilter(tableName string) func(string, any) Sqlizer {
|
||||
return func(field string, value any) Sqlizer { return Eq{tableName + ".id": value} }
|
||||
}
|
||||
|
||||
func invalidFilter(ctx context.Context) func(string, any) Sqlizer {
|
||||
return func(field string, value any) Sqlizer {
|
||||
log.Warn(ctx, "Invalid filter", "fieldName", field, "value", value)
|
||||
return Eq{"1": "0"}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
whiteList = map[string]map[string]struct{}{}
|
||||
mutex sync.RWMutex
|
||||
)
|
||||
|
||||
func registerModelWhiteList(instance any) fieldWhiteListedFunc {
|
||||
name := reflect.TypeOf(instance).String()
|
||||
registerFieldWhiteList(name, instance)
|
||||
return getFieldWhiteListedFunc(name)
|
||||
}
|
||||
|
||||
func registerFieldWhiteList(name string, instance any) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
if whiteList[name] != nil {
|
||||
return
|
||||
}
|
||||
m := structs.Map(instance)
|
||||
whiteList[name] = map[string]struct{}{}
|
||||
for k := range m {
|
||||
whiteList[name][toSnakeCase(k)] = struct{}{}
|
||||
}
|
||||
ma := structs.Map(model.Annotations{})
|
||||
for k := range ma {
|
||||
whiteList[name][toSnakeCase(k)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
type fieldWhiteListedFunc func(field string) bool
|
||||
|
||||
func getFieldWhiteListedFunc(tableName string) fieldWhiteListedFunc {
|
||||
return func(field string) bool {
|
||||
mutex.RLock()
|
||||
defer mutex.RUnlock()
|
||||
if _, ok := whiteList[tableName]; !ok {
|
||||
return false
|
||||
}
|
||||
_, ok := whiteList[tableName][field]
|
||||
return ok
|
||||
}
|
||||
}
|
||||
235
persistence/sql_restful_test.go
Normal file
235
persistence/sql_restful_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/conf/configtest"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("sqlRestful", func() {
|
||||
Describe("parseRestFilters", func() {
|
||||
var r sqlRepository
|
||||
var options rest.QueryOptions
|
||||
|
||||
BeforeEach(func() {
|
||||
r = sqlRepository{}
|
||||
})
|
||||
|
||||
It("returns nil if filters is empty", func() {
|
||||
options.Filters = nil
|
||||
Expect(r.parseRestFilters(context.Background(), options)).To(BeNil())
|
||||
})
|
||||
|
||||
It(`returns nil if tries a filter with fullTextExpr("'")`, func() {
|
||||
r.filterMappings = map[string]filterFunc{
|
||||
"name": fullTextFilter("table"),
|
||||
}
|
||||
options.Filters = map[string]interface{}{"name": "'"}
|
||||
Expect(r.parseRestFilters(context.Background(), options)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("does not add nill filters", func() {
|
||||
r.filterMappings = map[string]filterFunc{
|
||||
"name": func(string, any) squirrel.Sqlizer {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
options.Filters = map[string]interface{}{"name": "joe"}
|
||||
Expect(r.parseRestFilters(context.Background(), options)).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns a '=' condition for 'id' filter", func() {
|
||||
options.Filters = map[string]interface{}{"id": "123"}
|
||||
Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Eq{"id": "123"}}))
|
||||
})
|
||||
|
||||
It("returns a 'in' condition for multiples 'id' filters", func() {
|
||||
options.Filters = map[string]interface{}{"id": []string{"123", "456"}}
|
||||
Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Eq{"id": []string{"123", "456"}}}))
|
||||
})
|
||||
|
||||
It("returns a 'like' condition for other filters", func() {
|
||||
options.Filters = map[string]interface{}{"name": "joe"}
|
||||
Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Like{"name": "joe%"}}))
|
||||
})
|
||||
|
||||
It("uses the custom filter", func() {
|
||||
r.filterMappings = map[string]filterFunc{
|
||||
"test": func(field string, value interface{}) squirrel.Sqlizer {
|
||||
return squirrel.Gt{field: value}
|
||||
},
|
||||
}
|
||||
options.Filters = map[string]interface{}{"test": 100}
|
||||
Expect(r.parseRestFilters(context.Background(), options)).To(Equal(squirrel.And{squirrel.Gt{"test": 100}}))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("fullTextFilter function", func() {
|
||||
var filter filterFunc
|
||||
var tableName string
|
||||
var mbidFields []string
|
||||
|
||||
BeforeEach(func() {
|
||||
DeferCleanup(configtest.SetupConfig())
|
||||
tableName = "test_table"
|
||||
mbidFields = []string{"mbid", "artist_mbid"}
|
||||
filter = fullTextFilter(tableName, mbidFields...)
|
||||
})
|
||||
|
||||
Context("when value is a valid UUID", func() {
|
||||
It("returns only the mbid filter (precedence over full text)", func() {
|
||||
uuid := "550e8400-e29b-41d4-a716-446655440000"
|
||||
result := filter("search", uuid)
|
||||
|
||||
expected := squirrel.Or{
|
||||
squirrel.Eq{"test_table.mbid": uuid},
|
||||
squirrel.Eq{"test_table.artist_mbid": uuid},
|
||||
}
|
||||
Expect(result).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("falls back to full text when no mbid fields are provided", func() {
|
||||
noMbidFilter := fullTextFilter(tableName)
|
||||
uuid := "550e8400-e29b-41d4-a716-446655440000"
|
||||
result := noMbidFilter("search", uuid)
|
||||
|
||||
// mbidExpr with no fields returns nil, so cmp.Or falls back to fullTextExpr
|
||||
expected := squirrel.And{
|
||||
squirrel.Like{"test_table.full_text": "% 550e8400-e29b-41d4-a716-446655440000%"},
|
||||
}
|
||||
Expect(result).To(Equal(expected))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when value is not a valid UUID", func() {
|
||||
It("returns full text search condition only", func() {
|
||||
result := filter("search", "beatles")
|
||||
|
||||
// mbidExpr returns nil for non-UUIDs, so fullTextExpr result is returned directly
|
||||
expected := squirrel.And{
|
||||
squirrel.Like{"test_table.full_text": "% beatles%"},
|
||||
}
|
||||
Expect(result).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles multi-word search terms", func() {
|
||||
result := filter("search", "the beatles abbey road")
|
||||
|
||||
// Should return And condition directly
|
||||
andCondition, ok := result.(squirrel.And)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(andCondition).To(HaveLen(4))
|
||||
|
||||
// Check that all words are present (order may vary)
|
||||
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% the%"}))
|
||||
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% beatles%"}))
|
||||
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% abbey%"}))
|
||||
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% road%"}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when SearchFullString config changes behavior", func() {
|
||||
It("uses different separator with SearchFullString=false", func() {
|
||||
conf.Server.SearchFullString = false
|
||||
result := filter("search", "test query")
|
||||
|
||||
andCondition, ok := result.(squirrel.And)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(andCondition).To(HaveLen(2))
|
||||
|
||||
// Check that all words are present with leading space (order may vary)
|
||||
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% test%"}))
|
||||
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "% query%"}))
|
||||
})
|
||||
|
||||
It("uses no separator with SearchFullString=true", func() {
|
||||
conf.Server.SearchFullString = true
|
||||
result := filter("search", "test query")
|
||||
|
||||
andCondition, ok := result.(squirrel.And)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(andCondition).To(HaveLen(2))
|
||||
|
||||
// Check that all words are present without leading space (order may vary)
|
||||
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "%test%"}))
|
||||
Expect(andCondition).To(ContainElement(squirrel.Like{"test_table.full_text": "%query%"}))
|
||||
})
|
||||
})
|
||||
|
||||
Context("edge cases", func() {
|
||||
It("returns nil for empty string", func() {
|
||||
result := filter("search", "")
|
||||
Expect(result).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil for string with only whitespace", func() {
|
||||
result := filter("search", " ")
|
||||
Expect(result).To(BeNil())
|
||||
})
|
||||
|
||||
It("handles special characters that are sanitized", func() {
|
||||
result := filter("search", "don't")
|
||||
|
||||
expected := squirrel.And{
|
||||
squirrel.Like{"test_table.full_text": "% dont%"}, // str.SanitizeStrings removes quotes
|
||||
}
|
||||
Expect(result).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("returns nil for single quote (SQL injection protection)", func() {
|
||||
result := filter("search", "'")
|
||||
Expect(result).To(BeNil())
|
||||
})
|
||||
|
||||
It("handles mixed case UUIDs", func() {
|
||||
uuid := "550E8400-E29B-41D4-A716-446655440000"
|
||||
result := filter("search", uuid)
|
||||
|
||||
// Should return only mbid filter (uppercase UUID should work)
|
||||
expected := squirrel.Or{
|
||||
squirrel.Eq{"test_table.mbid": strings.ToLower(uuid)},
|
||||
squirrel.Eq{"test_table.artist_mbid": strings.ToLower(uuid)},
|
||||
}
|
||||
Expect(result).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles invalid UUID format gracefully", func() {
|
||||
result := filter("search", "550e8400-invalid-uuid")
|
||||
|
||||
// Should return full text filter since UUID is invalid
|
||||
expected := squirrel.And{
|
||||
squirrel.Like{"test_table.full_text": "% 550e8400-invalid-uuid%"},
|
||||
}
|
||||
Expect(result).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("handles empty mbid fields array", func() {
|
||||
emptyMbidFilter := fullTextFilter(tableName, []string{}...)
|
||||
result := emptyMbidFilter("search", "test")
|
||||
|
||||
// mbidExpr with empty fields returns nil, so cmp.Or falls back to fullTextExpr
|
||||
expected := squirrel.And{
|
||||
squirrel.Like{"test_table.full_text": "% test%"},
|
||||
}
|
||||
Expect(result).To(Equal(expected))
|
||||
})
|
||||
|
||||
It("converts value to lowercase before processing", func() {
|
||||
result := filter("search", "TEST")
|
||||
|
||||
// The function converts to lowercase internally
|
||||
expected := squirrel.And{
|
||||
squirrel.Like{"test_table.full_text": "% test%"},
|
||||
}
|
||||
Expect(result).To(Equal(expected))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
77
persistence/sql_search.go
Normal file
77
persistence/sql_search.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/str"
|
||||
)
|
||||
|
||||
func formatFullText(text ...string) string {
|
||||
fullText := str.SanitizeStrings(text...)
|
||||
return " " + fullText
|
||||
}
|
||||
|
||||
// doSearch performs a full-text search with the specified parameters.
|
||||
// The naturalOrder is used to sort results when no full-text filter is applied. It is useful for cases like
|
||||
// OpenSubsonic, where an empty search query should return all results in a natural order. Normally the parameter
|
||||
// should be `tableName + ".rowid"`, but some repositories (ex: artist) may use a different natural order.
|
||||
func (r sqlRepository) doSearch(sq SelectBuilder, q string, offset, size int, results any, naturalOrder string, orderBys ...string) error {
|
||||
q = strings.TrimSpace(q)
|
||||
q = strings.TrimSuffix(q, "*")
|
||||
if len(q) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
filter := fullTextExpr(r.tableName, q)
|
||||
if filter != nil {
|
||||
sq = sq.Where(filter)
|
||||
sq = sq.OrderBy(orderBys...)
|
||||
} else {
|
||||
// This is to speed up the results of `search3?query=""`, for OpenSubsonic
|
||||
// If the filter is empty, we sort by the specified natural order.
|
||||
sq = sq.OrderBy(naturalOrder)
|
||||
}
|
||||
sq = sq.Where(Eq{r.tableName + ".missing": false})
|
||||
sq = sq.Limit(uint64(size)).Offset(uint64(offset))
|
||||
return r.queryAll(sq, results, model.QueryOptions{Offset: offset})
|
||||
}
|
||||
|
||||
func (r sqlRepository) searchByMBID(sq SelectBuilder, mbid string, mbidFields []string, results any) error {
|
||||
sq = sq.Where(mbidExpr(r.tableName, mbid, mbidFields...))
|
||||
sq = sq.Where(Eq{r.tableName + ".missing": false})
|
||||
|
||||
return r.queryAll(sq, results)
|
||||
}
|
||||
|
||||
func mbidExpr(tableName, mbid string, mbidFields ...string) Sqlizer {
|
||||
if uuid.Validate(mbid) != nil || len(mbidFields) == 0 {
|
||||
return nil
|
||||
}
|
||||
mbid = strings.ToLower(mbid)
|
||||
var cond []Sqlizer
|
||||
for _, mbidField := range mbidFields {
|
||||
cond = append(cond, Eq{tableName + "." + mbidField: mbid})
|
||||
}
|
||||
return Or(cond)
|
||||
}
|
||||
|
||||
func fullTextExpr(tableName string, s string) Sqlizer {
|
||||
q := str.SanitizeStrings(s)
|
||||
if q == "" {
|
||||
return nil
|
||||
}
|
||||
var sep string
|
||||
if !conf.Server.SearchFullString {
|
||||
sep = " "
|
||||
}
|
||||
parts := strings.Split(q, " ")
|
||||
filters := And{}
|
||||
for _, part := range parts {
|
||||
filters = append(filters, Like{tableName + ".full_text": "%" + sep + part + "%"})
|
||||
}
|
||||
return filters
|
||||
}
|
||||
14
persistence/sql_search_test.go
Normal file
14
persistence/sql_search_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("sqlRepository", func() {
|
||||
Describe("formatFullText", func() {
|
||||
It("prefixes with a space", func() {
|
||||
Expect(formatFullText("legiao urbana")).To(Equal(" legiao urbana"))
|
||||
})
|
||||
})
|
||||
})
|
||||
168
persistence/sql_tags.go
Normal file
168
persistence/sql_tags.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
// Format of a tag in the DB
|
||||
type dbTag struct {
|
||||
ID string `json:"id"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
type dbTags map[model.TagName][]dbTag
|
||||
|
||||
func unmarshalTags(data string) (model.Tags, error) {
|
||||
var dbTags dbTags
|
||||
err := json.Unmarshal([]byte(data), &dbTags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing tags: %w", err)
|
||||
}
|
||||
|
||||
res := make(model.Tags, len(dbTags))
|
||||
for name, tags := range dbTags {
|
||||
res[name] = make([]string, len(tags))
|
||||
for i, tag := range tags {
|
||||
res[name][i] = tag.Value
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func marshalTags(tags model.Tags) string {
|
||||
dbTags := dbTags{}
|
||||
for name, values := range tags {
|
||||
for _, value := range values {
|
||||
t := model.NewTag(name, value)
|
||||
dbTags[name] = append(dbTags[name], dbTag{ID: t.ID, Value: value})
|
||||
}
|
||||
}
|
||||
res, _ := json.Marshal(dbTags)
|
||||
return string(res)
|
||||
}
|
||||
|
||||
func tagIDFilter(name string, idValue any) Sqlizer {
|
||||
name = strings.TrimSuffix(name, "_id")
|
||||
return Exists(
|
||||
fmt.Sprintf(`json_tree(tags, "$.%s")`, name),
|
||||
And{
|
||||
NotEq{"json_tree.atom": nil},
|
||||
Eq{"value": idValue},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// tagLibraryIdFilter filters tags based on library access through the library_tag table
|
||||
func tagLibraryIdFilter(_ string, value interface{}) Sqlizer {
|
||||
return Eq{"library_tag.library_id": value}
|
||||
}
|
||||
|
||||
// baseTagRepository provides common functionality for all tag-based repositories.
|
||||
// It handles CRUD operations with optional filtering by tag name.
|
||||
type baseTagRepository struct {
|
||||
sqlRepository
|
||||
tagFilter *model.TagName // nil = no filter (all tags), non-nil = filter by specific tag name
|
||||
}
|
||||
|
||||
// newBaseTagRepository creates a new base tag repository with optional tag filtering.
|
||||
// If tagFilter is nil, the repository will work with all tags.
|
||||
// If tagFilter is provided, the repository will only work with tags of that specific name.
|
||||
func newBaseTagRepository(ctx context.Context, db dbx.Builder, tagFilter *model.TagName) *baseTagRepository {
|
||||
r := &baseTagRepository{
|
||||
tagFilter: tagFilter,
|
||||
}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.tableName = "tag"
|
||||
r.registerModel(&model.Tag{}, map[string]filterFunc{
|
||||
"name": containsFilter("tag_value"),
|
||||
"library_id": tagLibraryIdFilter,
|
||||
})
|
||||
r.setSortMappings(map[string]string{
|
||||
"name": "tag_value",
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
// applyLibraryFiltering adds the appropriate library joins based on user context
|
||||
func (r *baseTagRepository) applyLibraryFiltering(sq SelectBuilder) SelectBuilder {
|
||||
// Add library_tag join
|
||||
sq = sq.LeftJoin("library_tag on library_tag.tag_id = tag.id")
|
||||
|
||||
// For authenticated users, also join with user_library to filter by accessible libraries
|
||||
user := loggedUser(r.ctx)
|
||||
if user.ID != invalidUserId {
|
||||
sq = sq.Join("user_library on user_library.library_id = library_tag.library_id AND user_library.user_id = ?", user.ID)
|
||||
}
|
||||
|
||||
return sq
|
||||
}
|
||||
|
||||
// newSelect overrides the base implementation to apply tag name filtering and library filtering.
|
||||
func (r *baseTagRepository) newSelect(options ...model.QueryOptions) SelectBuilder {
|
||||
sq := r.sqlRepository.newSelect(options...)
|
||||
|
||||
// Apply tag name filtering if specified
|
||||
if r.tagFilter != nil {
|
||||
sq = sq.Where(Eq{"tag.tag_name": *r.tagFilter})
|
||||
}
|
||||
|
||||
// Apply library filtering and set up aggregation columns
|
||||
sq = r.applyLibraryFiltering(sq).Columns(
|
||||
"tag.id",
|
||||
"tag.tag_name",
|
||||
"tag.tag_value",
|
||||
"COALESCE(SUM(library_tag.album_count), 0) as album_count",
|
||||
"COALESCE(SUM(library_tag.media_file_count), 0) as song_count",
|
||||
).GroupBy("tag.id", "tag.tag_name", "tag.tag_value")
|
||||
|
||||
return sq
|
||||
}
|
||||
|
||||
// ResourceRepository interface implementation
|
||||
|
||||
func (r *baseTagRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
sq := Select("COUNT(DISTINCT tag.id)").From("tag")
|
||||
|
||||
// Apply tag name filtering if specified
|
||||
if r.tagFilter != nil {
|
||||
sq = sq.Where(Eq{"tag.tag_name": *r.tagFilter})
|
||||
}
|
||||
|
||||
// Apply library filtering
|
||||
sq = r.applyLibraryFiltering(sq)
|
||||
|
||||
return r.count(sq, r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *baseTagRepository) Read(id string) (interface{}, error) {
|
||||
query := r.newSelect().Where(Eq{"id": id})
|
||||
var res model.Tag
|
||||
err := r.queryOne(query, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *baseTagRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
query := r.newSelect(r.parseRestOptions(r.ctx, options...))
|
||||
var res model.TagList
|
||||
err := r.queryAll(query, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *baseTagRepository) EntityName() string {
|
||||
return "tag"
|
||||
}
|
||||
|
||||
func (r *baseTagRepository) NewInstance() interface{} {
|
||||
return model.Tag{}
|
||||
}
|
||||
|
||||
// Interface compliance check
|
||||
var _ model.ResourceRepository = (*baseTagRepository)(nil)
|
||||
263
persistence/tag_library_filtering_test.go
Normal file
263
persistence/tag_library_filtering_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/conf/configtest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
const (
|
||||
adminUserID = "userid"
|
||||
regularUserID = "2222"
|
||||
libraryID1 = 1
|
||||
libraryID2 = 2
|
||||
libraryID3 = 3
|
||||
|
||||
tagNameGenre = "genre"
|
||||
tagValueRock = "rock"
|
||||
tagValuePop = "pop"
|
||||
tagValueJazz = "jazz"
|
||||
)
|
||||
|
||||
var _ = Describe("Tag Library Filtering", func() {
|
||||
var (
|
||||
tagRockID = id.NewTagID(tagNameGenre, tagValueRock)
|
||||
tagPopID = id.NewTagID(tagNameGenre, tagValuePop)
|
||||
tagJazzID = id.NewTagID(tagNameGenre, tagValueJazz)
|
||||
)
|
||||
|
||||
expectTagValues := func(tagList model.TagList, expected []string) {
|
||||
tagValues := make([]string, len(tagList))
|
||||
for i, tag := range tagList {
|
||||
tagValues[i] = tag.TagValue
|
||||
}
|
||||
Expect(tagValues).To(ContainElements(expected))
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
DeferCleanup(configtest.SetupConfig())
|
||||
|
||||
// Generate unique path suffix to avoid conflicts with other tests
|
||||
uniqueSuffix := time.Now().Format("20060102150405.000")
|
||||
|
||||
// Clean up database
|
||||
db := GetDBXBuilder()
|
||||
_, err := db.NewQuery("DELETE FROM library_tag").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = db.NewQuery("DELETE FROM tag").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = db.NewQuery("DELETE FROM user_library WHERE user_id != {:admin} AND user_id != {:regular}").
|
||||
Bind(dbx.Params{"admin": adminUserID, "regular": regularUserID}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = db.NewQuery("DELETE FROM library WHERE id > 1").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create test libraries with unique names and paths to avoid conflicts with other tests
|
||||
_, err = db.NewQuery("INSERT INTO library (id, name, path) VALUES ({:id}, {:name}, {:path})").
|
||||
Bind(dbx.Params{"id": libraryID2, "name": "Library 2-" + uniqueSuffix, "path": "/music/lib2-" + uniqueSuffix}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = db.NewQuery("INSERT INTO library (id, name, path) VALUES ({:id}, {:name}, {:path})").
|
||||
Bind(dbx.Params{"id": libraryID3, "name": "Library 3-" + uniqueSuffix, "path": "/music/lib3-" + uniqueSuffix}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Give admin access to all libraries
|
||||
for _, libID := range []int{libraryID1, libraryID2, libraryID3} {
|
||||
_, err = db.NewQuery("INSERT OR IGNORE INTO user_library (user_id, library_id) VALUES ({:user}, {:lib})").
|
||||
Bind(dbx.Params{"user": adminUserID, "lib": libID}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
// Create test tags
|
||||
adminCtx := request.WithUser(log.NewContext(context.TODO()), adminUser)
|
||||
tagRepo := NewTagRepository(adminCtx, GetDBXBuilder())
|
||||
|
||||
createTag := func(libraryID int, name, value string) {
|
||||
tag := model.Tag{ID: id.NewTagID(name, value), TagName: model.TagName(name), TagValue: value}
|
||||
err := tagRepo.Add(libraryID, tag)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
createTag(libraryID1, tagNameGenre, tagValueRock)
|
||||
createTag(libraryID2, tagNameGenre, tagValuePop)
|
||||
createTag(libraryID3, tagNameGenre, tagValueJazz)
|
||||
createTag(libraryID2, tagNameGenre, tagValueRock) // Rock appears in both lib1 and lib2
|
||||
|
||||
// Set tag counts (manually for testing)
|
||||
setCounts := func(tagID string, libID, albums, songs int) {
|
||||
_, err := db.NewQuery("UPDATE library_tag SET album_count = {:albums}, media_file_count = {:songs} WHERE tag_id = {:tag} AND library_id = {:lib}").
|
||||
Bind(dbx.Params{"albums": albums, "songs": songs, "tag": tagID, "lib": libID}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
|
||||
setCounts(tagRockID, libraryID1, 5, 20)
|
||||
setCounts(tagPopID, libraryID2, 3, 10)
|
||||
setCounts(tagJazzID, libraryID3, 2, 8)
|
||||
setCounts(tagRockID, libraryID2, 1, 4)
|
||||
|
||||
// Give regular user access to library 2 only
|
||||
_, err = db.NewQuery("INSERT INTO user_library (user_id, library_id) VALUES ({:user}, {:lib})").
|
||||
Bind(dbx.Params{"user": regularUserID, "lib": libraryID2}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Describe("TagRepository Library Filtering", func() {
|
||||
// Helper to create repository and read all tags
|
||||
readAllTags := func(user *model.User, filters ...rest.QueryOptions) model.TagList {
|
||||
var ctx context.Context
|
||||
if user != nil {
|
||||
ctx = request.WithUser(log.NewContext(context.TODO()), *user)
|
||||
} else {
|
||||
ctx = context.Background() // Headless context
|
||||
}
|
||||
|
||||
tagRepo := NewTagRepository(ctx, GetDBXBuilder())
|
||||
repo := tagRepo.(model.ResourceRepository)
|
||||
|
||||
var opts rest.QueryOptions
|
||||
if len(filters) > 0 {
|
||||
opts = filters[0]
|
||||
}
|
||||
|
||||
tags, err := repo.ReadAll(opts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return tags.(model.TagList)
|
||||
}
|
||||
|
||||
// Helper to count tags
|
||||
countTags := func(user *model.User) int64 {
|
||||
var ctx context.Context
|
||||
if user != nil {
|
||||
ctx = request.WithUser(log.NewContext(context.TODO()), *user)
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
tagRepo := NewTagRepository(ctx, GetDBXBuilder())
|
||||
repo := tagRepo.(model.ResourceRepository)
|
||||
|
||||
count, err := repo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
return count
|
||||
}
|
||||
|
||||
Context("Admin User", func() {
|
||||
It("should see all tags regardless of library", func() {
|
||||
tags := readAllTags(&adminUser)
|
||||
Expect(tags).To(HaveLen(3))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Regular User with Limited Library Access", func() {
|
||||
It("should only see tags from accessible libraries", func() {
|
||||
tags := readAllTags(®ularUser)
|
||||
// Should see rock (libraries 1,2) and pop (library 2), but not jazz (library 3)
|
||||
Expect(tags).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("should respect explicit library_id filters within accessible libraries", func() {
|
||||
tags := readAllTags(®ularUser, rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"library_id": libraryID2},
|
||||
})
|
||||
// Should see only tags from library 2: pop and rock(lib2)
|
||||
Expect(tags).To(HaveLen(2))
|
||||
expectTagValues(tags, []string{tagValuePop, tagValueRock})
|
||||
})
|
||||
|
||||
It("should not return tags when filtering by inaccessible library", func() {
|
||||
tags := readAllTags(®ularUser, rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"library_id": libraryID3},
|
||||
})
|
||||
// Should return no tags since user can't access library 3
|
||||
Expect(tags).To(HaveLen(0))
|
||||
})
|
||||
|
||||
It("should filter by library 1 correctly", func() {
|
||||
tags := readAllTags(®ularUser, rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"library_id": libraryID1},
|
||||
})
|
||||
// Should see only rock from library 1
|
||||
Expect(tags).To(HaveLen(1))
|
||||
Expect(tags[0].TagValue).To(Equal(tagValueRock))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Headless Processes (No User Context)", func() {
|
||||
It("should see all tags from all libraries when no user is in context", func() {
|
||||
tags := readAllTags(nil) // nil = headless context
|
||||
// Should see all tags from all libraries (no filtering applied)
|
||||
Expect(tags).To(HaveLen(3))
|
||||
expectTagValues(tags, []string{tagValueRock, tagValuePop, tagValueJazz})
|
||||
})
|
||||
|
||||
It("should count all tags from all libraries when no user is in context", func() {
|
||||
count := countTags(nil)
|
||||
// Should count all tags from all libraries
|
||||
Expect(count).To(Equal(int64(3)))
|
||||
})
|
||||
|
||||
It("should calculate proper statistics from all libraries for headless processes", func() {
|
||||
tags := readAllTags(nil)
|
||||
|
||||
// Find the rock tag (appears in libraries 1 and 2)
|
||||
var rockTag *model.Tag
|
||||
for _, tag := range tags {
|
||||
if tag.TagValue == tagValueRock {
|
||||
rockTag = &tag
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(rockTag).ToNot(BeNil())
|
||||
|
||||
// Should have stats from all libraries where rock appears
|
||||
// Library 1: 5 albums, 20 songs
|
||||
// Library 2: 1 album, 4 songs
|
||||
// Total: 6 albums, 24 songs
|
||||
Expect(rockTag.AlbumCount).To(Equal(6))
|
||||
Expect(rockTag.SongCount).To(Equal(24))
|
||||
})
|
||||
|
||||
It("should allow headless processes to apply explicit library_id filters", func() {
|
||||
tags := readAllTags(nil, rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"library_id": libraryID3},
|
||||
})
|
||||
// Should see only jazz from library 3
|
||||
Expect(tags).To(HaveLen(1))
|
||||
Expect(tags[0].TagValue).To(Equal(tagValueJazz))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Admin User with Explicit Library Filtering", func() {
|
||||
It("should see all tags when no filter is applied", func() {
|
||||
tags := readAllTags(&adminUser)
|
||||
Expect(tags).To(HaveLen(3))
|
||||
})
|
||||
|
||||
It("should respect explicit library_id filters", func() {
|
||||
tags := readAllTags(&adminUser, rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"library_id": libraryID3},
|
||||
})
|
||||
// Should see only jazz from library 3
|
||||
Expect(tags).To(HaveLen(1))
|
||||
Expect(tags[0].TagValue).To(Equal(tagValueJazz))
|
||||
})
|
||||
|
||||
It("should filter by library 2 correctly", func() {
|
||||
tags := readAllTags(&adminUser, rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"library_id": libraryID2},
|
||||
})
|
||||
// Should see pop and rock from library 2
|
||||
Expect(tags).To(HaveLen(2))
|
||||
expectTagValues(tags, []string{tagValuePop, tagValueRock})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
99
persistence/tag_repository.go
Normal file
99
persistence/tag_repository.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type tagRepository struct {
|
||||
*baseTagRepository
|
||||
}
|
||||
|
||||
func NewTagRepository(ctx context.Context, db dbx.Builder) model.TagRepository {
|
||||
return &tagRepository{
|
||||
baseTagRepository: newBaseTagRepository(ctx, db, nil), // nil = no filter, works with all tags
|
||||
}
|
||||
}
|
||||
|
||||
func (r *tagRepository) Add(libraryID int, tags ...model.Tag) error {
|
||||
for chunk := range slices.Chunk(tags, 200) {
|
||||
sq := Insert(r.tableName).Columns("id", "tag_name", "tag_value").
|
||||
Suffix("on conflict (id) do nothing")
|
||||
for _, t := range chunk {
|
||||
sq = sq.Values(t.ID, t.TagName, t.TagValue)
|
||||
}
|
||||
_, err := r.executeSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create library_tag entries for library filtering
|
||||
libSq := Insert("library_tag").Columns("tag_id", "library_id", "album_count", "media_file_count").
|
||||
Suffix("on conflict (tag_id, library_id) do nothing")
|
||||
for _, t := range chunk {
|
||||
libSq = libSq.Values(t.ID, libraryID, 0, 0)
|
||||
}
|
||||
_, err = r.executeSQL(libSq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding library_tag entries: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateCounts updates the library_tag table with per-library statistics.
|
||||
// Only genres are being updated for now.
|
||||
func (r *tagRepository) UpdateCounts() error {
|
||||
template := `
|
||||
INSERT INTO library_tag (tag_id, library_id, %[1]s_count)
|
||||
SELECT jt.value as tag_id, %[1]s.library_id, count(distinct %[1]s.id) as %[1]s_count
|
||||
FROM %[1]s
|
||||
JOIN json_tree(%[1]s.tags, '$.genre') as jt ON jt.atom IS NOT NULL AND jt.key = 'id'
|
||||
JOIN tag ON tag.id = jt.value
|
||||
GROUP BY jt.value, %[1]s.library_id
|
||||
ON CONFLICT (tag_id, library_id)
|
||||
DO UPDATE SET %[1]s_count = excluded.%[1]s_count;
|
||||
`
|
||||
|
||||
for _, table := range []string{"album", "media_file"} {
|
||||
start := time.Now()
|
||||
query := Expr(fmt.Sprintf(template, table))
|
||||
c, err := r.executeSQL(query)
|
||||
log.Debug(r.ctx, "Updated library tag counts", "table", table, "elapsed", time.Since(start), "updated", c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating %s library tag counts: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *tagRepository) purgeUnused() error {
|
||||
del := Delete(r.tableName).Where(`
|
||||
id not in (select jt.value
|
||||
from album left join json_tree(album.tags, '$') as jt
|
||||
where atom is not null
|
||||
and key = 'id'
|
||||
UNION
|
||||
select jt.value
|
||||
from media_file left join json_tree(media_file.tags, '$') as jt
|
||||
where atom is not null
|
||||
and key = 'id')
|
||||
`)
|
||||
c, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error purging %s unused tags: %w", r.tableName, err)
|
||||
}
|
||||
if c > 0 {
|
||||
log.Debug(r.ctx, "Purged unused tags", "totalDeleted", c, "table", r.tableName)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var _ model.ResourceRepository = &tagRepository{}
|
||||
311
persistence/tag_repository_test.go
Normal file
311
persistence/tag_repository_test.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/conf/configtest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ = Describe("TagRepository", func() {
|
||||
var repo model.TagRepository
|
||||
var restRepo model.ResourceRepository
|
||||
var ctx context.Context
|
||||
|
||||
BeforeEach(func() {
|
||||
DeferCleanup(configtest.SetupConfig())
|
||||
ctx = request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", UserName: "johndoe", IsAdmin: true})
|
||||
tagRepo := NewTagRepository(ctx, GetDBXBuilder())
|
||||
repo = tagRepo
|
||||
restRepo = tagRepo.(model.ResourceRepository)
|
||||
|
||||
// Clean the database before each test to ensure isolation
|
||||
db := GetDBXBuilder()
|
||||
_, err := db.NewQuery("DELETE FROM tag").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = db.NewQuery("DELETE FROM library_tag").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Ensure library 1 exists (if it doesn't already)
|
||||
_, err = db.NewQuery("INSERT OR IGNORE INTO library (id, name, path, default_new_users) VALUES (1, 'Test Library', '/test', true)").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Ensure the admin user has access to library 1
|
||||
_, err = db.NewQuery("INSERT OR IGNORE INTO user_library (user_id, library_id) VALUES ('userid', 1)").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Add comprehensive test data that covers all test scenarios
|
||||
newTag := func(name, value string) model.Tag {
|
||||
return model.Tag{ID: id.NewTagID(name, value), TagName: model.TagName(name), TagValue: value}
|
||||
}
|
||||
|
||||
err = repo.Add(1,
|
||||
// Genre tags
|
||||
newTag("genre", "rock"),
|
||||
newTag("genre", "pop"),
|
||||
newTag("genre", "jazz"),
|
||||
newTag("genre", "electronic"),
|
||||
newTag("genre", "classical"),
|
||||
newTag("genre", "ambient"),
|
||||
newTag("genre", "techno"),
|
||||
newTag("genre", "house"),
|
||||
newTag("genre", "trance"),
|
||||
newTag("genre", "Alternative Rock"),
|
||||
newTag("genre", "Blues"),
|
||||
newTag("genre", "Country"),
|
||||
// Mood tags
|
||||
newTag("mood", "happy"),
|
||||
newTag("mood", "sad"),
|
||||
newTag("mood", "energetic"),
|
||||
newTag("mood", "calm"),
|
||||
// Other tag types
|
||||
newTag("instrument", "guitar"),
|
||||
newTag("instrument", "piano"),
|
||||
newTag("decade", "1980s"),
|
||||
newTag("decade", "1990s"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Describe("Add", func() {
|
||||
It("should handle adding new tags", func() {
|
||||
newTag := model.Tag{
|
||||
ID: id.NewTagID("genre", "experimental"),
|
||||
TagName: "genre",
|
||||
TagValue: "experimental",
|
||||
}
|
||||
|
||||
err := repo.Add(1, newTag)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify tag was added
|
||||
result, err := restRepo.Read(newTag.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resultTag := result.(*model.Tag)
|
||||
Expect(resultTag.TagValue).To(Equal("experimental"))
|
||||
|
||||
// Check count increased
|
||||
count, err := restRepo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(21))) // 20 from dataset + 1 new
|
||||
})
|
||||
|
||||
It("should handle duplicate tags gracefully", func() {
|
||||
// Try to add a duplicate tag
|
||||
duplicateTag := model.Tag{
|
||||
ID: id.NewTagID("genre", "rock"), // This already exists
|
||||
TagName: "genre",
|
||||
TagValue: "rock",
|
||||
}
|
||||
|
||||
count, err := restRepo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(20))) // Still 20 tags
|
||||
|
||||
err = repo.Add(1, duplicateTag)
|
||||
Expect(err).ToNot(HaveOccurred()) // Should not error
|
||||
|
||||
// Count should remain the same
|
||||
count, err = restRepo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(20))) // Still 20 tags
|
||||
})
|
||||
})
|
||||
|
||||
Describe("UpdateCounts", func() {
|
||||
It("should update tag counts successfully", func() {
|
||||
err := repo.UpdateCounts()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should handle empty database gracefully", func() {
|
||||
// Clear the database first
|
||||
db := GetDBXBuilder()
|
||||
_, err := db.NewQuery("DELETE FROM tag").Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = repo.UpdateCounts()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should handle albums with non-existent tag IDs in JSON gracefully", func() {
|
||||
// Regression test for foreign key constraint error
|
||||
// Create an album with tag IDs in JSON that don't exist in tag table
|
||||
db := GetDBXBuilder()
|
||||
|
||||
// First, create a non-existent tag ID (this simulates tags in JSON that aren't in tag table)
|
||||
nonExistentTagID := id.NewTagID("genre", "nonexistent-genre")
|
||||
|
||||
// Create album with JSON containing the non-existent tag ID
|
||||
albumWithBadTags := `{"genre":[{"id":"` + nonExistentTagID + `","value":"nonexistent-genre"}]}`
|
||||
|
||||
// Insert album directly into database with the problematic JSON
|
||||
_, err := db.NewQuery("INSERT INTO album (id, name, library_id, tags) VALUES ({:id}, {:name}, {:lib}, {:tags})").
|
||||
Bind(dbx.Params{
|
||||
"id": "test-album-bad-tags",
|
||||
"name": "Album With Bad Tags",
|
||||
"lib": 1,
|
||||
"tags": albumWithBadTags,
|
||||
}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// This should not fail with foreign key constraint error
|
||||
err = repo.UpdateCounts()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Cleanup
|
||||
_, err = db.NewQuery("DELETE FROM album WHERE id = {:id}").
|
||||
Bind(dbx.Params{"id": "test-album-bad-tags"}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should handle media files with non-existent tag IDs in JSON gracefully", func() {
|
||||
// Regression test for foreign key constraint error with media files
|
||||
db := GetDBXBuilder()
|
||||
|
||||
// Create a non-existent tag ID
|
||||
nonExistentTagID := id.NewTagID("genre", "another-nonexistent-genre")
|
||||
|
||||
// Create media file with JSON containing the non-existent tag ID
|
||||
mediaFileWithBadTags := `{"genre":[{"id":"` + nonExistentTagID + `","value":"another-nonexistent-genre"}]}`
|
||||
|
||||
// Insert media file directly into database with the problematic JSON
|
||||
_, err := db.NewQuery("INSERT INTO media_file (id, title, library_id, tags) VALUES ({:id}, {:title}, {:lib}, {:tags})").
|
||||
Bind(dbx.Params{
|
||||
"id": "test-media-bad-tags",
|
||||
"title": "Media File With Bad Tags",
|
||||
"lib": 1,
|
||||
"tags": mediaFileWithBadTags,
|
||||
}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// This should not fail with foreign key constraint error
|
||||
err = repo.UpdateCounts()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Cleanup
|
||||
_, err = db.NewQuery("DELETE FROM media_file WHERE id = {:id}").
|
||||
Bind(dbx.Params{"id": "test-media-bad-tags"}).Execute()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
It("should return correct count of tags", func() {
|
||||
count, err := restRepo.Count()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(int64(20))) // From the test dataset
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Read", func() {
|
||||
It("should return existing tag", func() {
|
||||
rockID := id.NewTagID("genre", "rock")
|
||||
result, err := restRepo.Read(rockID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
resultTag := result.(*model.Tag)
|
||||
Expect(resultTag.ID).To(Equal(rockID))
|
||||
Expect(resultTag.TagName).To(Equal(model.TagName("genre")))
|
||||
Expect(resultTag.TagValue).To(Equal("rock"))
|
||||
})
|
||||
|
||||
It("should return error for non-existent tag", func() {
|
||||
_, err := restRepo.Read("non-existent-id")
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ReadAll", func() {
|
||||
It("should return all tags from dataset", func() {
|
||||
result, err := restRepo.ReadAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tags := result.(model.TagList)
|
||||
Expect(tags).To(HaveLen(20))
|
||||
})
|
||||
|
||||
It("should filter tags by partial value correctly", func() {
|
||||
options := rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"name": "%rock%"}, // Tags containing 'rock'
|
||||
}
|
||||
result, err := restRepo.ReadAll(options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tags := result.(model.TagList)
|
||||
Expect(tags).To(HaveLen(2)) // "rock" and "Alternative Rock"
|
||||
|
||||
// Verify all returned tags contain 'rock' in their value
|
||||
for _, tag := range tags {
|
||||
Expect(strings.ToLower(tag.TagValue)).To(ContainSubstring("rock"))
|
||||
}
|
||||
})
|
||||
|
||||
It("should filter tags by partial value using LIKE", func() {
|
||||
options := rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"name": "%e%"}, // Tags containing 'e'
|
||||
}
|
||||
result, err := restRepo.ReadAll(options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tags := result.(model.TagList)
|
||||
Expect(tags).To(HaveLen(8)) // electronic, house, trance, energetic, Blues, decade x2, Alternative Rock
|
||||
|
||||
// Verify all returned tags contain 'e' in their value
|
||||
for _, tag := range tags {
|
||||
Expect(strings.ToLower(tag.TagValue)).To(ContainSubstring("e"))
|
||||
}
|
||||
})
|
||||
|
||||
It("should sort tags by value ascending", func() {
|
||||
options := rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"name": "%r%"}, // Tags containing 'r'
|
||||
Sort: "name",
|
||||
Order: "asc",
|
||||
}
|
||||
result, err := restRepo.ReadAll(options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tags := result.(model.TagList)
|
||||
Expect(tags).To(HaveLen(7))
|
||||
|
||||
Expect(slices.IsSortedFunc(tags, func(a, b model.Tag) int {
|
||||
return strings.Compare(strings.ToLower(a.TagValue), strings.ToLower(b.TagValue))
|
||||
}))
|
||||
})
|
||||
|
||||
It("should sort tags by value descending", func() {
|
||||
options := rest.QueryOptions{
|
||||
Filters: map[string]interface{}{"name": "%r%"}, // Tags containing 'r'
|
||||
Sort: "name",
|
||||
Order: "desc",
|
||||
}
|
||||
result, err := restRepo.ReadAll(options)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
tags := result.(model.TagList)
|
||||
Expect(tags).To(HaveLen(7))
|
||||
|
||||
Expect(slices.IsSortedFunc(tags, func(a, b model.Tag) int {
|
||||
return strings.Compare(strings.ToLower(b.TagValue), strings.ToLower(a.TagValue)) // Descending order
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("EntityName", func() {
|
||||
It("should return correct entity name", func() {
|
||||
name := restRepo.EntityName()
|
||||
Expect(name).To(Equal("tag"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("NewInstance", func() {
|
||||
It("should return new tag instance", func() {
|
||||
instance := restRepo.NewInstance()
|
||||
Expect(instance).To(BeAssignableToTypeOf(model.Tag{}))
|
||||
})
|
||||
})
|
||||
})
|
||||
112
persistence/transcoding_repository.go
Normal file
112
persistence/transcoding_repository.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type transcodingRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewTranscodingRepository(ctx context.Context, db dbx.Builder) model.TranscodingRepository {
|
||||
r := &transcodingRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.registerModel(&model.Transcoding{}, nil)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) Get(id string) (*model.Transcoding, error) {
|
||||
sel := r.newSelect().Columns("*").Where(Eq{"id": id})
|
||||
var res model.Transcoding
|
||||
err := r.queryOne(sel, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) CountAll(qo ...model.QueryOptions) (int64, error) {
|
||||
return r.count(Select(), qo...)
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) FindByFormat(format string) (*model.Transcoding, error) {
|
||||
sel := r.newSelect().Columns("*").Where(Eq{"target_format": format})
|
||||
var res model.Transcoding
|
||||
err := r.queryOne(sel, &res)
|
||||
return &res, err
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) Put(t *model.Transcoding) error {
|
||||
if !loggedUser(r.ctx).IsAdmin {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
_, err := r.put(t.ID, t)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
return r.count(Select(), r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) Read(id string) (interface{}, error) {
|
||||
return r.Get(id)
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) ReadAll(options ...rest.QueryOptions) (interface{}, error) {
|
||||
sel := r.newSelect(r.parseRestOptions(r.ctx, options...)).Columns("*")
|
||||
res := model.Transcodings{}
|
||||
err := r.queryAll(sel, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) EntityName() string {
|
||||
return "transcoding"
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) NewInstance() interface{} {
|
||||
return &model.Transcoding{}
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) Save(entity interface{}) (string, error) {
|
||||
if !loggedUser(r.ctx).IsAdmin {
|
||||
return "", rest.ErrPermissionDenied
|
||||
}
|
||||
t := entity.(*model.Transcoding)
|
||||
id, err := r.put(t.ID, t)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return "", rest.ErrNotFound
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) Update(id string, entity interface{}, cols ...string) error {
|
||||
if !loggedUser(r.ctx).IsAdmin {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
t := entity.(*model.Transcoding)
|
||||
t.ID = id
|
||||
_, err := r.put(id, t)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *transcodingRepository) Delete(id string) error {
|
||||
if !loggedUser(r.ctx).IsAdmin {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
err := r.delete(Eq{"id": id})
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var _ model.TranscodingRepository = (*transcodingRepository)(nil)
|
||||
var _ rest.Repository = (*transcodingRepository)(nil)
|
||||
var _ rest.Persistable = (*transcodingRepository)(nil)
|
||||
96
persistence/transcoding_repository_test.go
Normal file
96
persistence/transcoding_repository_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TranscodingRepository", func() {
|
||||
var repo model.TranscodingRepository
|
||||
var adminRepo model.TranscodingRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(GinkgoT().Context())
|
||||
ctx = request.WithUser(ctx, regularUser)
|
||||
repo = NewTranscodingRepository(ctx, GetDBXBuilder())
|
||||
|
||||
adminCtx := log.NewContext(GinkgoT().Context())
|
||||
adminCtx = request.WithUser(adminCtx, adminUser)
|
||||
adminRepo = NewTranscodingRepository(adminCtx, GetDBXBuilder())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up any transcoding created during the tests
|
||||
tc, err := adminRepo.FindByFormat("test_format")
|
||||
if err == nil {
|
||||
err = adminRepo.(*transcodingRepository).Delete(tc.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
Describe("Admin User", func() {
|
||||
It("creates a new transcoding", func() {
|
||||
base, err := adminRepo.CountAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
err = adminRepo.Put(&model.Transcoding{ID: "new", Name: "new", TargetFormat: "test_format", DefaultBitRate: 320, Command: "ffmpeg"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
count, err := adminRepo.CountAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(count).To(Equal(base + 1))
|
||||
})
|
||||
|
||||
It("updates an existing transcoding", func() {
|
||||
tr := &model.Transcoding{ID: "upd", Name: "old", TargetFormat: "test_format", DefaultBitRate: 100, Command: "ffmpeg"}
|
||||
Expect(adminRepo.Put(tr)).To(Succeed())
|
||||
tr.Name = "updated"
|
||||
err := adminRepo.Put(tr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
res, err := adminRepo.FindByFormat("test_format")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Name).To(Equal("updated"))
|
||||
})
|
||||
|
||||
It("deletes a transcoding", func() {
|
||||
err := adminRepo.Put(&model.Transcoding{ID: "to-delete", Name: "temp", TargetFormat: "test_format", DefaultBitRate: 256, Command: "ffmpeg"})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = adminRepo.(*transcodingRepository).Delete("to-delete")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = adminRepo.Get("to-delete")
|
||||
Expect(err).To(MatchError(model.ErrNotFound))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Regular User", func() {
|
||||
It("fails to create", func() {
|
||||
err := repo.Put(&model.Transcoding{ID: "bad", Name: "bad", TargetFormat: "test_format", DefaultBitRate: 64, Command: "ffmpeg"})
|
||||
Expect(err).To(Equal(rest.ErrPermissionDenied))
|
||||
})
|
||||
|
||||
It("fails to update", func() {
|
||||
tr := &model.Transcoding{ID: "updreg", Name: "old", TargetFormat: "test_format", DefaultBitRate: 64, Command: "ffmpeg"}
|
||||
Expect(adminRepo.Put(tr)).To(Succeed())
|
||||
|
||||
tr.Name = "bad"
|
||||
err := repo.Put(tr)
|
||||
Expect(err).To(Equal(rest.ErrPermissionDenied))
|
||||
|
||||
//_ = adminRepo.(*transcodingRepository).Delete("updreg")
|
||||
})
|
||||
|
||||
It("fails to delete", func() {
|
||||
tr := &model.Transcoding{ID: "delreg", Name: "temp", TargetFormat: "test_format", DefaultBitRate: 64, Command: "ffmpeg"}
|
||||
Expect(adminRepo.Put(tr)).To(Succeed())
|
||||
|
||||
err := repo.(*transcodingRepository).Delete("delreg")
|
||||
Expect(err).To(Equal(rest.ErrPermissionDenied))
|
||||
|
||||
//_ = adminRepo.(*transcodingRepository).Delete("delreg")
|
||||
})
|
||||
})
|
||||
})
|
||||
63
persistence/user_props_repository.go
Normal file
63
persistence/user_props_repository.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type userPropsRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewUserPropsRepository(ctx context.Context, db dbx.Builder) model.UserPropsRepository {
|
||||
r := &userPropsRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.tableName = "user_props"
|
||||
return r
|
||||
}
|
||||
|
||||
func (r userPropsRepository) Put(userId, key string, value string) error {
|
||||
update := Update(r.tableName).Set("value", value).Where(And{Eq{"user_id": userId}, Eq{"key": key}})
|
||||
count, err := r.executeSQL(update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
insert := Insert(r.tableName).Columns("user_id", "key", "value").Values(userId, key, value)
|
||||
_, err = r.executeSQL(insert)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r userPropsRepository) Get(userId, key string) (string, error) {
|
||||
sel := Select("value").From(r.tableName).Where(And{Eq{"user_id": userId}, Eq{"key": key}})
|
||||
resp := struct {
|
||||
Value string
|
||||
}{}
|
||||
err := r.queryOne(sel, &resp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.Value, nil
|
||||
}
|
||||
|
||||
func (r userPropsRepository) DefaultGet(userId, key string, defaultValue string) (string, error) {
|
||||
value, err := r.Get(userId, key)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return defaultValue, nil
|
||||
}
|
||||
if err != nil {
|
||||
return defaultValue, err
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (r userPropsRepository) Delete(userId, key string) error {
|
||||
return r.delete(And{Eq{"user_id": userId}, Eq{"key": key}})
|
||||
}
|
||||
475
persistence/user_repository.go
Normal file
475
persistence/user_repository.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/consts"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/utils"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type userRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
type dbUser struct {
|
||||
*model.User `structs:",flatten"`
|
||||
LibrariesJSON string `structs:"-" json:"-"`
|
||||
}
|
||||
|
||||
func (u *dbUser) PostScan() error {
|
||||
if u.LibrariesJSON != "" {
|
||||
if err := json.Unmarshal([]byte(u.LibrariesJSON), &u.User.Libraries); err != nil {
|
||||
return fmt.Errorf("parsing user libraries from db: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbUsers []dbUser
|
||||
|
||||
func (us dbUsers) toModels() model.Users {
|
||||
return slice.Map(us, func(u dbUser) model.User { return *u.User })
|
||||
}
|
||||
|
||||
var (
|
||||
once sync.Once
|
||||
encKey []byte
|
||||
)
|
||||
|
||||
func NewUserRepository(ctx context.Context, db dbx.Builder) model.UserRepository {
|
||||
r := &userRepository{}
|
||||
r.ctx = ctx
|
||||
r.db = db
|
||||
r.tableName = "user"
|
||||
r.registerModel(&model.User{}, map[string]filterFunc{
|
||||
"id": idFilter(r.tableName),
|
||||
"password": invalidFilter(ctx),
|
||||
"name": r.withTableName(startsWithFilter),
|
||||
})
|
||||
once.Do(func() {
|
||||
_ = r.initPasswordEncryptionKey()
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
// selectUserWithLibraries returns a SelectBuilder that includes library information
|
||||
func (r *userRepository) selectUserWithLibraries(options ...model.QueryOptions) SelectBuilder {
|
||||
return r.newSelect(options...).
|
||||
Columns(`user.*`,
|
||||
`COALESCE(json_group_array(json_object(
|
||||
'id', library.id,
|
||||
'name', library.name,
|
||||
'path', library.path,
|
||||
'remote_path', library.remote_path,
|
||||
'last_scan_at', library.last_scan_at,
|
||||
'last_scan_started_at', library.last_scan_started_at,
|
||||
'full_scan_in_progress', library.full_scan_in_progress,
|
||||
'updated_at', library.updated_at,
|
||||
'created_at', library.created_at
|
||||
)) FILTER (WHERE library.id IS NOT NULL), '[]') AS libraries_json`).
|
||||
LeftJoin("user_library ul ON user.id = ul.user_id").
|
||||
LeftJoin("library ON ul.library_id = library.id").
|
||||
GroupBy("user.id")
|
||||
}
|
||||
|
||||
func (r *userRepository) CountAll(qo ...model.QueryOptions) (int64, error) {
|
||||
return r.count(Select(), qo...)
|
||||
}
|
||||
|
||||
func (r *userRepository) Get(id string) (*model.User, error) {
|
||||
sel := r.selectUserWithLibraries().Where(Eq{"user.id": id})
|
||||
var res dbUser
|
||||
err := r.queryOne(sel, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.User, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) GetAll(options ...model.QueryOptions) (model.Users, error) {
|
||||
sel := r.selectUserWithLibraries(options...)
|
||||
var res dbUsers
|
||||
err := r.queryAll(sel, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.toModels(), nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Put(u *model.User) error {
|
||||
if u.ID == "" {
|
||||
u.ID = id.NewRandom()
|
||||
}
|
||||
u.UpdatedAt = time.Now()
|
||||
if u.NewPassword != "" {
|
||||
_ = r.encryptPassword(u)
|
||||
}
|
||||
values, err := toSQLArgs(*u)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error converting user to SQL args: %w", err)
|
||||
}
|
||||
delete(values, "current_password")
|
||||
|
||||
// Save/update the user
|
||||
update := Update(r.tableName).Where(Eq{"id": u.ID}).SetMap(values)
|
||||
count, err := r.executeSQL(update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isNewUser := count == 0
|
||||
if isNewUser {
|
||||
values["created_at"] = time.Now()
|
||||
insert := Insert(r.tableName).SetMap(values)
|
||||
_, err = r.executeSQL(insert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-assign all libraries to admin users in a single SQL operation
|
||||
if u.IsAdmin {
|
||||
sql := Expr(
|
||||
"INSERT OR IGNORE INTO user_library (user_id, library_id) SELECT ?, id FROM library",
|
||||
u.ID,
|
||||
)
|
||||
if _, err := r.executeSQL(sql); err != nil {
|
||||
return fmt.Errorf("failed to assign all libraries to admin user: %w", err)
|
||||
}
|
||||
} else if isNewUser { // Only for new regular users
|
||||
// Auto-assign default libraries to new regular users
|
||||
sql := Expr(
|
||||
"INSERT OR IGNORE INTO user_library (user_id, library_id) SELECT ?, id FROM library WHERE default_new_users = true",
|
||||
u.ID,
|
||||
)
|
||||
if _, err := r.executeSQL(sql); err != nil {
|
||||
return fmt.Errorf("failed to assign default libraries to new user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindFirstAdmin() (*model.User, error) {
|
||||
sel := r.selectUserWithLibraries(model.QueryOptions{Sort: "updated_at", Max: 1}).Where(Eq{"user.is_admin": true})
|
||||
var usr dbUser
|
||||
err := r.queryOne(sel, &usr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return usr.User, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByUsername(username string) (*model.User, error) {
|
||||
sel := r.selectUserWithLibraries().Where(Expr("user.user_name = ? COLLATE NOCASE", username))
|
||||
var usr dbUser
|
||||
err := r.queryOne(sel, &usr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return usr.User, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByUsernameWithPassword(username string) (*model.User, error) {
|
||||
usr, err := r.FindByUsername(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = r.decryptPassword(usr)
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateLastLoginAt(id string) error {
|
||||
upd := Update(r.tableName).Where(Eq{"id": id}).Set("last_login_at", time.Now())
|
||||
_, err := r.executeSQL(upd)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateLastAccessAt(id string) error {
|
||||
now := time.Now()
|
||||
upd := Update(r.tableName).Where(Eq{"id": id}).Set("last_access_at", now)
|
||||
_, err := r.executeSQL(upd)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *userRepository) Count(options ...rest.QueryOptions) (int64, error) {
|
||||
usr := loggedUser(r.ctx)
|
||||
if !usr.IsAdmin {
|
||||
return 0, rest.ErrPermissionDenied
|
||||
}
|
||||
return r.CountAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *userRepository) Read(id string) (any, error) {
|
||||
usr := loggedUser(r.ctx)
|
||||
if !usr.IsAdmin && usr.ID != id {
|
||||
return nil, rest.ErrPermissionDenied
|
||||
}
|
||||
usr, err := r.Get(id)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return nil, rest.ErrNotFound
|
||||
}
|
||||
return usr, err
|
||||
}
|
||||
|
||||
func (r *userRepository) ReadAll(options ...rest.QueryOptions) (any, error) {
|
||||
usr := loggedUser(r.ctx)
|
||||
if !usr.IsAdmin {
|
||||
return nil, rest.ErrPermissionDenied
|
||||
}
|
||||
return r.GetAll(r.parseRestOptions(r.ctx, options...))
|
||||
}
|
||||
|
||||
func (r *userRepository) EntityName() string {
|
||||
return "user"
|
||||
}
|
||||
|
||||
func (r *userRepository) NewInstance() any {
|
||||
return &model.User{}
|
||||
}
|
||||
|
||||
func (r *userRepository) Save(entity any) (string, error) {
|
||||
usr := loggedUser(r.ctx)
|
||||
if !usr.IsAdmin {
|
||||
return "", rest.ErrPermissionDenied
|
||||
}
|
||||
u := entity.(*model.User)
|
||||
if err := validateUsernameUnique(r, u); err != nil {
|
||||
return "", err
|
||||
}
|
||||
err := r.Put(u)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return u.ID, err
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(id string, entity any, _ ...string) error {
|
||||
u := entity.(*model.User)
|
||||
u.ID = id
|
||||
usr := loggedUser(r.ctx)
|
||||
if !usr.IsAdmin && usr.ID != u.ID {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
if !usr.IsAdmin {
|
||||
if !conf.Server.EnableUserEditing {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
u.IsAdmin = false
|
||||
u.UserName = usr.UserName
|
||||
}
|
||||
|
||||
// Decrypt the user's existing password before validating. This is required otherwise the existing password entered by the user will never match.
|
||||
if err := r.decryptPassword(usr); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePasswordChange(u, usr); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateUsernameUnique(r, u); err != nil {
|
||||
return err
|
||||
}
|
||||
err := r.Put(u)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func validatePasswordChange(newUser *model.User, logged *model.User) error {
|
||||
err := &rest.ValidationError{Errors: map[string]string{}}
|
||||
if logged.IsAdmin && newUser.ID != logged.ID {
|
||||
return nil
|
||||
}
|
||||
if newUser.NewPassword == "" {
|
||||
if newUser.CurrentPassword == "" {
|
||||
return nil
|
||||
}
|
||||
err.Errors["password"] = "ra.validation.required"
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(logged.Password, consts.PasswordAutogenPrefix) {
|
||||
if newUser.CurrentPassword == "" {
|
||||
err.Errors["currentPassword"] = "ra.validation.required"
|
||||
}
|
||||
if newUser.CurrentPassword != logged.Password {
|
||||
err.Errors["currentPassword"] = "ra.validation.passwordDoesNotMatch"
|
||||
}
|
||||
}
|
||||
if len(err.Errors) > 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateUsernameUnique(r model.UserRepository, u *model.User) error {
|
||||
usr, err := r.FindByUsername(u.UserName)
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usr.ID != u.ID {
|
||||
return &rest.ValidationError{Errors: map[string]string{"userName": "ra.validation.unique"}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(id string) error {
|
||||
usr := loggedUser(r.ctx)
|
||||
if !usr.IsAdmin {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
err := r.delete(Eq{"id": id})
|
||||
if errors.Is(err, model.ErrNotFound) {
|
||||
return rest.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func keyTo32Bytes(input string) []byte {
|
||||
data := sha256.Sum256([]byte(input))
|
||||
return data[0:]
|
||||
}
|
||||
|
||||
func (r *userRepository) initPasswordEncryptionKey() error {
|
||||
encKey = keyTo32Bytes(consts.DefaultEncryptionKey)
|
||||
if conf.Server.PasswordEncryptionKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := keyTo32Bytes(conf.Server.PasswordEncryptionKey)
|
||||
keySum := fmt.Sprintf("%x", sha256.Sum256(key))
|
||||
|
||||
props := NewPropertyRepository(r.ctx, r.db)
|
||||
savedKeySum, err := props.Get(consts.PasswordsEncryptedKey)
|
||||
|
||||
// If passwords are already encrypted
|
||||
if err == nil {
|
||||
if savedKeySum != keySum {
|
||||
log.Error("Password Encryption Key changed! Users won't be able to login!")
|
||||
return errors.New("passwordEncryptionKey changed")
|
||||
}
|
||||
encKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
// if not, try to re-encrypt all current passwords with new encryption key,
|
||||
// assuming they were encrypted with the DefaultEncryptionKey
|
||||
sql := r.newSelect().Columns("id", "user_name", "password")
|
||||
users := model.Users{}
|
||||
err = r.queryAll(sql, &users)
|
||||
if err != nil {
|
||||
log.Error("Could not encrypt all passwords", err)
|
||||
return err
|
||||
}
|
||||
log.Warn("New PasswordEncryptionKey set. Encrypting all passwords", "numUsers", len(users))
|
||||
if err = r.decryptAllPasswords(users); err != nil {
|
||||
return err
|
||||
}
|
||||
encKey = key
|
||||
for i := range users {
|
||||
u := users[i]
|
||||
u.NewPassword = u.Password
|
||||
if err := r.encryptPassword(&u); err == nil {
|
||||
upd := Update(r.tableName).Set("password", u.NewPassword).Where(Eq{"id": u.ID})
|
||||
_, err = r.executeSQL(upd)
|
||||
if err != nil {
|
||||
log.Error("Password NOT encrypted! This may cause problems!", "user", u.UserName, "id", u.ID, err)
|
||||
} else {
|
||||
log.Warn("Password encrypted successfully", "user", u.UserName, "id", u.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = props.Put(consts.PasswordsEncryptedKey, keySum)
|
||||
if err != nil {
|
||||
log.Error("Could not flag passwords as encrypted. It will cause login errors", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encrypts u.NewPassword
|
||||
func (r *userRepository) encryptPassword(u *model.User) error {
|
||||
encPassword, err := utils.Encrypt(r.ctx, encKey, u.NewPassword)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error encrypting user's password", "user", u.UserName, err)
|
||||
return err
|
||||
}
|
||||
u.NewPassword = encPassword
|
||||
return nil
|
||||
}
|
||||
|
||||
// decrypts u.Password
|
||||
func (r *userRepository) decryptPassword(u *model.User) error {
|
||||
plaintext, err := utils.Decrypt(r.ctx, encKey, u.Password)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error decrypting user's password", "user", u.UserName, err)
|
||||
return err
|
||||
}
|
||||
u.Password = plaintext
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) decryptAllPasswords(users model.Users) error {
|
||||
for i := range users {
|
||||
if err := r.decryptPassword(&users[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Library association methods
|
||||
|
||||
func (r *userRepository) GetUserLibraries(userID string) (model.Libraries, error) {
|
||||
sel := Select("l.*").
|
||||
From("library l").
|
||||
Join("user_library ul ON l.id = ul.library_id").
|
||||
Where(Eq{"ul.user_id": userID}).
|
||||
OrderBy("l.name")
|
||||
|
||||
var res model.Libraries
|
||||
err := r.queryAll(sel, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *userRepository) SetUserLibraries(userID string, libraryIDs []int) error {
|
||||
// Remove existing associations
|
||||
delSql := Delete("user_library").Where(Eq{"user_id": userID})
|
||||
if _, err := r.executeSQL(delSql); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add new associations
|
||||
if len(libraryIDs) > 0 {
|
||||
insert := Insert("user_library").Columns("user_id", "library_id")
|
||||
for _, libID := range libraryIDs {
|
||||
insert = insert.Values(userID, libID)
|
||||
}
|
||||
_, err := r.executeSQL(insert)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ model.UserRepository = (*userRepository)(nil)
|
||||
var _ rest.Repository = (*userRepository)(nil)
|
||||
var _ rest.Persistable = (*userRepository)(nil)
|
||||
573
persistence/user_repository_test.go
Normal file
573
persistence/user_repository_test.go
Normal file
@@ -0,0 +1,573 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/consts"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/id"
|
||||
"github.com/navidrome/navidrome/tests"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("UserRepository", func() {
|
||||
var repo model.UserRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
repo = NewUserRepository(log.NewContext(GinkgoT().Context()), GetDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Put/Get/FindByUsername", func() {
|
||||
usr := model.User{
|
||||
ID: "123",
|
||||
UserName: "AdMiN",
|
||||
Name: "Admin",
|
||||
Email: "admin@admin.com",
|
||||
NewPassword: "wordpass",
|
||||
IsAdmin: true,
|
||||
}
|
||||
It("saves the user to the DB", func() {
|
||||
Expect(repo.Put(&usr)).To(BeNil())
|
||||
})
|
||||
It("returns the newly created user", func() {
|
||||
actual, err := repo.Get("123")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Name).To(Equal("Admin"))
|
||||
})
|
||||
It("find the user by case-insensitive username", func() {
|
||||
actual, err := repo.FindByUsername("aDmIn")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Name).To(Equal("Admin"))
|
||||
})
|
||||
It("find the user by username and decrypts the password", func() {
|
||||
actual, err := repo.FindByUsernameWithPassword("aDmIn")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Name).To(Equal("Admin"))
|
||||
Expect(actual.Password).To(Equal("wordpass"))
|
||||
})
|
||||
It("updates the name and keep the same password", func() {
|
||||
usr.Name = "Jane Doe"
|
||||
usr.NewPassword = ""
|
||||
Expect(repo.Put(&usr)).To(BeNil())
|
||||
|
||||
actual, err := repo.FindByUsernameWithPassword("admin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Name).To(Equal("Jane Doe"))
|
||||
Expect(actual.Password).To(Equal("wordpass"))
|
||||
})
|
||||
It("updates password if specified", func() {
|
||||
usr.NewPassword = "newpass"
|
||||
Expect(repo.Put(&usr)).To(BeNil())
|
||||
|
||||
actual, err := repo.FindByUsernameWithPassword("admin")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(actual.Password).To(Equal("newpass"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("validatePasswordChange", func() {
|
||||
var loggedUser *model.User
|
||||
|
||||
BeforeEach(func() {
|
||||
loggedUser = &model.User{ID: "1", UserName: "logan"}
|
||||
})
|
||||
|
||||
It("does nothing if passwords are not specified", func() {
|
||||
user := &model.User{ID: "2", UserName: "johndoe"}
|
||||
err := validatePasswordChange(user, loggedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
Context("Autogenerated password (used with Reverse Proxy Authentication)", func() {
|
||||
var user model.User
|
||||
BeforeEach(func() {
|
||||
loggedUser.IsAdmin = false
|
||||
loggedUser.Password = consts.PasswordAutogenPrefix + id.NewRandom()
|
||||
})
|
||||
It("does nothing if passwords are not specified", func() {
|
||||
user = *loggedUser
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
It("does not requires currentPassword for regular user", func() {
|
||||
user = *loggedUser
|
||||
user.CurrentPassword = ""
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
It("does not requires currentPassword for admin", func() {
|
||||
loggedUser.IsAdmin = true
|
||||
user = *loggedUser
|
||||
user.CurrentPassword = ""
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Logged User is admin", func() {
|
||||
BeforeEach(func() {
|
||||
loggedUser.IsAdmin = true
|
||||
})
|
||||
It("can change other user's passwords without currentPassword", func() {
|
||||
user := &model.User{ID: "2", UserName: "johndoe"}
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(user, loggedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
It("requires currentPassword to change its own", func() {
|
||||
user := *loggedUser
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
var verr *rest.ValidationError
|
||||
errors.As(err, &verr)
|
||||
Expect(verr.Errors).To(HaveLen(1))
|
||||
Expect(verr.Errors).To(HaveKeyWithValue("currentPassword", "ra.validation.required"))
|
||||
})
|
||||
It("does not allow to change password to empty string", func() {
|
||||
loggedUser.Password = "abc123"
|
||||
user := *loggedUser
|
||||
user.CurrentPassword = "abc123"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
var verr *rest.ValidationError
|
||||
errors.As(err, &verr)
|
||||
Expect(verr.Errors).To(HaveLen(1))
|
||||
Expect(verr.Errors).To(HaveKeyWithValue("password", "ra.validation.required"))
|
||||
})
|
||||
It("fails if currentPassword does not match", func() {
|
||||
loggedUser.Password = "abc123"
|
||||
user := *loggedUser
|
||||
user.CurrentPassword = "current"
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
var verr *rest.ValidationError
|
||||
errors.As(err, &verr)
|
||||
Expect(verr.Errors).To(HaveLen(1))
|
||||
Expect(verr.Errors).To(HaveKeyWithValue("currentPassword", "ra.validation.passwordDoesNotMatch"))
|
||||
})
|
||||
It("can change own password if requirements are met", func() {
|
||||
loggedUser.Password = "abc123"
|
||||
user := *loggedUser
|
||||
user.CurrentPassword = "abc123"
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Logged User is a regular user", func() {
|
||||
BeforeEach(func() {
|
||||
loggedUser.IsAdmin = false
|
||||
})
|
||||
It("requires currentPassword", func() {
|
||||
user := *loggedUser
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
var verr *rest.ValidationError
|
||||
errors.As(err, &verr)
|
||||
Expect(verr.Errors).To(HaveLen(1))
|
||||
Expect(verr.Errors).To(HaveKeyWithValue("currentPassword", "ra.validation.required"))
|
||||
})
|
||||
It("does not allow to change password to empty string", func() {
|
||||
loggedUser.Password = "abc123"
|
||||
user := *loggedUser
|
||||
user.CurrentPassword = "abc123"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
var verr *rest.ValidationError
|
||||
errors.As(err, &verr)
|
||||
Expect(verr.Errors).To(HaveLen(1))
|
||||
Expect(verr.Errors).To(HaveKeyWithValue("password", "ra.validation.required"))
|
||||
})
|
||||
It("fails if currentPassword does not match", func() {
|
||||
loggedUser.Password = "abc123"
|
||||
user := *loggedUser
|
||||
user.CurrentPassword = "current"
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
var verr *rest.ValidationError
|
||||
errors.As(err, &verr)
|
||||
Expect(verr.Errors).To(HaveLen(1))
|
||||
Expect(verr.Errors).To(HaveKeyWithValue("currentPassword", "ra.validation.passwordDoesNotMatch"))
|
||||
})
|
||||
It("can change own password if requirements are met", func() {
|
||||
loggedUser.Password = "abc123"
|
||||
user := *loggedUser
|
||||
user.CurrentPassword = "abc123"
|
||||
user.NewPassword = "new"
|
||||
err := validatePasswordChange(&user, loggedUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("validateUsernameUnique", func() {
|
||||
var repo *tests.MockedUserRepo
|
||||
var existingUser *model.User
|
||||
BeforeEach(func() {
|
||||
existingUser = &model.User{ID: "1", UserName: "johndoe"}
|
||||
repo = tests.CreateMockUserRepo()
|
||||
err := repo.Put(existingUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
It("allows unique usernames", func() {
|
||||
var newUser = &model.User{ID: "2", UserName: "unique_username"}
|
||||
err := validateUsernameUnique(repo, newUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
It("returns ValidationError if username already exists", func() {
|
||||
var newUser = &model.User{ID: "2", UserName: "johndoe"}
|
||||
err := validateUsernameUnique(repo, newUser)
|
||||
var verr *rest.ValidationError
|
||||
isValidationError := errors.As(err, &verr)
|
||||
|
||||
Expect(isValidationError).To(BeTrue())
|
||||
Expect(verr.Errors).To(HaveKeyWithValue("userName", "ra.validation.unique"))
|
||||
})
|
||||
It("returns generic error if repository call fails", func() {
|
||||
repo.Error = errors.New("fake error")
|
||||
|
||||
var newUser = &model.User{ID: "2", UserName: "newuser"}
|
||||
err := validateUsernameUnique(repo, newUser)
|
||||
Expect(err).To(MatchError("fake error"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Library Association Methods", func() {
|
||||
var userID string
|
||||
var library1, library2 model.Library
|
||||
|
||||
BeforeEach(func() {
|
||||
// Create a test user first to satisfy foreign key constraints
|
||||
testUser := model.User{
|
||||
ID: "test-user-id",
|
||||
UserName: "testuser",
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
NewPassword: "password",
|
||||
IsAdmin: false,
|
||||
}
|
||||
Expect(repo.Put(&testUser)).To(BeNil())
|
||||
userID = testUser.ID
|
||||
|
||||
library1 = model.Library{ID: 0, Name: "Library 500", Path: "/path/500"}
|
||||
library2 = model.Library{ID: 0, Name: "Library 501", Path: "/path/501"}
|
||||
|
||||
// Create test libraries
|
||||
libRepo := NewLibraryRepository(log.NewContext(context.TODO()), GetDBXBuilder())
|
||||
Expect(libRepo.Put(&library1)).To(BeNil())
|
||||
Expect(libRepo.Put(&library2)).To(BeNil())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up user-library associations to ensure test isolation
|
||||
_ = repo.SetUserLibraries(userID, []int{})
|
||||
|
||||
// Clean up test libraries to ensure isolation between test groups
|
||||
libRepo := NewLibraryRepository(log.NewContext(context.TODO()), GetDBXBuilder())
|
||||
_ = libRepo.(*libraryRepository).delete(squirrel.Eq{"id": []int{library1.ID, library2.ID}})
|
||||
})
|
||||
|
||||
Describe("GetUserLibraries", func() {
|
||||
It("returns empty list when user has no library associations", func() {
|
||||
libraries, err := repo.GetUserLibraries("non-existent-user")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libraries).To(HaveLen(0))
|
||||
})
|
||||
|
||||
It("returns user's associated libraries", func() {
|
||||
err := repo.SetUserLibraries(userID, []int{library1.ID, library2.ID})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
libraries, err := repo.GetUserLibraries(userID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libraries).To(HaveLen(2))
|
||||
|
||||
libIDs := []int{libraries[0].ID, libraries[1].ID}
|
||||
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("SetUserLibraries", func() {
|
||||
It("sets user's library associations", func() {
|
||||
libraryIDs := []int{library1.ID, library2.ID}
|
||||
err := repo.SetUserLibraries(userID, libraryIDs)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
libraries, err := repo.GetUserLibraries(userID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libraries).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("replaces existing associations", func() {
|
||||
// Set initial associations
|
||||
err := repo.SetUserLibraries(userID, []int{library1.ID, library2.ID})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Replace with just one library
|
||||
err = repo.SetUserLibraries(userID, []int{library1.ID})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
libraries, err := repo.GetUserLibraries(userID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libraries).To(HaveLen(1))
|
||||
Expect(libraries[0].ID).To(Equal(library1.ID))
|
||||
})
|
||||
|
||||
It("removes all associations when passed empty slice", func() {
|
||||
// Set initial associations
|
||||
err := repo.SetUserLibraries(userID, []int{library1.ID, library2.ID})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Remove all
|
||||
err = repo.SetUserLibraries(userID, []int{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
libraries, err := repo.GetUserLibraries(userID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libraries).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Admin User Auto-Assignment", func() {
|
||||
var (
|
||||
libRepo model.LibraryRepository
|
||||
library1 model.Library
|
||||
library2 model.Library
|
||||
initialLibCount int
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
libRepo = NewLibraryRepository(log.NewContext(context.TODO()), GetDBXBuilder())
|
||||
|
||||
// Count initial libraries
|
||||
existingLibs, err := libRepo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
initialLibCount = len(existingLibs)
|
||||
|
||||
library1 = model.Library{ID: 0, Name: "Admin Test Library 1", Path: "/admin/test/path1"}
|
||||
library2 = model.Library{ID: 0, Name: "Admin Test Library 2", Path: "/admin/test/path2"}
|
||||
|
||||
// Create test libraries
|
||||
Expect(libRepo.Put(&library1)).To(BeNil())
|
||||
Expect(libRepo.Put(&library2)).To(BeNil())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up test libraries and their associations
|
||||
_ = libRepo.(*libraryRepository).delete(squirrel.Eq{"id": []int{library1.ID, library2.ID}})
|
||||
|
||||
// Clean up user-library associations for these test libraries
|
||||
_, _ = repo.(*userRepository).executeSQL(squirrel.Delete("user_library").Where(squirrel.Eq{"library_id": []int{library1.ID, library2.ID}}))
|
||||
})
|
||||
|
||||
It("automatically assigns all libraries to admin users when created", func() {
|
||||
adminUser := model.User{
|
||||
ID: "admin-user-id-1",
|
||||
UserName: "adminuser1",
|
||||
Name: "Admin User",
|
||||
Email: "admin1@example.com",
|
||||
NewPassword: "password",
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
err := repo.Put(&adminUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Admin should automatically have access to all libraries (including existing ones)
|
||||
libraries, err := repo.GetUserLibraries(adminUser.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libraries).To(HaveLen(initialLibCount + 2)) // Initial libraries + our 2 test libraries
|
||||
|
||||
libIDs := make([]int, len(libraries))
|
||||
for i, lib := range libraries {
|
||||
libIDs[i] = lib.ID
|
||||
}
|
||||
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
|
||||
})
|
||||
|
||||
It("automatically assigns all libraries to admin users when updated", func() {
|
||||
// Create regular user first
|
||||
regularUser := model.User{
|
||||
ID: "regular-user-id-1",
|
||||
UserName: "regularuser1",
|
||||
Name: "Regular User",
|
||||
Email: "regular1@example.com",
|
||||
NewPassword: "password",
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
err := repo.Put(®ularUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Give them access to just one library
|
||||
err = repo.SetUserLibraries(regularUser.ID, []int{library1.ID})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Promote to admin
|
||||
regularUser.IsAdmin = true
|
||||
err = repo.Put(®ularUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Should now have access to all libraries (including existing ones)
|
||||
libraries, err := repo.GetUserLibraries(regularUser.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libraries).To(HaveLen(initialLibCount + 2)) // Initial libraries + our 2 test libraries
|
||||
|
||||
libIDs := make([]int, len(libraries))
|
||||
for i, lib := range libraries {
|
||||
libIDs[i] = lib.ID
|
||||
}
|
||||
// Should include our test libraries plus all existing ones
|
||||
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
|
||||
})
|
||||
|
||||
It("assigns default libraries to regular users", func() {
|
||||
regularUser := model.User{
|
||||
ID: "regular-user-id-2",
|
||||
UserName: "regularuser2",
|
||||
Name: "Regular User",
|
||||
Email: "regular2@example.com",
|
||||
NewPassword: "password",
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
err := repo.Put(®ularUser)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Regular user should be assigned to default libraries (library ID 1 from migration)
|
||||
libraries, err := repo.GetUserLibraries(regularUser.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(libraries).To(HaveLen(1))
|
||||
Expect(libraries[0].ID).To(Equal(1))
|
||||
Expect(libraries[0].DefaultNewUsers).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Libraries Field Population", func() {
|
||||
var (
|
||||
libRepo model.LibraryRepository
|
||||
library1 model.Library
|
||||
library2 model.Library
|
||||
testUser model.User
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
libRepo = NewLibraryRepository(log.NewContext(context.TODO()), GetDBXBuilder())
|
||||
library1 = model.Library{ID: 0, Name: "Field Test Library 1", Path: "/field/test/path1"}
|
||||
library2 = model.Library{ID: 0, Name: "Field Test Library 2", Path: "/field/test/path2"}
|
||||
|
||||
// Create test libraries
|
||||
Expect(libRepo.Put(&library1)).To(BeNil())
|
||||
Expect(libRepo.Put(&library2)).To(BeNil())
|
||||
|
||||
// Create test user
|
||||
testUser = model.User{
|
||||
ID: "field-test-user",
|
||||
UserName: "fieldtestuser",
|
||||
Name: "Field Test User",
|
||||
Email: "fieldtest@example.com",
|
||||
NewPassword: "password",
|
||||
IsAdmin: false,
|
||||
}
|
||||
Expect(repo.Put(&testUser)).To(BeNil())
|
||||
|
||||
// Assign libraries to user
|
||||
Expect(repo.SetUserLibraries(testUser.ID, []int{library1.ID, library2.ID})).To(BeNil())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
// Clean up test libraries and their associations
|
||||
_ = libRepo.(*libraryRepository).delete(squirrel.Eq{"id": []int{library1.ID, library2.ID}})
|
||||
_ = repo.(*userRepository).delete(squirrel.Eq{"id": testUser.ID})
|
||||
|
||||
// Clean up user-library associations for these test libraries
|
||||
_, _ = repo.(*userRepository).executeSQL(squirrel.Delete("user_library").Where(squirrel.Eq{"library_id": []int{library1.ID, library2.ID}}))
|
||||
})
|
||||
|
||||
It("populates Libraries field when getting a single user", func() {
|
||||
user, err := repo.Get(testUser.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(user.Libraries).To(HaveLen(2))
|
||||
|
||||
libIDs := []int{user.Libraries[0].ID, user.Libraries[1].ID}
|
||||
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
|
||||
|
||||
// Check that library details are properly populated
|
||||
for _, lib := range user.Libraries {
|
||||
switch lib.ID {
|
||||
case library1.ID:
|
||||
Expect(lib.Name).To(Equal("Field Test Library 1"))
|
||||
Expect(lib.Path).To(Equal("/field/test/path1"))
|
||||
case library2.ID:
|
||||
Expect(lib.Name).To(Equal("Field Test Library 2"))
|
||||
Expect(lib.Path).To(Equal("/field/test/path2"))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("populates Libraries field when getting all users", func() {
|
||||
users, err := repo.(*userRepository).GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Find our test user in the results
|
||||
found := slices.IndexFunc(users, func(u model.User) bool { return u.ID == testUser.ID })
|
||||
Expect(found).ToNot(Equal(-1))
|
||||
|
||||
foundUser := users[found]
|
||||
Expect(foundUser).ToNot(BeNil())
|
||||
Expect(foundUser.Libraries).To(HaveLen(2))
|
||||
|
||||
libIDs := []int{foundUser.Libraries[0].ID, foundUser.Libraries[1].ID}
|
||||
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
|
||||
})
|
||||
|
||||
It("populates Libraries field when finding user by username", func() {
|
||||
user, err := repo.FindByUsername(testUser.UserName)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(user.Libraries).To(HaveLen(2))
|
||||
|
||||
libIDs := []int{user.Libraries[0].ID, user.Libraries[1].ID}
|
||||
Expect(libIDs).To(ContainElements(library1.ID, library2.ID))
|
||||
})
|
||||
|
||||
It("returns default Libraries array for new regular users", func() {
|
||||
// Create a user with no explicit library associations - should get default libraries
|
||||
userWithoutLibs := model.User{
|
||||
ID: "no-libs-user",
|
||||
UserName: "nolibsuser",
|
||||
Name: "No Libs User",
|
||||
Email: "nolibs@example.com",
|
||||
NewPassword: "password",
|
||||
IsAdmin: false,
|
||||
}
|
||||
Expect(repo.Put(&userWithoutLibs)).To(BeNil())
|
||||
defer func() { _ = repo.(*userRepository).delete(squirrel.Eq{"id": userWithoutLibs.ID}) }()
|
||||
|
||||
user, err := repo.Get(userWithoutLibs.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(user.Libraries).ToNot(BeNil())
|
||||
// Regular users should be assigned to default libraries (library ID 1 from migration)
|
||||
Expect(user.Libraries).To(HaveLen(1))
|
||||
Expect(user.Libraries[0].ID).To(Equal(1))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("filters", func() {
|
||||
It("qualifies id filter with table name", func() {
|
||||
r := repo.(*userRepository)
|
||||
qo := r.parseRestOptions(r.ctx, rest.QueryOptions{Filters: map[string]any{"id": "123"}})
|
||||
sel := r.selectUserWithLibraries(qo)
|
||||
query, _, err := r.toSQL(sel)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(query).To(ContainSubstring("user.id = {:p0}"))
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user