WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 72 additions & 3 deletions sqlglot/optimizer/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,25 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequenc
# in bigquery, unnest structs are automatically scoped as tables, so you can
# directly select a struct field in a query.
# this handles the case where the unnest is statically defined.
if self.dialect.UNNEST_COLUMN_ONLY:
if source.expression.is_type(exp.DataType.Type.STRUCT):
for k in source.expression.type.expressions: # type: ignore
if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
unnest = source.expression

# if type is not annotated yet, try to get it from the schema
if not unnest.type or unnest.type.is_type(exp.DataType.Type.UNKNOWN):
unnest_expr = seq_get(unnest.expressions, 0)
if isinstance(unnest_expr, exp.Column) and self.scope.parent:
col_type = self._get_unnest_column_type(unnest_expr)
# extract element type if it's an ARRAY
if col_type and col_type.is_type(exp.DataType.Type.ARRAY):
element_types = col_type.expressions
if element_types:
unnest.type = element_types[0].copy()
else:
if col_type:
unnest.type = col_type.copy()
# check if the result type is a STRUCT - extract struct field names
if unnest.is_type(exp.DataType.Type.STRUCT):
for k in unnest.type.expressions: # type: ignore
columns.append(k.name)
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
columns = self.get_source_columns_from_set_op(source.expression)
Expand Down Expand Up @@ -299,3 +315,56 @@ def _get_unambiguous_columns(
unambiguous_columns[column] = table

return unambiguous_columns

def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
"""
Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.

Args:
column: The column expression being unnested.

Returns:
The DataType of the column, or None if not found.
"""
scope = self.scope.parent

# if column is qualified, use that table, otherwise disambiguate using the resolver
if column.table:
table_name = column.table
else:
# use the parent scope's resolver to disambiguate the column
parent_resolver = Resolver(scope, self.schema, self._infer_schema)
table_identifier = parent_resolver.get_table(column)
if not table_identifier:
return None
table_name = table_identifier.name

source = scope.sources.get(table_name)
return self._get_column_type_from_scope(source, column) if source else None

def _get_column_type_from_scope(
self, source: t.Union[Scope, exp.Table], column: exp.Column
) -> t.Optional[exp.DataType]:
"""
Get a column's type by tracing through scopes/tables to find the base table.

Args:
source: The source to search - can be a Scope (to iterate its sources) or a Table.
column: The column to find the type for.

Returns:
The DataType of the column, or None if not found.
"""
if isinstance(source, exp.Table):
# base table - get the column type from schema
col_type: t.Optional[exp.DataType] = self.schema.get_column_type(source, column)
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
return col_type
elif isinstance(source, Scope):
# iterate over all sources in the scope
for source_name, nested_source in source.sources.items():
col_type = self._get_column_type_from_scope(nested_source, column)
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
return col_type

return None
42 changes: 42 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,48 @@ def test_qualify_columns(self, logger):
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id",
)

self.assertEqual(
optimizer.qualify.qualify(
parse_one(
"""
SELECT
(SELECT SUM(c.amount)
FROM UNNEST(credits) AS c
WHERE type != 'promotion') as total
FROM billing
""",
read="bigquery",
),
schema={"billing": {"credits": "ARRAY<STRUCT<amount FLOAT64, type STRING>>"}},
dialect="bigquery",
).sql(dialect="bigquery"),
"SELECT (SELECT SUM(`c`.`amount`) AS `_col_0` FROM UNNEST(`billing`.`credits`) AS `c` WHERE `type` <> 'promotion') AS `total` FROM `billing` AS `billing`",
)

self.assertEqual(
optimizer.qualify.qualify(
parse_one(
"""
WITH cte AS (SELECT * FROM base_table)
SELECT
(SELECT SUM(item.price)
FROM UNNEST(items) AS item
WHERE category = 'electronics') as electronics_total
FROM cte
""",
read="bigquery",
),
schema={
"base_table": {
"id": "INT64",
"items": "ARRAY<STRUCT<price FLOAT64, category STRING>>",
}
},
dialect="bigquery",
).sql(dialect="bigquery"),
"WITH `cte` AS (SELECT `base_table`.`id` AS `id`, `base_table`.`items` AS `items` FROM `base_table` AS `base_table`) SELECT (SELECT SUM(`item`.`price`) AS `_col_0` FROM UNNEST(`cte`.`items`) AS `item` WHERE `category` = 'electronics') AS `electronics_total` FROM `cte` AS `cte`",
)

self.check_file(
"qualify_columns",
qualify_columns,
Expand Down