diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 0d7d507575..264dbef6f5 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -144,14 +144,14 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) var table *ast.TableName switch n := raw.Stmt.(type) { case *ast.InsertStmt: - if err := check(validate.InsertStmt(n)); err != nil { - return nil, err - } var err error table, err = ParseTableName(n.Relation) if err := check(err); err != nil { return nil, err } + if err := check(validate.InsertStmt(c.catalog, table, n)); err != nil { + return nil, err + } } if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil { diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/exec.json b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/exec.json index ee1b7ecd9e..2e996ca79d 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/exec.json +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/exec.json @@ -1,3 +1,3 @@ { - "contexts": ["managed-db"] + "contexts": ["base"] } diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql index 3f9e9d9b86..2484d998b6 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql @@ -1,4 +1,23 @@ --- name: UpsertServer :exec -INSERT INTO servers(code, name) VALUES ($1, $2) -ON CONFLICT (code) -DO UPDATE SET name_typo = 1111; \ No newline at end of file +-- name: UpsertServerSetColumnTypo :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT (code) +DO UPDATE SET name_typo = 1111; + +-- name: UpsertServerConflictTargetTypo :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT (code_typo) +DO UPDATE SET name = 1111; + +-- name: UpsertServerExcludedColumnTypo :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT (code) +DO UPDATE SET name = EXCLUDED.name_typo; + +-- name: UpsertServerMissingConflictTarget :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT DO UPDATE SET name = EXCLUDED.name; + +-- name: UpsertServerOnConstraintExcludedTypo :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT ON CONSTRAINT servers_pkey DO UPDATE SET name = EXCLUDED.name_typo; + diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql index 3ff1ccd6b3..c3dec12e49 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql @@ -1,4 +1,4 @@ CREATE TABLE servers ( - code varchar PRIMARY KEY, - name text NOT NULL -); \ No newline at end of file + code varchar PRIMARY KEY, + name text NOT NULL +); diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt index adbb13a418..b9692daa64 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt @@ -1,2 +1,6 @@ # package querytest -query.sql:4:15: column "name_typo" of relation "servers" does not exist \ No newline at end of file +query.sql:4:15: column "name_typo" of relation "servers" does not exist +query.sql:8:13: column "code_typo" of relation "servers" does not exist +query.sql:14:22: column "name_typo" of relation "EXCLUDED" does not exist +query.sql:17:1: ON CONFLICT DO UPDATE requires inference specification or constraint name +query.sql:22:61: column "name_typo" of relation "EXCLUDED" does not exist diff --git a/internal/sql/validate/insert_stmt.go b/internal/sql/validate/insert_stmt.go index dd8041ea23..236ddbfceb 100644 --- a/internal/sql/validate/insert_stmt.go +++ b/internal/sql/validate/insert_stmt.go @@ -1,11 +1,16 @@ package validate import ( + "strings" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) -func InsertStmt(stmt *ast.InsertStmt) error { +const excludedTable = "EXCLUDED" + +func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) error { sel, ok := stmt.SelectStmt.(*ast.SelectStmt) if !ok { return nil @@ -35,5 +40,102 @@ func InsertStmt(stmt *ast.InsertStmt) error { Message: "INSERT has more expressions than target columns", } } + + return onConflictClause(c, fqn, stmt) +} + +// onConflictClause validates an ON CONFLICT DO UPDATE clause against the target +// table. It checks: +// - ON CONFLICT (col, ...) conflict target columns exist +// - DO UPDATE SET col = ... assignment target columns exist +// - EXCLUDED.col references exist +func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt) error { + if fqn == nil || n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { + return nil + } + + table, err := c.GetTable(fqn) + if err != nil { + return err + } + + // Build set of column names for existence checks. + colNames := make(map[string]struct{}, len(table.Columns)) + for _, col := range table.Columns { + colNames[col.Name] = struct{}{} + } + + // DO UPDATE requires a conflict target: ON CONFLICT (col) or ON CONFLICT ON CONSTRAINT name. + if n.OnConflictClause.Infer == nil { + return &sqlerr.Error{ + Code: "42601", + Message: "ON CONFLICT DO UPDATE requires inference specification or constraint name", + } + } + + // Validate ON CONFLICT (col, ...) conflict target columns. + if n.OnConflictClause.Infer.IndexElems != nil { + for _, item := range n.OnConflictClause.Infer.IndexElems.Items { + elem, ok := item.(*ast.IndexElem) + if !ok || elem.Name == nil { + continue + } + + if _, exists := colNames[*elem.Name]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name) + e.Location = n.OnConflictClause.Infer.Location + return e + } + } + } + + // Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references. + if n.OnConflictClause.TargetList == nil { + return nil + } + + for _, item := range n.OnConflictClause.TargetList.Items { + target, ok := item.(*ast.ResTarget) + if !ok || target.Name == nil { + continue + } + + if _, exists := colNames[*target.Name]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name) + e.Location = target.Location + return e + } + + if ref, ok := target.Val.(*ast.ColumnRef); ok { + if excludedCol, ok := excludedColumnRef(ref); ok { + if _, exists := colNames[excludedCol]; !exists { + e := sqlerr.ColumnNotFound(excludedTable, excludedCol) + e.Location = ref.Location + return e + } + } + } + } + return nil } + +// excludedColumnRef returns the column name if the ColumnRef is an EXCLUDED.col +// reference, and ok=true. Returns "", false otherwise. +func excludedColumnRef(ref *ast.ColumnRef) (string, bool) { + if ref.Fields == nil || len(ref.Fields.Items) != 2 { + return "", false + } + + first, ok := ref.Fields.Items[0].(*ast.String) + if !ok || !strings.EqualFold(first.Str, excludedTable) { + return "", false + } + + second, ok := ref.Fields.Items[1].(*ast.String) + if !ok { + return "", false + } + + return second.Str, true +}