获取PySpark中列的名称/别名

pkln4tw6  于 2023-01-08  发布在  Spark
关注(0)|答案(4)|浏览(210)

我定义了一个列对象,如下所示:

column = F.col('foo').alias('bar')

我知道我可以使用str(column)得到完整的表达式,但是我怎样才能只得到列的别名呢?
在这个例子中,我寻找一个函数get_column_name,其中get_column_name(column)返回字符串bar

mwkjh3gx

mwkjh3gx1#

一种方法是通过正则表达式:

from pyspark.sql.functions import col
column = col('foo').alias('bar')
print(column)
#Column<foo AS `bar`>

import re
print(re.findall("(?<=AS `)\w+(?=`>$)", str(column)))[0]
#'bar'
yqyhoc1h

yqyhoc1h2#

或者,我们可以使用一个 Package 器函数来调整Column.aliasColumn.name方法的行为,将 alias only 存储在AS属性中:

from pyspark.sql import Column, SparkSession
from pyspark.sql.functions import col, explode, array, struct, lit
SparkSession.builder.getOrCreate()

def alias_wrapper(self, *alias, **kwargs):
    renamed_col = Column._alias(self, *alias, **kwargs)
    renamed_col.AS = alias[0] if len(alias) == 1 else alias
    return renamed_col

Column._alias, Column.alias, Column.name, Column.AS = Column.alias, alias_wrapper, alias_wrapper, None

从而保证:

assert(col("foo").alias("bar").AS == "bar")
# `name` should act like `alias`
assert(col("foo").name("bar").AS == "bar")
# column without alias should have None in `AS`
assert(col("foo").AS is None)
# multialias should be handled
assert(explode(array(struct(lit(1), lit("a")))).alias("foo", "bar").AS == ("foo", "bar"))
rmbxnbpk

rmbxnbpk3#

我注意到在一些系统中,你可能会在列的名字周围有倒勾。
选项1(无正则表达式):第一个月

from pyspark.sql.functions import col
col_1 = col('foo')
col_2 = col('foo').alias('bar')
col_3 = col('foo').alias('bar').alias('baz')

s = str(col_1)
print(col_1)
print(s.replace("`", "").split("'")[-2].split(" AS ")[-1])
# Column<'foo'>
# foo

s = str(col_2)
print(col_2)
print(s.replace("`", "").split("'")[-2].split(" AS ")[-1])
# Column<'foo AS bar'>
# bar

s = str(col_3)
print(col_3)
print(s.replace("`", "").split("'")[-2].split(" AS ")[-1])
# Column<'foo AS bar AS baz'>
# baz

选项2(正则表达式):模式'.*??(\w+)?'看起来足够安全:
re.search(r"'.*??(\w+)?'", str(col)).group(1)

from pyspark.sql.functions import col
col_1 = col('foo')
col_2 = col('foo').alias('bar')
col_3 = col('foo').alias('bar').alias('baz')

import re

print(col_1)
print(re.search(r"'.*?`?(\w+)`?'", str(col_1)).group(1))
# Column<'foo'>
# foo

print(col_2)
print(re.search(r"'.*?`?(\w+)`?'", str(col_2)).group(1))
# Column<'foo AS bar'>
# bar

print(col_3)
print(re.search(r"'.*?`?(\w+)`?'", str(col_3)).group(1))
# Column<'foo AS bar AS baz'>
# baz
kdfy810k

kdfy810k4#

对于PySpark 3.x来说,倒勾号看起来被引号取代了,所以这在早期的Spark版本上可能不能开箱即用,但应该很容易修改。

from pyspark.sql import Column

def get_column_name(col: Column) -> str:
    """
    PySpark doesn't allow you to directly access the column name with respect to aliases
    from an unbound column. We have to parse this out from the string representation.

    This works on columns with one or more aliases as well as unaliased columns.

    Returns:
        Col name as str, with respect to aliasing
    """
    c = str(col).lstrip("Column<'").rstrip("'>")
    return c.split(' AS ')[-1]

用于验证行为的一些测试:

import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark() -> SparkSession:
    # Provide a session spark fixture for all tests
    yield SparkSession.builder.getOrCreate()

def test_get_col_name(spark):
    col = f.col('a')
    actual = get_column_name(col)
    assert actual == 'a'

def test_get_col_name_alias(spark):
    col = f.col('a').alias('b')
    actual = get_column_name(col)
    assert actual == 'b'

def test_get_col_name_multiple_alias(spark):
    col = f.col('a').alias('b').alias('c')
    actual = get_column_name(col)
    assert actual == 'c'

相关问题