diff options
Diffstat (limited to 'internal/postgres/insert_module.go')
| -rw-r--r-- | internal/postgres/insert_module.go | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/internal/postgres/insert_module.go b/internal/postgres/insert_module.go index 9e81744b..39065541 100644 --- a/internal/postgres/insert_module.go +++ b/internal/postgres/insert_module.go @@ -23,6 +23,7 @@ import ( "golang.org/x/pkgsite/internal" "golang.org/x/pkgsite/internal/database" "golang.org/x/pkgsite/internal/derrors" + "golang.org/x/pkgsite/internal/licenses" "golang.org/x/pkgsite/internal/log" "golang.org/x/pkgsite/internal/stdlib" "golang.org/x/pkgsite/internal/version" @@ -51,9 +52,6 @@ func (db *DB) InsertModule(ctx context.Context, m *internal.Module) (err error) // inserted. Rows that currently exist should not be missing from the // new module. We want to be sure that we will overwrite every row that // pertains to the module. - if err := db.compareLicenses(ctx, m); err != nil { - return err - } if err := db.comparePaths(ctx, m); err != nil { return err } @@ -81,6 +79,13 @@ func (db *DB) saveModule(ctx context.Context, m *internal.Module) (err error) { if err != nil { return err } + // Compare existing data from the database, and the module to be + // inserted. Rows that currently exist should not be missing from the + // new module. We want to be sure that we will overwrite every row that + // pertains to the module. + if err := db.compareLicenses(ctx, moduleID, m.Licenses); err != nil { + return err + } if err := insertLicenses(ctx, tx, m, moduleID); err != nil { return err } @@ -561,15 +566,15 @@ func validateModule(m *internal.Module) (err error) { // compareLicenses compares m.Licenses with the existing licenses for // m.ModulePath and m.Version in the database. It returns an error if there // are licenses in the licenses table that are not present in m.Licenses. -func (db *DB) compareLicenses(ctx context.Context, m *internal.Module) (err error) { - defer derrors.Wrap(&err, "compareLicenses(ctx, %q, %q)", m.ModulePath, m.Version) - dbLicenses, err := db.getModuleLicenses(ctx, m.ModulePath, m.Version) +func (db *DB) compareLicenses(ctx context.Context, moduleID int, lics []*licenses.License) (err error) { + defer derrors.Wrap(&err, "compareLicenses(ctx, %d)", moduleID) + dbLicenses, err := db.getModuleLicenses(ctx, moduleID) if err != nil { return err } set := map[string]bool{} - for _, l := range m.Licenses { + for _, l := range lics { set[l.FilePath] = true } for _, l := range dbLicenses { |
