diff --git a/models/unittest/fixtures.go b/models/unittest/fixtures.go index c653ce1e38..9ce0909589 100644 --- a/models/unittest/fixtures.go +++ b/models/unittest/fixtures.go @@ -7,6 +7,7 @@ package unittest import ( "fmt" "os" + "path/filepath" "time" "code.gitea.io/gitea/models/db" @@ -28,6 +29,16 @@ func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) { return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine) } +func OverrideFixtures(opts FixturesOptions, engine ...*xorm.Engine) func() { + old := fixturesLoader + if err := InitFixtures(opts, engine...); err != nil { + panic(err) + } + return func() { + fixturesLoader = old + } +} + // InitFixtures initialize test fixtures for a test database func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { e := GetXORMEngine(engine...) @@ -37,6 +48,12 @@ func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { } else { fixtureOptionFiles = testfixtures.Files(opts.Files...) } + var fixtureOptionDirs []func(*testfixtures.Loader) error + if opts.Dirs != nil { + for _, dir := range opts.Dirs { + fixtureOptionDirs = append(fixtureOptionDirs, testfixtures.Directory(filepath.Join(opts.Base, dir))) + } + } dialect := "unknown" switch e.Dialect().URI().DBType { case schemas.POSTGRES: @@ -57,6 +74,7 @@ func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { testfixtures.DangerousSkipTestDatabaseCheck(), fixtureOptionFiles, } + loaderOptions = append(loaderOptions, fixtureOptionDirs...) if e.Dialect().URI().DBType == schemas.POSTGRES { loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences()) diff --git a/models/unittest/testdb.go b/models/unittest/testdb.go index 6db99cd393..c6ff292f05 100644 --- a/models/unittest/testdb.go +++ b/models/unittest/testdb.go @@ -209,6 +209,8 @@ func MainTest(m *testing.M, testOpts ...*TestOptions) { type FixturesOptions struct { Dir string Files []string + Dirs []string + Base string } // CreateTestEngine creates a memory database and loads the fixture data from fixturesDir diff --git a/tests/test_utils.go b/tests/test_utils.go index 50049e73f0..8e456783cf 100644 --- a/tests/test_utils.go +++ b/tests/test_utils.go @@ -267,3 +267,13 @@ func PrintCurrentTest(t testing.TB, skip ...int) func() { func Printf(format string, args ...any) { testlogger.Printf(format, args...) } + +func AddFixtures(dirs ...string) func() { + return unittest.OverrideFixtures( + unittest.FixturesOptions{ + Dir: filepath.Join(filepath.Dir(setting.AppPath), "models/fixtures/"), + Base: filepath.Dir(setting.AppPath), + Dirs: dirs, + }, + ) +}