diff --git a/src/ast/dml.rs b/src/ast/dml.rs index f9c8823a2..06e939e1e 100644 --- a/src/ast/dml.rs +++ b/src/ast/dml.rs @@ -31,9 +31,9 @@ use crate::{ use super::{ display_comma_separated, helpers::attached_token::AttachedToken, query::InputFormatClause, - Assignment, Expr, FromTable, Ident, InsertAliases, MysqlInsertPriority, ObjectName, OnInsert, - OptimizerHint, OrderByExpr, Query, SelectInto, SelectItem, Setting, SqliteOnConflict, - TableFactor, TableObject, TableWithJoins, UpdateTableFromKind, Values, + Assignment, Expr, FromTable, Ident, InsertAliases, InsertTableAlias, MysqlInsertPriority, + ObjectName, OnInsert, OptimizerHint, OrderByExpr, Query, SelectInto, SelectItem, Setting, + SqliteOnConflict, TableFactor, TableObject, TableWithJoins, UpdateTableFromKind, Values, }; /// INSERT statement. @@ -56,8 +56,9 @@ pub struct Insert { pub into: bool, /// TABLE pub table: TableObject, - /// table_name as foo (for PostgreSQL) - pub table_alias: Option, + /// `table_name as foo` (for PostgreSQL) + /// `table_name foo` (for Oracle) + pub table_alias: Option, /// COLUMNS pub columns: Vec, /// Overwrite (Hive) @@ -125,8 +126,13 @@ pub struct Insert { impl Display for Insert { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // SQLite OR conflict has a special format: INSERT OR ... INTO table_name - let table_name = if let Some(alias) = &self.table_alias { - format!("{0} AS {alias}", self.table) + let table_name = if let Some(table_alias) = &self.table_alias { + format!( + "{table} {as_keyword}{alias}", + table = self.table, + as_keyword = if table_alias.explicit { "AS " } else { "" }, + alias = table_alias.alias + ) } else { self.table.to_string() }; diff --git a/src/ast/mod.rs b/src/ast/mod.rs index eda282260..d51ea6667 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -6451,6 +6451,17 @@ pub struct InsertAliases { pub col_aliases: Option>, } +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +/// Optional alias for an `INSERT` table; i.e. the table to be inserted into +pub struct InsertTableAlias { + /// `true` if the aliases was explicitly introduced with the "AS" keyword + pub explicit: bool, + /// the alias name itself + pub alias: Ident, +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] diff --git a/src/ast/query.rs b/src/ast/query.rs index b8f605be5..6d95216df 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -1589,6 +1589,7 @@ pub enum TableFactor { /// /// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#pivot_operator) /// [Snowflake](https://docs.snowflake.com/en/sql-reference/constructs/pivot) + /// [Oracle](https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/SELECT.html#GUID-CFA006CA-6FF1-4972-821E-6996142A51C6__GUID-68257B27-1C4C-4C47-8140-5C60E0E65D35) Pivot { /// The input table to pivot. table: Box, @@ -1610,8 +1611,10 @@ pub enum TableFactor { /// table UNPIVOT [ { INCLUDE | EXCLUDE } NULLS ] (value FOR name IN (column1, [ column2, ... ])) [ alias ] /// ``` /// - /// See . - /// See . + /// [Snowflake](https://docs.snowflake.com/en/sql-reference/constructs/unpivot) + /// [Databricks](https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot) + /// [BigQuery](https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#unpivot_operator) + /// [Oracle](https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/SELECT.html#GUID-CFA006CA-6FF1-4972-821E-6996142A51C6__GUID-9B4E0389-413C-4014-94A1-0A0571BDF7E1) Unpivot { /// The input table to unpivot. table: Box, diff --git a/src/ast/spans.rs b/src/ast/spans.rs index f4bdf85a3..abca7138d 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -1324,7 +1324,7 @@ impl Spanned for Insert { union_spans( core::iter::once(insert_token.0.span) .chain(core::iter::once(table.span())) - .chain(table_alias.as_ref().map(|i| i.span)) + .chain(table_alias.iter().map(|k| k.alias.span)) .chain(columns.iter().map(|i| i.span)) .chain(source.as_ref().map(|q| q.span())) .chain(assignments.iter().map(|i| i.span())) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index e708217da..d2006b037 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -4543,7 +4543,13 @@ impl<'a> Parser<'a> { /// /// Returns true if the current token matches the expected keyword. pub fn peek_keyword(&self, expected: Keyword) -> bool { - matches!(&self.peek_token_ref().token, Token::Word(w) if expected == w.keyword) + self.peek_keyword_one_of(&[expected]) + } + + #[must_use] + /// Checks whether the current token is one of the expected keywords without consuming it. + fn peek_keyword_one_of(&self, expected: &[Keyword]) -> bool { + matches!(&self.peek_token_ref().token, Token::Word(w) if expected.contains(&w.keyword)) } /// If the current token is the `expected` keyword followed by @@ -13617,7 +13623,7 @@ impl<'a> Parser<'a> { Keyword::PIVOT => { self.expect_token(&Token::LParen)?; let aggregate_functions = - self.parse_comma_separated(Self::parse_aliased_function_call)?; + self.parse_comma_separated(Self::parse_pivot_aggregate_function)?; self.expect_keyword_is(Keyword::FOR)?; let value_column = self.parse_period_separated(|p| p.parse_identifier())?; self.expect_keyword_is(Keyword::IN)?; @@ -16242,20 +16248,6 @@ impl<'a> Parser<'a> { }) } - fn parse_aliased_function_call(&mut self) -> Result { - let function_name = match self.next_token().token { - Token::Word(w) => Ok(w.value), - _ => self.expected("a function identifier", self.peek_token()), - }?; - let expr = self.parse_function(ObjectName::from(vec![Ident::new(function_name)]))?; - let alias = if self.parse_keyword(Keyword::AS) { - Some(self.parse_identifier()?) - } else { - None - }; - - Ok(ExprWithAlias { expr, alias }) - } /// Parses an expression with an optional alias /// /// Examples: @@ -16289,13 +16281,40 @@ impl<'a> Parser<'a> { Ok(ExprWithAlias { expr, alias }) } + /// Parse an expression followed by an optional alias; Unlike + /// [Self::parse_expr_with_alias] the "AS" keyword between the expression + /// and the alias is optional. + fn parse_expr_with_alias_optional_as_keyword(&mut self) -> Result { + let expr = self.parse_expr()?; + let alias = self.parse_identifier_optional_alias()?; + Ok(ExprWithAlias { expr, alias }) + } + + /// Parses a plain function call with an optional alias for the `PIVOT` clause + fn parse_pivot_aggregate_function(&mut self) -> Result { + let function_name = match self.next_token().token { + Token::Word(w) => Ok(w.value), + _ => self.expected("a function identifier", self.peek_token()), + }?; + let expr = self.parse_function(ObjectName::from(vec![Ident::new(function_name)]))?; + let alias = { + fn validator(explicit: bool, kw: &Keyword, parser: &mut Parser) -> bool { + // ~ for a PIVOT aggregate function the alias must not be a "FOR"; in any dialect + kw != &Keyword::FOR && parser.dialect.is_select_item_alias(explicit, kw, parser) + } + self.parse_optional_alias_inner(None, validator)? + }; + Ok(ExprWithAlias { expr, alias }) + } + /// Parse a PIVOT table factor (ClickHouse/Oracle style pivot), returning a TableFactor. pub fn parse_pivot_table_factor( &mut self, table: TableFactor, ) -> Result { self.expect_token(&Token::LParen)?; - let aggregate_functions = self.parse_comma_separated(Self::parse_aliased_function_call)?; + let aggregate_functions = + self.parse_comma_separated(Self::parse_pivot_aggregate_function)?; self.expect_keyword_is(Keyword::FOR)?; let value_column = if self.peek_token_ref().token == Token::LParen { self.parse_parenthesized_column_list_inner(Mandatory, false, |p| { @@ -16317,7 +16336,9 @@ impl<'a> Parser<'a> { } else if self.peek_sub_query() { PivotValueSource::Subquery(self.parse_query()?) } else { - PivotValueSource::List(self.parse_comma_separated(Self::parse_expr_with_alias)?) + PivotValueSource::List( + self.parse_comma_separated(Self::parse_expr_with_alias_optional_as_keyword)?, + ) }; self.expect_token(&Token::RParen)?; @@ -17118,12 +17139,26 @@ impl<'a> Parser<'a> { let table = self.parse_keyword(Keyword::TABLE); let table_object = self.parse_table_object()?; - let table_alias = - if dialect_of!(self is PostgreSqlDialect) && self.parse_keyword(Keyword::AS) { - Some(self.parse_identifier()?) + let table_alias = if dialect_of!(self is OracleDialect) { + if !self.peek_sub_query() + && !self.peek_keyword_one_of(&[Keyword::DEFAULT, Keyword::VALUES]) + { + self.maybe_parse(|parser| parser.parse_identifier())? + .map(|alias| InsertTableAlias { + explicit: false, + alias, + }) } else { None - }; + } + } else if dialect_of!(self is PostgreSqlDialect) && self.parse_keyword(Keyword::AS) { + Some(InsertTableAlias { + explicit: true, + alias: self.parse_identifier()?, + }) + } else { + None + }; let is_mysql = dialect_of!(self is MySqlDialect); @@ -19349,14 +19384,7 @@ impl<'a> Parser<'a> { /// Returns true if the next keyword indicates a sub query, i.e. SELECT or WITH fn peek_sub_query(&mut self) -> bool { - if self - .parse_one_of_keywords(&[Keyword::SELECT, Keyword::WITH]) - .is_some() - { - self.prev_token(); - return true; - } - false + self.peek_keyword_one_of(&[Keyword::SELECT, Keyword::WITH]) } pub(crate) fn parse_show_stmt_options(&mut self) -> Result { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 5822153ac..182854d13 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11357,6 +11357,18 @@ fn parse_pivot_table() { verified_stmt(multiple_value_columns_sql).to_string(), multiple_value_columns_sql ); + + // assert optional "AS" keyword for aliases for pivot values + one_statement_parses_to( + "SELECT * FROM t PIVOT(SUM(1) FOR a.abc IN (1 x, 'two' y, three z))", + "SELECT * FROM t PIVOT(SUM(1) FOR a.abc IN (1 AS x, 'two' AS y, three AS z))", + ); + + // assert optional "AS" keyword for aliases for pivot aggregate function + one_statement_parses_to( + "SELECT * FROM t PIVOT(SUM(1) x, COUNT(42) y FOR a.abc IN (1))", + "SELECT * FROM t PIVOT(SUM(1) AS x, COUNT(42) AS y FOR a.abc IN (1))", + ); } #[test] diff --git a/tests/sqlparser_oracle.rs b/tests/sqlparser_oracle.rs index 0dbccdb5e..34149dc57 100644 --- a/tests/sqlparser_oracle.rs +++ b/tests/sqlparser_oracle.rs @@ -21,7 +21,10 @@ use pretty_assertions::assert_eq; use sqlparser::{ - ast::{BinaryOperator, Expr, Ident, QuoteDelimitedString, Value, ValueWithSpan}, + ast::{ + BinaryOperator, Expr, Ident, Insert, InsertTableAlias, ObjectName, QuoteDelimitedString, + Statement, TableObject, Value, ValueWithSpan, + }, dialect::OracleDialect, parser::ParserError, tokenizer::Span, @@ -414,3 +417,55 @@ fn test_connect_by() { ORDER BY \"Employee\", \"Manager\", \"Pathlen\", \"Path\"", ); } + +#[test] +fn test_insert_with_table_alias() { + let oracle_dialect = oracle(); + + fn verify_table_name_with_alias(stmt: &Statement, exp_table_name: &str, exp_table_alias: &str) { + assert!(matches!(stmt, + Statement::Insert(Insert { + table: TableObject::TableName(table_name), + table_alias: Some(InsertTableAlias { + explicit: false, + alias: Ident { + value: table_alias, + quote_style: None, + span: _ + } + }), + .. + }) + if table_alias == exp_table_alias + && table_name == &ObjectName::from(vec![Ident { + value: exp_table_name.into(), + quote_style: None, + span: Span::empty(), + }]) + )); + } + + let stmt = oracle_dialect.verified_stmt( + "INSERT INTO foo_t t \ + SELECT 1, 2, 3 FROM dual", + ); + verify_table_name_with_alias(&stmt, "foo_t", "t"); + + let stmt = oracle_dialect.verified_stmt( + "INSERT INTO foo_t asdf (a, b, c) \ + SELECT 1, 2, 3 FROM dual", + ); + verify_table_name_with_alias(&stmt, "foo_t", "asdf"); + + let stmt = oracle_dialect.verified_stmt( + "INSERT INTO foo_t t (a, b, c) \ + VALUES (1, 2, 3)", + ); + verify_table_name_with_alias(&stmt, "foo_t", "t"); + + let stmt = oracle_dialect.verified_stmt( + "INSERT INTO foo_t t \ + VALUES (1, 2, 3)", + ); + verify_table_name_with_alias(&stmt, "foo_t", "t"); +} diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index f8c738136..a831a510b 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -5406,10 +5406,13 @@ fn test_simple_postgres_insert_with_alias() { quote_style: None, span: Span::empty(), }])), - table_alias: Some(Ident { - value: "test_table".to_string(), - quote_style: None, - span: Span::empty(), + table_alias: Some(InsertTableAlias { + explicit: true, + alias: Ident { + value: "test_table".to_string(), + quote_style: None, + span: Span::empty(), + } }), columns: vec![ Ident { @@ -5482,10 +5485,13 @@ fn test_simple_postgres_insert_with_alias() { quote_style: None, span: Span::empty(), }])), - table_alias: Some(Ident { - value: "test_table".to_string(), - quote_style: None, - span: Span::empty(), + table_alias: Some(InsertTableAlias { + explicit: true, + alias: Ident { + value: "test_table".to_string(), + quote_style: None, + span: Span::empty(), + } }), columns: vec![ Ident { @@ -5560,10 +5566,13 @@ fn test_simple_insert_with_quoted_alias() { quote_style: None, span: Span::empty(), }])), - table_alias: Some(Ident { - value: "Test_Table".to_string(), - quote_style: Some('"'), - span: Span::empty(), + table_alias: Some(InsertTableAlias { + explicit: true, + alias: Ident { + value: "Test_Table".to_string(), + quote_style: Some('"'), + span: Span::empty(), + } }), columns: vec![ Ident {