From b7c944879a193aa16e3b951f7b24b11b4d351c28 Mon Sep 17 00:00:00 2001 From: Nikolay Kuznetsov Date: Thu, 2 Apr 2026 02:20:07 +0300 Subject: [PATCH 1/7] validateOnConflictColumns --- internal/compiler/analyze.go | 84 +++++++++++++++++++ .../postgresql/pgx/exec.json | 2 +- .../postgresql/pgx/query.sql | 18 +++- .../postgresql/pgx/stderr.txt | 4 +- 4 files changed, 102 insertions(+), 6 deletions(-) diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 0d7d507575..c4454e7d4b 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -9,6 +9,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/named" "github.com/sqlc-dev/sqlc/internal/sql/rewrite" + "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" "github.com/sqlc-dev/sqlc/internal/sql/validate" ) @@ -152,6 +153,9 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(err); err != nil { return nil, err } + if err := check(c.validateOnConflictColumns(n)); err != nil { + return nil, err + } } if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil { @@ -213,3 +217,83 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) Named: namedParams, }, rerr } + +// validateOnConflictColumns checks column names in an ON CONFLICT DO UPDATE +// clause against the target table: +// - ON CONFLICT (col, ...) conflict target columns +// - DO UPDATE SET col = ... assignment target columns +// - EXCLUDED.col references in assignment values +func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error { + if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { + return nil + } + fqn, err := ParseTableName(n.Relation) + if err != nil { + return err + } + table, err := c.catalog.GetTable(fqn) + if err != nil { + return err + } + colSet := make(map[string]struct{}, len(table.Columns)) + for _, col := range table.Columns { + colSet[col.Name] = struct{}{} + } + + // Validate ON CONFLICT (col, ...) conflict target columns. + if n.OnConflictClause.Infer != nil { + for _, item := range n.OnConflictClause.Infer.IndexElems.Items { + elem, ok := item.(*ast.IndexElem) + if !ok || elem.Name == nil { + continue + } + if _, exists := colSet[*elem.Name]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name) + e.Location = n.OnConflictClause.Infer.Location + return e + } + } + } + + // Validate DO UPDATE SET col = ... and EXCLUDED.col references. + for _, item := range n.OnConflictClause.TargetList.Items { + target, ok := item.(*ast.ResTarget) + if !ok || target.Name == nil { + continue + } + // Validate the assignment target column. + if _, exists := colSet[*target.Name]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name) + e.Location = target.Location + return e + } + // Validate EXCLUDED.col references in the assigned value. + if ref, ok := target.Val.(*ast.ColumnRef); ok { + if col, ok := excludedColumn(ref); ok { + if _, exists := colSet[col]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, col) + e.Location = ref.Location + return e + } + } + } + } + return nil +} + +// excludedColumn returns the column name if the ColumnRef is an EXCLUDED.col +// reference, and ok=true. Returns "", false otherwise. +func excludedColumn(ref *ast.ColumnRef) (string, bool) { + if ref.Fields == nil || len(ref.Fields.Items) != 2 { + return "", false + } + first, ok := ref.Fields.Items[0].(*ast.String) + if !ok || first.Str != "excluded" { + return "", false + } + second, ok := ref.Fields.Items[1].(*ast.String) + if !ok { + return "", false + } + return second.Str, true +} diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/exec.json b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/exec.json index ee1b7ecd9e..2e996ca79d 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/exec.json +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/exec.json @@ -1,3 +1,3 @@ { - "contexts": ["managed-db"] + "contexts": ["base"] } diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql index 3f9e9d9b86..256bea1c2c 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql @@ -1,4 +1,14 @@ --- name: UpsertServer :exec -INSERT INTO servers(code, name) VALUES ($1, $2) -ON CONFLICT (code) -DO UPDATE SET name_typo = 1111; \ No newline at end of file +-- name: UpsertServerSetColumnTypo :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT (code) +DO UPDATE SET name_typo = 1111; + +-- name: UpsertServerConflictTargetTypo :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT (code_typo) +DO UPDATE SET name = 1111; + +-- name: UpsertServerExcludedColumnTypo :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT (code) +DO UPDATE SET name = EXCLUDED.name_typo; diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt index adbb13a418..6f6ebd87ad 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt @@ -1,2 +1,4 @@ # package querytest -query.sql:4:15: column "name_typo" of relation "servers" does not exist \ No newline at end of file +query.sql:4:15: column "name_typo" of relation "servers" does not exist +query.sql:8:13: column "code_typo" of relation "servers" does not exist +query.sql:14:22: column "name_typo" of relation "servers" does not exist From fca278db92e597dba76b151164f4fe55aa881c16 Mon Sep 17 00:00:00 2001 From: Nikolay Kuznetsov Date: Thu, 2 Apr 2026 02:20:07 +0300 Subject: [PATCH 2/7] validateOnConflictTypes --- internal/compiler/analyze.go | 82 +++++++++++++++++++ .../postgresql/pgx/query.sql | 11 +++ .../postgresql/pgx/schema.sql | 7 +- .../postgresql/pgx/stderr.txt | 2 + internal/engine/postgresql/utils.go | 69 ++++++++++++++++ 5 files changed, 168 insertions(+), 3 deletions(-) diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index c4454e7d4b..16cd7713fc 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -1,10 +1,12 @@ package compiler import ( + "fmt" "sort" analyzer "github.com/sqlc-dev/sqlc/internal/analysis" "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/engine/postgresql" "github.com/sqlc-dev/sqlc/internal/source" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/named" @@ -143,6 +145,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar) var table *ast.TableName + var insertStmt *ast.InsertStmt switch n := raw.Stmt.(type) { case *ast.InsertStmt: if err := check(validate.InsertStmt(n)); err != nil { @@ -156,6 +159,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(c.validateOnConflictColumns(n)); err != nil { return nil, err } + insertStmt = n } if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil { @@ -189,6 +193,11 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(err); err != nil { return nil, err } + if c.conf.Engine == config.EnginePostgreSQL { + if err := check(c.validateOnConflictTypes(insertStmt, params)); err != nil { + return nil, err + } + } cols, err := c.outputColumns(qc, raw.Stmt) if err := check(err); err != nil { return nil, err @@ -281,6 +290,79 @@ func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error { return nil } +// validateOnConflictTypes checks that $N params used in DO UPDATE SET assignments +// are type-compatible with the target column, based on the type already resolved +// for that param from the INSERT columns. +func (c *Compiler) validateOnConflictTypes(n *ast.InsertStmt, params []Parameter) error { + if n == nil || n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { + return nil + } + fqn, err := ParseTableName(n.Relation) + if err != nil { + return err + } + table, err := c.catalog.GetTable(fqn) + if err != nil { + return err + } + + // Build param number → resolved DataType string from already-resolved params. + // Skips params with "any" type (unresolved). + paramDataTypes := make(map[int]string, len(params)) + for i := range params { + if params[i].Column != nil && params[i].Column.DataType != "any" { + paramDataTypes[params[i].Number] = params[i].Column.DataType + } + } + + // Build column name → DataType string using the same dataType() function + // used by resolveCatalogRefs, so formats are comparable. + colDataTypes := make(map[string]string, len(table.Columns)) + for _, col := range table.Columns { + colDataTypes[col.Name] = dataType(&col.Type) + } + + for _, item := range n.OnConflictClause.TargetList.Items { + target, ok := item.(*ast.ResTarget) + if !ok || target.Name == nil { + continue + } + colDT, ok := colDataTypes[*target.Name] + if !ok { + continue + } + switch val := target.Val.(type) { + case *ast.ParamRef: + paramDT, ok := paramDataTypes[val.Number] + if !ok { + continue + } + if postgresql.TypeFamily(paramDT) != postgresql.TypeFamily(colDT) { + return &sqlerr.Error{ + Message: fmt.Sprintf("parameter $%d has type %q but column %q has type %q", val.Number, paramDT, *target.Name, colDT), + Location: val.Location, + } + } + case *ast.ColumnRef: + excludedCol, ok := excludedColumn(val) + if !ok { + continue + } + excludedDT, ok := colDataTypes[excludedCol] + if !ok { + continue + } + if postgresql.TypeFamily(excludedDT) != postgresql.TypeFamily(colDT) { + return &sqlerr.Error{ + Message: fmt.Sprintf("EXCLUDED.%s has type %q but column %q has type %q", excludedCol, excludedDT, *target.Name, colDT), + Location: val.Location, + } + } + } + } + return nil +} + // excludedColumn returns the column name if the ColumnRef is an EXCLUDED.col // reference, and ok=true. Returns "", false otherwise. func excludedColumn(ref *ast.ColumnRef) (string, bool) { diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql index 256bea1c2c..adfd9ae2ec 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql @@ -12,3 +12,14 @@ DO UPDATE SET name = 1111; INSERT INTO servers(code, name) VALUES ($1, $2) ON CONFLICT (code) DO UPDATE SET name = EXCLUDED.name_typo; + +-- name: UpsertServerSetParamTypeMismatch :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT (code) +DO UPDATE SET count = $2; + +-- name: UpsertServerExcludedTypeMismatch :exec +INSERT INTO servers(code, name, count) VALUES ($1, $2, $3) +ON CONFLICT (code) +DO UPDATE SET count = EXCLUDED.code; + diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql index 3ff1ccd6b3..ff73489c3a 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql @@ -1,4 +1,5 @@ CREATE TABLE servers ( - code varchar PRIMARY KEY, - name text NOT NULL -); \ No newline at end of file + code varchar PRIMARY KEY, + name text NOT NULL, + count integer NOT NULL DEFAULT 0 +); diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt index 6f6ebd87ad..74e56e8136 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt @@ -2,3 +2,5 @@ query.sql:4:15: column "name_typo" of relation "servers" does not exist query.sql:8:13: column "code_typo" of relation "servers" does not exist query.sql:14:22: column "name_typo" of relation "servers" does not exist +query.sql:19:23: parameter $2 has type "text" but column "count" has type "pg_catalog.int4" +query.sql:24:23: EXCLUDED.code has type "pg_catalog.varchar" but column "count" has type "pg_catalog.int4" diff --git a/internal/engine/postgresql/utils.go b/internal/engine/postgresql/utils.go index b8d49b1a97..d5856439f5 100644 --- a/internal/engine/postgresql/utils.go +++ b/internal/engine/postgresql/utils.go @@ -39,6 +39,75 @@ func IsNamedParamSign(node *nodes.Node) bool { return ok && joinNodes(expr.AExpr.Name, ".") == "@" } +// TypeFamily maps a PostgreSQL DataType string to a canonical type family name, +// grouping compatible type aliases together. This is used for type compatibility +// checks rather than exact string equality, because PostgreSQL considers many +// type aliases assignment-compatible (e.g. text and varchar are both string types). +// +// The groupings are derived from postgresType() in +// internal/codegen/golang/postgresql_type.go, which maps these aliases to the +// same Go type. We cannot call postgresType() directly for type compatibility +// checking because it requires *plugin.GenerateRequest — a protobuf codegen +// struct constructed after compilation — and driver-specific opts.Options. +func TypeFamily(dt string) string { + switch dt { + case "serial", "serial4", "pg_catalog.serial4", + "integer", "int", "int4", "pg_catalog.int4": + return "int32" + case "bigserial", "serial8", "pg_catalog.serial8", + "bigint", "int8", "pg_catalog.int8", + "interval", "pg_catalog.interval": + return "int64" + case "smallserial", "serial2", "pg_catalog.serial2", + "smallint", "int2", "pg_catalog.int2": + return "int16" + case "float", "double precision", "float8", "pg_catalog.float8": + return "float64" + case "real", "float4", "pg_catalog.float4": + return "float32" + case "numeric", "pg_catalog.numeric", "money": + return "numeric" + case "boolean", "bool", "pg_catalog.bool": + return "bool" + case "json", "pg_catalog.json": + return "json" + case "jsonb", "pg_catalog.jsonb": + return "jsonb" + case "bytea", "blob", "pg_catalog.bytea": + return "bytes" + case "date": + return "date" + case "pg_catalog.time": + return "time" + case "pg_catalog.timetz": + return "timetz" + case "pg_catalog.timestamp", "timestamp": + return "timestamp" + case "pg_catalog.timestamptz", "timestamptz": + return "timestamptz" + case "text", "pg_catalog.varchar", "pg_catalog.bpchar", + "string", "citext", "name", + "ltree", "lquery", "ltxtquery": + return "text" + case "uuid": + return "uuid" + case "inet": + return "inet" + case "cidr": + return "cidr" + case "macaddr", "macaddr8": + return "macaddr" + case "bit", "varbit", "pg_catalog.bit", "pg_catalog.varbit": + return "bits" + case "hstore": + return "hstore" + case "vector": + return "vector" + default: + return dt + } +} + func makeByte(s string) byte { var b byte if s == "" { From fe19865ac9bbd59f3cc5e97528d290d02c6a7701 Mon Sep 17 00:00:00 2001 From: Nikolay Kuznetsov Date: Thu, 2 Apr 2026 02:20:07 +0300 Subject: [PATCH 3/7] validateOnConflictClause --- internal/compiler/analyze.go | 80 +++++++++++++++--------------------- 1 file changed, 33 insertions(+), 47 deletions(-) diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 16cd7713fc..2121ff447c 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -145,7 +145,6 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar) var table *ast.TableName - var insertStmt *ast.InsertStmt switch n := raw.Stmt.(type) { case *ast.InsertStmt: if err := check(validate.InsertStmt(n)); err != nil { @@ -156,10 +155,6 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(err); err != nil { return nil, err } - if err := check(c.validateOnConflictColumns(n)); err != nil { - return nil, err - } - insertStmt = n } if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil { @@ -193,8 +188,8 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(err); err != nil { return nil, err } - if c.conf.Engine == config.EnginePostgreSQL { - if err := check(c.validateOnConflictTypes(insertStmt, params)); err != nil { + if n, ok := raw.Stmt.(*ast.InsertStmt); ok { + if err := check(c.validateOnConflictClause(n, params)); err != nil { return nil, err } } @@ -227,12 +222,13 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) }, rerr } -// validateOnConflictColumns checks column names in an ON CONFLICT DO UPDATE -// clause against the target table: -// - ON CONFLICT (col, ...) conflict target columns -// - DO UPDATE SET col = ... assignment target columns -// - EXCLUDED.col references in assignment values -func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error { +// validateOnConflictClause validates an ON CONFLICT DO UPDATE clause against +// the target table. It checks: +// - ON CONFLICT (col, ...) conflict target columns exist +// - DO UPDATE SET col = ... assignment target columns exist +// - EXCLUDED.col references exist +// - For PostgreSQL: $N param and EXCLUDED.col type compatibility with SET target +func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt, params []Parameter) error { if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { return nil } @@ -244,9 +240,11 @@ func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error { if err != nil { return err } - colSet := make(map[string]struct{}, len(table.Columns)) + + // Build column name → DataType from catalog for existence and type checks. + colDataTypes := make(map[string]string, len(table.Columns)) for _, col := range table.Columns { - colSet[col.Name] = struct{}{} + colDataTypes[col.Name] = dataType(&col.Type) } // Validate ON CONFLICT (col, ...) conflict target columns. @@ -256,7 +254,7 @@ func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error { if !ok || elem.Name == nil { continue } - if _, exists := colSet[*elem.Name]; !exists { + if _, exists := colDataTypes[*elem.Name]; !exists { e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name) e.Location = n.OnConflictClause.Infer.Location return e @@ -264,49 +262,44 @@ func (c *Compiler) validateOnConflictColumns(n *ast.InsertStmt) error { } } - // Validate DO UPDATE SET col = ... and EXCLUDED.col references. + // Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references. for _, item := range n.OnConflictClause.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok || target.Name == nil { continue } - // Validate the assignment target column. - if _, exists := colSet[*target.Name]; !exists { + if _, exists := colDataTypes[*target.Name]; !exists { e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name) e.Location = target.Location return e } - // Validate EXCLUDED.col references in the assigned value. if ref, ok := target.Val.(*ast.ColumnRef); ok { - if col, ok := excludedColumn(ref); ok { - if _, exists := colSet[col]; !exists { - e := sqlerr.ColumnNotFound(table.Rel.Name, col) + if excludedCol, ok := excludedColumn(ref); ok { + if _, exists := colDataTypes[excludedCol]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol) e.Location = ref.Location return e } } } } - return nil -} -// validateOnConflictTypes checks that $N params used in DO UPDATE SET assignments -// are type-compatible with the target column, based on the type already resolved -// for that param from the INSERT columns. -func (c *Compiler) validateOnConflictTypes(n *ast.InsertStmt, params []Parameter) error { - if n == nil || n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { - return nil - } - fqn, err := ParseTableName(n.Relation) - if err != nil { - return err - } - table, err := c.catalog.GetTable(fqn) - if err != nil { - return err + // Type compatibility checks (PostgreSQL only). + // To remove type checking: delete this block and validateOnConflictSetTypes. + if c.conf.Engine == config.EnginePostgreSQL { + if err := c.validateOnConflictSetTypes(n, params, colDataTypes); err != nil { + return err + } } - // Build param number → resolved DataType string from already-resolved params. + return nil +} + +// validateOnConflictSetTypes checks that values in DO UPDATE SET assignments +// are type-compatible with their target columns (PostgreSQL only). +// It handles $N params (typed from INSERT VALUES) and EXCLUDED.col references. +func (c *Compiler) validateOnConflictSetTypes(n *ast.InsertStmt, params []Parameter, colDataTypes map[string]string) error { + // Build param number → resolved DataType from already-resolved params. // Skips params with "any" type (unresolved). paramDataTypes := make(map[int]string, len(params)) for i := range params { @@ -315,13 +308,6 @@ func (c *Compiler) validateOnConflictTypes(n *ast.InsertStmt, params []Parameter } } - // Build column name → DataType string using the same dataType() function - // used by resolveCatalogRefs, so formats are comparable. - colDataTypes := make(map[string]string, len(table.Columns)) - for _, col := range table.Columns { - colDataTypes[col.Name] = dataType(&col.Type) - } - for _, item := range n.OnConflictClause.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok || target.Name == nil { From 95e05526cd87287563c4776a24e38088f8ea25ff Mon Sep 17 00:00:00 2001 From: Nikolay Kuznetsov Date: Thu, 2 Apr 2026 02:20:07 +0300 Subject: [PATCH 4/7] remove validateOnConflictSetTypes --- internal/compiler/analyze.go | 89 +++---------------- .../postgresql/pgx/query.sql | 10 --- .../postgresql/pgx/schema.sql | 3 +- .../postgresql/pgx/stderr.txt | 2 - internal/engine/postgresql/utils.go | 69 -------------- 5 files changed, 15 insertions(+), 158 deletions(-) diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 2121ff447c..484e70962a 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -1,12 +1,10 @@ package compiler import ( - "fmt" "sort" analyzer "github.com/sqlc-dev/sqlc/internal/analysis" "github.com/sqlc-dev/sqlc/internal/config" - "github.com/sqlc-dev/sqlc/internal/engine/postgresql" "github.com/sqlc-dev/sqlc/internal/source" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/named" @@ -189,7 +187,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } if n, ok := raw.Stmt.(*ast.InsertStmt); ok { - if err := check(c.validateOnConflictClause(n, params)); err != nil { + if err := check(c.validateOnConflictClause(n)); err != nil { return nil, err } } @@ -227,34 +225,35 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) // - ON CONFLICT (col, ...) conflict target columns exist // - DO UPDATE SET col = ... assignment target columns exist // - EXCLUDED.col references exist -// - For PostgreSQL: $N param and EXCLUDED.col type compatibility with SET target -func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt, params []Parameter) error { +func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt) error { if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { return nil } + fqn, err := ParseTableName(n.Relation) if err != nil { return err } + table, err := c.catalog.GetTable(fqn) if err != nil { return err } - // Build column name → DataType from catalog for existence and type checks. - colDataTypes := make(map[string]string, len(table.Columns)) + // Build set of column names for existence checks. + colNames := make(map[string]struct{}, len(table.Columns)) for _, col := range table.Columns { - colDataTypes[col.Name] = dataType(&col.Type) + colNames[col.Name] = struct{}{} } // Validate ON CONFLICT (col, ...) conflict target columns. - if n.OnConflictClause.Infer != nil { + if n.OnConflictClause.Infer != nil && n.OnConflictClause.Infer.IndexElems != nil { for _, item := range n.OnConflictClause.Infer.IndexElems.Items { elem, ok := item.(*ast.IndexElem) if !ok || elem.Name == nil { continue } - if _, exists := colDataTypes[*elem.Name]; !exists { + if _, exists := colNames[*elem.Name]; !exists { e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name) e.Location = n.OnConflictClause.Infer.Location return e @@ -263,19 +262,22 @@ func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt, params []Paramete } // Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references. + if n.OnConflictClause.TargetList == nil { + return nil + } for _, item := range n.OnConflictClause.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok || target.Name == nil { continue } - if _, exists := colDataTypes[*target.Name]; !exists { + if _, exists := colNames[*target.Name]; !exists { e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name) e.Location = target.Location return e } if ref, ok := target.Val.(*ast.ColumnRef); ok { if excludedCol, ok := excludedColumn(ref); ok { - if _, exists := colDataTypes[excludedCol]; !exists { + if _, exists := colNames[excludedCol]; !exists { e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol) e.Location = ref.Location return e @@ -283,69 +285,6 @@ func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt, params []Paramete } } } - - // Type compatibility checks (PostgreSQL only). - // To remove type checking: delete this block and validateOnConflictSetTypes. - if c.conf.Engine == config.EnginePostgreSQL { - if err := c.validateOnConflictSetTypes(n, params, colDataTypes); err != nil { - return err - } - } - - return nil -} - -// validateOnConflictSetTypes checks that values in DO UPDATE SET assignments -// are type-compatible with their target columns (PostgreSQL only). -// It handles $N params (typed from INSERT VALUES) and EXCLUDED.col references. -func (c *Compiler) validateOnConflictSetTypes(n *ast.InsertStmt, params []Parameter, colDataTypes map[string]string) error { - // Build param number → resolved DataType from already-resolved params. - // Skips params with "any" type (unresolved). - paramDataTypes := make(map[int]string, len(params)) - for i := range params { - if params[i].Column != nil && params[i].Column.DataType != "any" { - paramDataTypes[params[i].Number] = params[i].Column.DataType - } - } - - for _, item := range n.OnConflictClause.TargetList.Items { - target, ok := item.(*ast.ResTarget) - if !ok || target.Name == nil { - continue - } - colDT, ok := colDataTypes[*target.Name] - if !ok { - continue - } - switch val := target.Val.(type) { - case *ast.ParamRef: - paramDT, ok := paramDataTypes[val.Number] - if !ok { - continue - } - if postgresql.TypeFamily(paramDT) != postgresql.TypeFamily(colDT) { - return &sqlerr.Error{ - Message: fmt.Sprintf("parameter $%d has type %q but column %q has type %q", val.Number, paramDT, *target.Name, colDT), - Location: val.Location, - } - } - case *ast.ColumnRef: - excludedCol, ok := excludedColumn(val) - if !ok { - continue - } - excludedDT, ok := colDataTypes[excludedCol] - if !ok { - continue - } - if postgresql.TypeFamily(excludedDT) != postgresql.TypeFamily(colDT) { - return &sqlerr.Error{ - Message: fmt.Sprintf("EXCLUDED.%s has type %q but column %q has type %q", excludedCol, excludedDT, *target.Name, colDT), - Location: val.Location, - } - } - } - } return nil } diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql index adfd9ae2ec..624a4bc367 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql @@ -13,13 +13,3 @@ INSERT INTO servers(code, name) VALUES ($1, $2) ON CONFLICT (code) DO UPDATE SET name = EXCLUDED.name_typo; --- name: UpsertServerSetParamTypeMismatch :exec -INSERT INTO servers(code, name) VALUES ($1, $2) -ON CONFLICT (code) -DO UPDATE SET count = $2; - --- name: UpsertServerExcludedTypeMismatch :exec -INSERT INTO servers(code, name, count) VALUES ($1, $2, $3) -ON CONFLICT (code) -DO UPDATE SET count = EXCLUDED.code; - diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql index ff73489c3a..c3dec12e49 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/schema.sql @@ -1,5 +1,4 @@ CREATE TABLE servers ( code varchar PRIMARY KEY, - name text NOT NULL, - count integer NOT NULL DEFAULT 0 + name text NOT NULL ); diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt index 74e56e8136..6f6ebd87ad 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt @@ -2,5 +2,3 @@ query.sql:4:15: column "name_typo" of relation "servers" does not exist query.sql:8:13: column "code_typo" of relation "servers" does not exist query.sql:14:22: column "name_typo" of relation "servers" does not exist -query.sql:19:23: parameter $2 has type "text" but column "count" has type "pg_catalog.int4" -query.sql:24:23: EXCLUDED.code has type "pg_catalog.varchar" but column "count" has type "pg_catalog.int4" diff --git a/internal/engine/postgresql/utils.go b/internal/engine/postgresql/utils.go index d5856439f5..b8d49b1a97 100644 --- a/internal/engine/postgresql/utils.go +++ b/internal/engine/postgresql/utils.go @@ -39,75 +39,6 @@ func IsNamedParamSign(node *nodes.Node) bool { return ok && joinNodes(expr.AExpr.Name, ".") == "@" } -// TypeFamily maps a PostgreSQL DataType string to a canonical type family name, -// grouping compatible type aliases together. This is used for type compatibility -// checks rather than exact string equality, because PostgreSQL considers many -// type aliases assignment-compatible (e.g. text and varchar are both string types). -// -// The groupings are derived from postgresType() in -// internal/codegen/golang/postgresql_type.go, which maps these aliases to the -// same Go type. We cannot call postgresType() directly for type compatibility -// checking because it requires *plugin.GenerateRequest — a protobuf codegen -// struct constructed after compilation — and driver-specific opts.Options. -func TypeFamily(dt string) string { - switch dt { - case "serial", "serial4", "pg_catalog.serial4", - "integer", "int", "int4", "pg_catalog.int4": - return "int32" - case "bigserial", "serial8", "pg_catalog.serial8", - "bigint", "int8", "pg_catalog.int8", - "interval", "pg_catalog.interval": - return "int64" - case "smallserial", "serial2", "pg_catalog.serial2", - "smallint", "int2", "pg_catalog.int2": - return "int16" - case "float", "double precision", "float8", "pg_catalog.float8": - return "float64" - case "real", "float4", "pg_catalog.float4": - return "float32" - case "numeric", "pg_catalog.numeric", "money": - return "numeric" - case "boolean", "bool", "pg_catalog.bool": - return "bool" - case "json", "pg_catalog.json": - return "json" - case "jsonb", "pg_catalog.jsonb": - return "jsonb" - case "bytea", "blob", "pg_catalog.bytea": - return "bytes" - case "date": - return "date" - case "pg_catalog.time": - return "time" - case "pg_catalog.timetz": - return "timetz" - case "pg_catalog.timestamp", "timestamp": - return "timestamp" - case "pg_catalog.timestamptz", "timestamptz": - return "timestamptz" - case "text", "pg_catalog.varchar", "pg_catalog.bpchar", - "string", "citext", "name", - "ltree", "lquery", "ltxtquery": - return "text" - case "uuid": - return "uuid" - case "inet": - return "inet" - case "cidr": - return "cidr" - case "macaddr", "macaddr8": - return "macaddr" - case "bit", "varbit", "pg_catalog.bit", "pg_catalog.varbit": - return "bits" - case "hstore": - return "hstore" - case "vector": - return "vector" - default: - return dt - } -} - func makeByte(s string) byte { var b byte if s == "" { From 1ed27efc715be49f70c6691c341ec885bf32dea8 Mon Sep 17 00:00:00 2001 From: Nikolay Kuznetsov Date: Thu, 2 Apr 2026 02:20:07 +0300 Subject: [PATCH 5/7] move to validate --- internal/compiler/analyze.go | 94 +--------------------------- internal/sql/validate/insert_stmt.go | 85 ++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 93 deletions(-) diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 484e70962a..ad7525d0b5 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -9,7 +9,6 @@ import ( "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/named" "github.com/sqlc-dev/sqlc/internal/sql/rewrite" - "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" "github.com/sqlc-dev/sqlc/internal/sql/validate" ) @@ -143,11 +142,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar) var table *ast.TableName - switch n := raw.Stmt.(type) { - case *ast.InsertStmt: - if err := check(validate.InsertStmt(n)); err != nil { - return nil, err - } + if n, ok := raw.Stmt.(*ast.InsertStmt); ok { var err error table, err = ParseTableName(n.Relation) if err := check(err); err != nil { @@ -187,7 +182,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) return nil, err } if n, ok := raw.Stmt.(*ast.InsertStmt); ok { - if err := check(c.validateOnConflictClause(n)); err != nil { + if err := check(validate.InsertStmt(n, table, c.catalog)); err != nil { return nil, err } } @@ -219,88 +214,3 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) Named: namedParams, }, rerr } - -// validateOnConflictClause validates an ON CONFLICT DO UPDATE clause against -// the target table. It checks: -// - ON CONFLICT (col, ...) conflict target columns exist -// - DO UPDATE SET col = ... assignment target columns exist -// - EXCLUDED.col references exist -func (c *Compiler) validateOnConflictClause(n *ast.InsertStmt) error { - if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { - return nil - } - - fqn, err := ParseTableName(n.Relation) - if err != nil { - return err - } - - table, err := c.catalog.GetTable(fqn) - if err != nil { - return err - } - - // Build set of column names for existence checks. - colNames := make(map[string]struct{}, len(table.Columns)) - for _, col := range table.Columns { - colNames[col.Name] = struct{}{} - } - - // Validate ON CONFLICT (col, ...) conflict target columns. - if n.OnConflictClause.Infer != nil && n.OnConflictClause.Infer.IndexElems != nil { - for _, item := range n.OnConflictClause.Infer.IndexElems.Items { - elem, ok := item.(*ast.IndexElem) - if !ok || elem.Name == nil { - continue - } - if _, exists := colNames[*elem.Name]; !exists { - e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name) - e.Location = n.OnConflictClause.Infer.Location - return e - } - } - } - - // Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references. - if n.OnConflictClause.TargetList == nil { - return nil - } - for _, item := range n.OnConflictClause.TargetList.Items { - target, ok := item.(*ast.ResTarget) - if !ok || target.Name == nil { - continue - } - if _, exists := colNames[*target.Name]; !exists { - e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name) - e.Location = target.Location - return e - } - if ref, ok := target.Val.(*ast.ColumnRef); ok { - if excludedCol, ok := excludedColumn(ref); ok { - if _, exists := colNames[excludedCol]; !exists { - e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol) - e.Location = ref.Location - return e - } - } - } - } - return nil -} - -// excludedColumn returns the column name if the ColumnRef is an EXCLUDED.col -// reference, and ok=true. Returns "", false otherwise. -func excludedColumn(ref *ast.ColumnRef) (string, bool) { - if ref.Fields == nil || len(ref.Fields.Items) != 2 { - return "", false - } - first, ok := ref.Fields.Items[0].(*ast.String) - if !ok || first.Str != "excluded" { - return "", false - } - second, ok := ref.Fields.Items[1].(*ast.String) - if !ok { - return "", false - } - return second.Str, true -} diff --git a/internal/sql/validate/insert_stmt.go b/internal/sql/validate/insert_stmt.go index dd8041ea23..94667114b2 100644 --- a/internal/sql/validate/insert_stmt.go +++ b/internal/sql/validate/insert_stmt.go @@ -1,11 +1,14 @@ package validate import ( + "strings" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/catalog" "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) -func InsertStmt(stmt *ast.InsertStmt) error { +func InsertStmt(stmt *ast.InsertStmt, fqn *ast.TableName, c *catalog.Catalog) error { sel, ok := stmt.SelectStmt.(*ast.SelectStmt) if !ok { return nil @@ -35,5 +38,85 @@ func InsertStmt(stmt *ast.InsertStmt) error { Message: "INSERT has more expressions than target columns", } } + return onConflictClause(stmt, fqn, c) +} + +// onConflictClause validates an ON CONFLICT DO UPDATE clause against the target +// table. It checks: +// - ON CONFLICT (col, ...) conflict target columns exist +// - DO UPDATE SET col = ... assignment target columns exist +// - EXCLUDED.col references exist +func onConflictClause(n *ast.InsertStmt, fqn *ast.TableName, c *catalog.Catalog) error { + if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { + return nil + } + + table, err := c.GetTable(fqn) + if err != nil { + return err + } + + // Build set of column names for existence checks. + colNames := make(map[string]struct{}, len(table.Columns)) + for _, col := range table.Columns { + colNames[col.Name] = struct{}{} + } + + // Validate ON CONFLICT (col, ...) conflict target columns. + if n.OnConflictClause.Infer != nil && n.OnConflictClause.Infer.IndexElems != nil { + for _, item := range n.OnConflictClause.Infer.IndexElems.Items { + elem, ok := item.(*ast.IndexElem) + if !ok || elem.Name == nil { + continue + } + if _, exists := colNames[*elem.Name]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name) + e.Location = n.OnConflictClause.Infer.Location + return e + } + } + } + + // Validate DO UPDATE SET col = ... assignment target columns and EXCLUDED.col references. + if n.OnConflictClause.TargetList == nil { + return nil + } + for _, item := range n.OnConflictClause.TargetList.Items { + target, ok := item.(*ast.ResTarget) + if !ok || target.Name == nil { + continue + } + if _, exists := colNames[*target.Name]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name) + e.Location = target.Location + return e + } + if ref, ok := target.Val.(*ast.ColumnRef); ok { + if excludedCol, ok := excludedColumnRef(ref); ok { + if _, exists := colNames[excludedCol]; !exists { + e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol) + e.Location = ref.Location + return e + } + } + } + } return nil } + +// excludedColumnRef returns the column name if the ColumnRef is an EXCLUDED.col +// reference, and ok=true. Returns "", false otherwise. +func excludedColumnRef(ref *ast.ColumnRef) (string, bool) { + if ref.Fields == nil || len(ref.Fields.Items) != 2 { + return "", false + } + first, ok := ref.Fields.Items[0].(*ast.String) + if !ok || !strings.EqualFold(first.Str, "excluded") { + return "", false + } + second, ok := ref.Fields.Items[1].(*ast.String) + if !ok { + return "", false + } + return second.Str, true +} From dc3771270046bb944afc315af911ba8d24677478 Mon Sep 17 00:00:00 2001 From: Nikolay Kuznetsov Date: Thu, 2 Apr 2026 02:20:07 +0300 Subject: [PATCH 6/7] params order --- internal/compiler/analyze.go | 11 +++++------ internal/sql/validate/insert_stmt.go | 6 +++--- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index ad7525d0b5..264dbef6f5 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -142,12 +142,16 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) raw, namedParams, edits := rewrite.NamedParameters(c.conf.Engine, raw, numbers, dollar) var table *ast.TableName - if n, ok := raw.Stmt.(*ast.InsertStmt); ok { + switch n := raw.Stmt.(type) { + case *ast.InsertStmt: var err error table, err = ParseTableName(n.Relation) if err := check(err); err != nil { return nil, err } + if err := check(validate.InsertStmt(c.catalog, table, n)); err != nil { + return nil, err + } } if err := check(validate.FuncCall(c.catalog, c.combo, raw)); err != nil { @@ -181,11 +185,6 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) if err := check(err); err != nil { return nil, err } - if n, ok := raw.Stmt.(*ast.InsertStmt); ok { - if err := check(validate.InsertStmt(n, table, c.catalog)); err != nil { - return nil, err - } - } cols, err := c.outputColumns(qc, raw.Stmt) if err := check(err); err != nil { return nil, err diff --git a/internal/sql/validate/insert_stmt.go b/internal/sql/validate/insert_stmt.go index 94667114b2..54651d9c11 100644 --- a/internal/sql/validate/insert_stmt.go +++ b/internal/sql/validate/insert_stmt.go @@ -8,7 +8,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) -func InsertStmt(stmt *ast.InsertStmt, fqn *ast.TableName, c *catalog.Catalog) error { +func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) error { sel, ok := stmt.SelectStmt.(*ast.SelectStmt) if !ok { return nil @@ -38,7 +38,7 @@ func InsertStmt(stmt *ast.InsertStmt, fqn *ast.TableName, c *catalog.Catalog) er Message: "INSERT has more expressions than target columns", } } - return onConflictClause(stmt, fqn, c) + return onConflictClause(c, fqn, stmt) } // onConflictClause validates an ON CONFLICT DO UPDATE clause against the target @@ -46,7 +46,7 @@ func InsertStmt(stmt *ast.InsertStmt, fqn *ast.TableName, c *catalog.Catalog) er // - ON CONFLICT (col, ...) conflict target columns exist // - DO UPDATE SET col = ... assignment target columns exist // - EXCLUDED.col references exist -func onConflictClause(n *ast.InsertStmt, fqn *ast.TableName, c *catalog.Catalog) error { +func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt) error { if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { return nil } From c3d992202eb0468f534fd4256469dc695ea39fac Mon Sep 17 00:00:00 2001 From: Nikolay Kuznetsov Date: Thu, 2 Apr 2026 02:20:07 +0300 Subject: [PATCH 7/7] ON CONSTRAINT --- .../postgresql/pgx/query.sql | 8 ++++++ .../postgresql/pgx/stderr.txt | 4 ++- internal/sql/validate/insert_stmt.go | 27 ++++++++++++++++--- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql index 624a4bc367..2484d998b6 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/query.sql @@ -13,3 +13,11 @@ INSERT INTO servers(code, name) VALUES ($1, $2) ON CONFLICT (code) DO UPDATE SET name = EXCLUDED.name_typo; +-- name: UpsertServerMissingConflictTarget :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT DO UPDATE SET name = EXCLUDED.name; + +-- name: UpsertServerOnConstraintExcludedTypo :exec +INSERT INTO servers(code, name) VALUES ($1, $2) +ON CONFLICT ON CONSTRAINT servers_pkey DO UPDATE SET name = EXCLUDED.name_typo; + diff --git a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt index 6f6ebd87ad..b9692daa64 100644 --- a/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt +++ b/internal/endtoend/testdata/update_set_on_conflict/postgresql/pgx/stderr.txt @@ -1,4 +1,6 @@ # package querytest query.sql:4:15: column "name_typo" of relation "servers" does not exist query.sql:8:13: column "code_typo" of relation "servers" does not exist -query.sql:14:22: column "name_typo" of relation "servers" does not exist +query.sql:14:22: column "name_typo" of relation "EXCLUDED" does not exist +query.sql:17:1: ON CONFLICT DO UPDATE requires inference specification or constraint name +query.sql:22:61: column "name_typo" of relation "EXCLUDED" does not exist diff --git a/internal/sql/validate/insert_stmt.go b/internal/sql/validate/insert_stmt.go index 54651d9c11..236ddbfceb 100644 --- a/internal/sql/validate/insert_stmt.go +++ b/internal/sql/validate/insert_stmt.go @@ -8,6 +8,8 @@ import ( "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" ) +const excludedTable = "EXCLUDED" + func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) error { sel, ok := stmt.SelectStmt.(*ast.SelectStmt) if !ok { @@ -38,6 +40,7 @@ func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) er Message: "INSERT has more expressions than target columns", } } + return onConflictClause(c, fqn, stmt) } @@ -47,7 +50,7 @@ func InsertStmt(c *catalog.Catalog, fqn *ast.TableName, stmt *ast.InsertStmt) er // - DO UPDATE SET col = ... assignment target columns exist // - EXCLUDED.col references exist func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt) error { - if n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { + if fqn == nil || n.OnConflictClause == nil || n.OnConflictClause.Action != ast.OnConflictActionUpdate { return nil } @@ -62,13 +65,22 @@ func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt) colNames[col.Name] = struct{}{} } + // DO UPDATE requires a conflict target: ON CONFLICT (col) or ON CONFLICT ON CONSTRAINT name. + if n.OnConflictClause.Infer == nil { + return &sqlerr.Error{ + Code: "42601", + Message: "ON CONFLICT DO UPDATE requires inference specification or constraint name", + } + } + // Validate ON CONFLICT (col, ...) conflict target columns. - if n.OnConflictClause.Infer != nil && n.OnConflictClause.Infer.IndexElems != nil { + if n.OnConflictClause.Infer.IndexElems != nil { for _, item := range n.OnConflictClause.Infer.IndexElems.Items { elem, ok := item.(*ast.IndexElem) if !ok || elem.Name == nil { continue } + if _, exists := colNames[*elem.Name]; !exists { e := sqlerr.ColumnNotFound(table.Rel.Name, *elem.Name) e.Location = n.OnConflictClause.Infer.Location @@ -81,26 +93,30 @@ func onConflictClause(c *catalog.Catalog, fqn *ast.TableName, n *ast.InsertStmt) if n.OnConflictClause.TargetList == nil { return nil } + for _, item := range n.OnConflictClause.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok || target.Name == nil { continue } + if _, exists := colNames[*target.Name]; !exists { e := sqlerr.ColumnNotFound(table.Rel.Name, *target.Name) e.Location = target.Location return e } + if ref, ok := target.Val.(*ast.ColumnRef); ok { if excludedCol, ok := excludedColumnRef(ref); ok { if _, exists := colNames[excludedCol]; !exists { - e := sqlerr.ColumnNotFound(table.Rel.Name, excludedCol) + e := sqlerr.ColumnNotFound(excludedTable, excludedCol) e.Location = ref.Location return e } } } } + return nil } @@ -110,13 +126,16 @@ func excludedColumnRef(ref *ast.ColumnRef) (string, bool) { if ref.Fields == nil || len(ref.Fields.Items) != 2 { return "", false } + first, ok := ref.Fields.Items[0].(*ast.String) - if !ok || !strings.EqualFold(first.Str, "excluded") { + if !ok || !strings.EqualFold(first.Str, excludedTable) { return "", false } + second, ok := ref.Fields.Items[1].(*ast.String) if !ok { return "", false } + return second.Str, true }