Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/ast/dml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ use crate::{

use super::{
display_comma_separated, helpers::attached_token::AttachedToken, query::InputFormatClause,
Assignment, Expr, FromTable, Ident, InsertAliases, MysqlInsertPriority, ObjectName, OnInsert,
OptimizerHint, OrderByExpr, Query, SelectInto, SelectItem, Setting, SqliteOnConflict,
TableFactor, TableObject, TableWithJoins, UpdateTableFromKind, Values,
Assignment, Expr, FromTable, Ident, InsertAliases, InsertTableAlias, MysqlInsertPriority,
ObjectName, OnInsert, OptimizerHint, OrderByExpr, Query, SelectInto, SelectItem, Setting,
SqliteOnConflict, TableFactor, TableObject, TableWithJoins, UpdateTableFromKind, Values,
};

/// INSERT statement.
Expand All @@ -56,8 +56,9 @@ pub struct Insert {
pub into: bool,
/// TABLE
pub table: TableObject,
/// table_name as foo (for PostgreSQL)
pub table_alias: Option<Ident>,
/// `table_name as foo` (for PostgreSQL)
/// `table_name foo` (for Oracle)
pub table_alias: Option<InsertTableAlias>,
/// COLUMNS
pub columns: Vec<Ident>,
/// Overwrite (Hive)
Expand Down Expand Up @@ -125,8 +126,13 @@ pub struct Insert {
impl Display for Insert {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// SQLite OR conflict has a special format: INSERT OR ... INTO table_name
let table_name = if let Some(alias) = &self.table_alias {
format!("{0} AS {alias}", self.table)
let table_name = if let Some(table_alias) = &self.table_alias {
format!(
"{table} {as_keyword}{alias}",
table = self.table,
as_keyword = if table_alias.explicit { "AS " } else { "" },
alias = table_alias.alias
)
} else {
self.table.to_string()
};
Expand Down
11 changes: 11 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6451,6 +6451,17 @@ pub struct InsertAliases {
pub col_aliases: Option<Vec<Ident>>,
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
/// Optional alias for an `INSERT` table; i.e. the table to be inserted into
pub struct InsertTableAlias {
/// `true` if the aliases was explicitly introduced with the "AS" keyword
pub explicit: bool,
/// the alias name itself
pub alias: Ident,
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
Expand Down
7 changes: 5 additions & 2 deletions src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,7 @@ pub enum TableFactor {
///
/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#pivot_operator)
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/constructs/pivot)
/// [Oracle](https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/SELECT.html#GUID-CFA006CA-6FF1-4972-821E-6996142A51C6__GUID-68257B27-1C4C-4C47-8140-5C60E0E65D35)
Pivot {
/// The input table to pivot.
table: Box<TableFactor>,
Expand All @@ -1610,8 +1611,10 @@ pub enum TableFactor {
/// table UNPIVOT [ { INCLUDE | EXCLUDE } NULLS ] (value FOR name IN (column1, [ column2, ... ])) [ alias ]
/// ```
///
/// See <https://docs.snowflake.com/en/sql-reference/constructs/unpivot>.
/// See <https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot>.
/// [Snowflake](https://docs.snowflake.com/en/sql-reference/constructs/unpivot)
/// [Databricks](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot)
/// [BigQuery](https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#unpivot_operator)
/// [Oracle](https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/SELECT.html#GUID-CFA006CA-6FF1-4972-821E-6996142A51C6__GUID-9B4E0389-413C-4014-94A1-0A0571BDF7E1)
Unpivot {
/// The input table to unpivot.
table: Box<TableFactor>,
Expand Down
2 changes: 1 addition & 1 deletion src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ impl Spanned for Insert {
union_spans(
core::iter::once(insert_token.0.span)
.chain(core::iter::once(table.span()))
.chain(table_alias.as_ref().map(|i| i.span))
.chain(table_alias.iter().map(|k| k.alias.span))
.chain(columns.iter().map(|i| i.span))
.chain(source.as_ref().map(|q| q.span()))
.chain(assignments.iter().map(|i| i.span()))
Expand Down
88 changes: 58 additions & 30 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4543,7 +4543,13 @@ impl<'a> Parser<'a> {
///
/// Returns true if the current token matches the expected keyword.
pub fn peek_keyword(&self, expected: Keyword) -> bool {
matches!(&self.peek_token_ref().token, Token::Word(w) if expected == w.keyword)
self.peek_keyword_one_of(&[expected])
}

#[must_use]
/// Checks whether the current token is one of the expected keywords without consuming it.
fn peek_keyword_one_of(&self, expected: &[Keyword]) -> bool {
matches!(&self.peek_token_ref().token, Token::Word(w) if expected.contains(&w.keyword))
}

/// If the current token is the `expected` keyword followed by
Expand Down Expand Up @@ -13617,7 +13623,7 @@ impl<'a> Parser<'a> {
Keyword::PIVOT => {
self.expect_token(&Token::LParen)?;
let aggregate_functions =
self.parse_comma_separated(Self::parse_aliased_function_call)?;
self.parse_comma_separated(Self::parse_pivot_aggregate_function)?;
self.expect_keyword_is(Keyword::FOR)?;
let value_column = self.parse_period_separated(|p| p.parse_identifier())?;
self.expect_keyword_is(Keyword::IN)?;
Expand Down Expand Up @@ -16242,20 +16248,6 @@ impl<'a> Parser<'a> {
})
}

fn parse_aliased_function_call(&mut self) -> Result<ExprWithAlias, ParserError> {
let function_name = match self.next_token().token {
Token::Word(w) => Ok(w.value),
_ => self.expected("a function identifier", self.peek_token()),
}?;
let expr = self.parse_function(ObjectName::from(vec![Ident::new(function_name)]))?;
let alias = if self.parse_keyword(Keyword::AS) {
Some(self.parse_identifier()?)
} else {
None
};

Ok(ExprWithAlias { expr, alias })
}
/// Parses an expression with an optional alias
///
/// Examples:
Expand Down Expand Up @@ -16289,13 +16281,40 @@ impl<'a> Parser<'a> {
Ok(ExprWithAlias { expr, alias })
}

/// Parse an expression followed by an optional alias; Unlike
/// [Self::parse_expr_with_alias] the "AS" keyword between the expression
/// and the alias is optional.
fn parse_expr_with_alias_optional_as_keyword(&mut self) -> Result<ExprWithAlias, ParserError> {
let expr = self.parse_expr()?;
let alias = self.parse_identifier_optional_alias()?;
Ok(ExprWithAlias { expr, alias })
}

/// Parses a plain function call with an optional alias for the `PIVOT` clause
fn parse_pivot_aggregate_function(&mut self) -> Result<ExprWithAlias, ParserError> {
let function_name = match self.next_token().token {
Token::Word(w) => Ok(w.value),
_ => self.expected("a function identifier", self.peek_token()),
}?;
let expr = self.parse_function(ObjectName::from(vec![Ident::new(function_name)]))?;
let alias = {
fn validator(explicit: bool, kw: &Keyword, parser: &mut Parser) -> bool {
// ~ for a PIVOT aggregate function the alias must not be a "FOR"; in any dialect
kw != &Keyword::FOR && parser.dialect.is_select_item_alias(explicit, kw, parser)
}
self.parse_optional_alias_inner(None, validator)?
};
Ok(ExprWithAlias { expr, alias })
}

/// Parse a PIVOT table factor (ClickHouse/Oracle style pivot), returning a TableFactor.
pub fn parse_pivot_table_factor(
&mut self,
table: TableFactor,
) -> Result<TableFactor, ParserError> {
self.expect_token(&Token::LParen)?;
let aggregate_functions = self.parse_comma_separated(Self::parse_aliased_function_call)?;
let aggregate_functions =
self.parse_comma_separated(Self::parse_pivot_aggregate_function)?;
self.expect_keyword_is(Keyword::FOR)?;
let value_column = if self.peek_token_ref().token == Token::LParen {
self.parse_parenthesized_column_list_inner(Mandatory, false, |p| {
Expand All @@ -16317,7 +16336,9 @@ impl<'a> Parser<'a> {
} else if self.peek_sub_query() {
PivotValueSource::Subquery(self.parse_query()?)
} else {
PivotValueSource::List(self.parse_comma_separated(Self::parse_expr_with_alias)?)
PivotValueSource::List(
self.parse_comma_separated(Self::parse_expr_with_alias_optional_as_keyword)?,
)
};
self.expect_token(&Token::RParen)?;

Expand Down Expand Up @@ -17118,12 +17139,26 @@ impl<'a> Parser<'a> {
let table = self.parse_keyword(Keyword::TABLE);
let table_object = self.parse_table_object()?;

let table_alias =
if dialect_of!(self is PostgreSqlDialect) && self.parse_keyword(Keyword::AS) {
Some(self.parse_identifier()?)
let table_alias = if dialect_of!(self is OracleDialect) {
if !self.peek_sub_query()
&& !self.peek_keyword_one_of(&[Keyword::DEFAULT, Keyword::VALUES])
{
self.maybe_parse(|parser| parser.parse_identifier())?
.map(|alias| InsertTableAlias {
explicit: false,
alias,
})
} else {
None
};
}
} else if dialect_of!(self is PostgreSqlDialect) && self.parse_keyword(Keyword::AS) {
Some(InsertTableAlias {
explicit: true,
alias: self.parse_identifier()?,
})
} else {
None
};

let is_mysql = dialect_of!(self is MySqlDialect);

Expand Down Expand Up @@ -19349,14 +19384,7 @@ impl<'a> Parser<'a> {

/// Returns true if the next keyword indicates a sub query, i.e. SELECT or WITH
fn peek_sub_query(&mut self) -> bool {
if self
.parse_one_of_keywords(&[Keyword::SELECT, Keyword::WITH])
.is_some()
{
self.prev_token();
return true;
}
false
self.peek_keyword_one_of(&[Keyword::SELECT, Keyword::WITH])
}

pub(crate) fn parse_show_stmt_options(&mut self) -> Result<ShowStatementOptions, ParserError> {
Expand Down
12 changes: 12 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11357,6 +11357,18 @@ fn parse_pivot_table() {
verified_stmt(multiple_value_columns_sql).to_string(),
multiple_value_columns_sql
);

// assert optional "AS" keyword for aliases for pivot values
one_statement_parses_to(
"SELECT * FROM t PIVOT(SUM(1) FOR a.abc IN (1 x, 'two' y, three z))",
"SELECT * FROM t PIVOT(SUM(1) FOR a.abc IN (1 AS x, 'two' AS y, three AS z))",
);

// assert optional "AS" keyword for aliases for pivot aggregate function
one_statement_parses_to(
"SELECT * FROM t PIVOT(SUM(1) x, COUNT(42) y FOR a.abc IN (1))",
"SELECT * FROM t PIVOT(SUM(1) AS x, COUNT(42) AS y FOR a.abc IN (1))",
);
}

#[test]
Expand Down
57 changes: 56 additions & 1 deletion tests/sqlparser_oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
use pretty_assertions::assert_eq;

use sqlparser::{
ast::{BinaryOperator, Expr, Ident, QuoteDelimitedString, Value, ValueWithSpan},
ast::{
BinaryOperator, Expr, Ident, Insert, InsertTableAlias, ObjectName, QuoteDelimitedString,
Statement, TableObject, Value, ValueWithSpan,
},
dialect::OracleDialect,
parser::ParserError,
tokenizer::Span,
Expand Down Expand Up @@ -414,3 +417,55 @@ fn test_connect_by() {
ORDER BY \"Employee\", \"Manager\", \"Pathlen\", \"Path\"",
);
}

#[test]
fn test_insert_with_table_alias() {
let oracle_dialect = oracle();

fn verify_table_name_with_alias(stmt: &Statement, exp_table_name: &str, exp_table_alias: &str) {
assert!(matches!(stmt,
Statement::Insert(Insert {
table: TableObject::TableName(table_name),
table_alias: Some(InsertTableAlias {
explicit: false,
alias: Ident {
value: table_alias,
quote_style: None,
span: _
}
}),
..
})
if table_alias == exp_table_alias
&& table_name == &ObjectName::from(vec![Ident {
value: exp_table_name.into(),
quote_style: None,
span: Span::empty(),
}])
));
}

let stmt = oracle_dialect.verified_stmt(
"INSERT INTO foo_t t \
SELECT 1, 2, 3 FROM dual",
);
verify_table_name_with_alias(&stmt, "foo_t", "t");

let stmt = oracle_dialect.verified_stmt(
"INSERT INTO foo_t asdf (a, b, c) \
SELECT 1, 2, 3 FROM dual",
);
verify_table_name_with_alias(&stmt, "foo_t", "asdf");

let stmt = oracle_dialect.verified_stmt(
"INSERT INTO foo_t t (a, b, c) \
VALUES (1, 2, 3)",
);
verify_table_name_with_alias(&stmt, "foo_t", "t");

let stmt = oracle_dialect.verified_stmt(
"INSERT INTO foo_t t \
VALUES (1, 2, 3)",
);
verify_table_name_with_alias(&stmt, "foo_t", "t");
}
33 changes: 21 additions & 12 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5406,10 +5406,13 @@ fn test_simple_postgres_insert_with_alias() {
quote_style: None,
span: Span::empty(),
}])),
table_alias: Some(Ident {
value: "test_table".to_string(),
quote_style: None,
span: Span::empty(),
table_alias: Some(InsertTableAlias {
explicit: true,
alias: Ident {
value: "test_table".to_string(),
quote_style: None,
span: Span::empty(),
}
}),
columns: vec![
Ident {
Expand Down Expand Up @@ -5482,10 +5485,13 @@ fn test_simple_postgres_insert_with_alias() {
quote_style: None,
span: Span::empty(),
}])),
table_alias: Some(Ident {
value: "test_table".to_string(),
quote_style: None,
span: Span::empty(),
table_alias: Some(InsertTableAlias {
explicit: true,
alias: Ident {
value: "test_table".to_string(),
quote_style: None,
span: Span::empty(),
}
}),
columns: vec![
Ident {
Expand Down Expand Up @@ -5560,10 +5566,13 @@ fn test_simple_insert_with_quoted_alias() {
quote_style: None,
span: Span::empty(),
}])),
table_alias: Some(Ident {
value: "Test_Table".to_string(),
quote_style: Some('"'),
span: Span::empty(),
table_alias: Some(InsertTableAlias {
explicit: true,
alias: Ident {
value: "Test_Table".to_string(),
quote_style: Some('"'),
span: Span::empty(),
}
}),
columns: vec![
Ident {
Expand Down
Loading