Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some bugs with Groupby and CacheTable #14

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PersistedView, ViewType}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, HiveTableRelation}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -128,7 +127,7 @@ trait LineageParser {
exp.toAttribute,
if (!containsCountAll(exp.child)) references
else references + exp.toAttribute.withName(AGGREGATE_COUNT_COLUMN_IDENTIFIER))
case a: Attribute => a -> a.references
case a: Attribute => a -> AttributeSet(a)
}
ListMap(exps: _*)
}
Expand All @@ -149,6 +148,9 @@ trait LineageParser {
attr.withQualifier(attr.qualifier.init)
case attr if attr.name.equalsIgnoreCase(AGGREGATE_COUNT_COLUMN_IDENTIFIER) =>
attr.withQualifier(qualifier)
case attr if isNameWithQualifier(attr, qualifier) =>
val newName = attr.name.split('.').last.stripPrefix("`").stripSuffix("`")
attr.withName(newName).withQualifier(qualifier)
})
}
} else {
Expand All @@ -160,6 +162,12 @@ trait LineageParser {
}
}

private def isNameWithQualifier(attr: Attribute, qualifier: Seq[String]): Boolean = {
val nameTokens = attr.name.split('.')
val namespace = nameTokens.init.mkString(".")
nameTokens.length > 1 && namespace.endsWith(qualifier.mkString("."))
}

private def mergeRelationColumnLineage(
parentColumnsLineage: AttributeMap[AttributeSet],
relationOutput: Seq[Attribute],
Expand Down Expand Up @@ -327,6 +335,31 @@ trait LineageParser {
joinColumnsLineage(parentColumnsLineage, getSelectColumnLineage(p.aggregateExpressions))
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)

case p: Expand =>
val references =
p.projections.transpose.map(_.flatMap(x => x.references)).map(AttributeSet(_))

val childColumnsLineage = ListMap(p.output.zip(references): _*)
val nextColumnsLineage =
joinColumnsLineage(parentColumnsLineage, childColumnsLineage)
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)

case p: Window =>
val windowColumnsLineage =
ListMap(p.windowExpressions.map(exp => (exp.toAttribute, exp.references)): _*)

val nextColumnsLineage = if (parentColumnsLineage.isEmpty) {
ListMap(p.child.output.map(attr => (attr, attr.references)): _*) ++ windowColumnsLineage
} else {
parentColumnsLineage.map {
case (k, _) if windowColumnsLineage.contains(k) =>
k -> windowColumnsLineage(k)
case (k, attrs) =>
k -> AttributeSet(attrs.flatten(attr =>
windowColumnsLineage.getOrElse(attr, AttributeSet(attr))))
}
}
p.children.map(extractColumnsLineage(_, nextColumnsLineage)).reduce(mergeColumnsLineage)
case p: Join =>
p.joinType match {
case LeftSemi | LeftAnti =>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该代码补丁涉及 SparkSQL 中一些列的代码修改,主要改动包括:

  1. 导入了新模块 ScalarSubquery
  2. 将 Attribute.references 的返回类型 AttributeSet 化
  3. 修改了 isNameWithQualifier 函数,用于解析带有限定符 (qualifier) 的属性名
  4. 新增了对 Expand 和 Window 两种操作的支持

此外代码的可读性较好,有一些无用代码需要清理。历史版本中可能有已修复的 bug 或者有待优化的部分,但是在当前代码的语境下无法评估。

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,125 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite
}
}

test("test group by") {
withTable("t1", "t2", "v2_catalog.db.t1", "v2_catalog.db.t2") { _ =>
spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive")
spark.sql("CREATE TABLE t2 (a string, b string, c string) USING hive")
spark.sql("CREATE TABLE v2_catalog.db.t1 (a string, b string, c string)")
spark.sql("CREATE TABLE v2_catalog.db.t2 (a string, b string, c string)")
val ret0 =
exectractLineage(
s"insert into table t1 select a," +
s"concat_ws('/', collect_set(b))," +
s"count(distinct(b)) * count(distinct(c))" +
s"from t2 group by a")
assert(ret0 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")),
("default.t1.c", Set("default.t2.b", "default.t2.c")))))

val ret1 =
exectractLineage(
s"insert into table v2_catalog.db.t1 select a," +
s"concat_ws('/', collect_set(b))," +
s"count(distinct(b)) * count(distinct(c))" +
s"from v2_catalog.db.t2 group by a")
assert(ret1 == Lineage(
List("v2_catalog.db.t2"),
List("v2_catalog.db.t1"),
List(
("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")),
("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b")),
("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")))))

val ret2 =
exectractLineage(
s"insert into table v2_catalog.db.t1 select a," +
s"count(distinct(b+c))," +
s"count(distinct(b)) * count(distinct(c))" +
s"from v2_catalog.db.t2 group by a")
assert(ret2 == Lineage(
List("v2_catalog.db.t2"),
List("v2_catalog.db.t1"),
List(
("v2_catalog.db.t1.a", Set("v2_catalog.db.t2.a")),
("v2_catalog.db.t1.b", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")),
("v2_catalog.db.t1.c", Set("v2_catalog.db.t2.b", "v2_catalog.db.t2.c")))))
}
}

test("test grouping sets") {
withTable("t1", "t2") { _ =>
spark.sql("CREATE TABLE t1 (a string, b string, c string) USING hive")
spark.sql("CREATE TABLE t2 (a string, b string, c string, d string) USING hive")
val ret0 =
exectractLineage(
s"insert into table t1 select a,b,GROUPING__ID " +
s"from t2 group by a,b,c,d grouping sets ((a,b,c), (a,b,d))")
assert(ret0 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")),
("default.t1.c", Set()))))
}
}

test("test catch table with window function") {
withTable("t1", "t2") { _ =>
spark.sql("CREATE TABLE t1 (a string, b string) USING hive")
spark.sql("CREATE TABLE t2 (a string, b string) USING hive")

spark.sql(
s"cache table c1 select * from (" +
s"select a, b, row_number() over (partition by a order by b asc ) rank from t2)" +
s" where rank=1")
val ret0 = exectractLineage("insert overwrite table t1 select a, b from c1")
assert(ret0 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")))))

val ret1 = exectractLineage("insert overwrite table t1 select a, rank from c1")
assert(ret1 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.a", "default.t2.b")))))

spark.sql(
s"cache table c2 select * from (" +
s"select b, a, row_number() over (partition by a order by b asc ) rank from t2)" +
s" where rank=1")
val ret2 = exectractLineage("insert overwrite table t1 select a, b from c2")
assert(ret2 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")))))

spark.sql(
s"cache table c3 select * from (" +
s"select a as aa, b as bb, row_number() over (partition by a order by b asc ) rank" +
s" from t2) where rank=1")
val ret3 = exectractLineage("insert overwrite table t1 select aa, bb from c3")
assert(ret3 == Lineage(
List("default.t2"),
List("default.t1"),
List(
("default.t1.a", Set("default.t2.a")),
("default.t1.b", Set("default.t2.b")))))
}
}

private def exectractLineageWithoutExecuting(sql: String): Lineage = {
val parsed = spark.sessionState.sqlParser.parsePlan(sql)
val analyzed = spark.sessionState.analyzer.execute(parsed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
import org.apache.hive.service.rpc.thrift.TStatus;
import org.apache.hive.service.rpc.thrift.TStatusCode;
import org.slf4j.Logger;
Expand Down Expand Up @@ -193,12 +194,20 @@ public static JdbcConnectionParams extractURLComponents(String uri, Properties i
}
}

Pattern confPattern = Pattern.compile("([^;]*)([^;]*);?");

// parse hive conf settings
String confStr = jdbcURI.getQuery();
if (confStr != null) {
Matcher confMatcher = pattern.matcher(confStr);
Matcher confMatcher = confPattern.matcher(confStr);
while (confMatcher.find()) {
connParams.getHiveConfs().put(confMatcher.group(1), confMatcher.group(2));
String connParam = confMatcher.group(1);
if (StringUtils.isNotBlank(connParam) && connParam.contains("=")) {
int symbolIndex = connParam.indexOf('=');
connParams
.getHiveConfs()
.put(connParam.substring(0, symbolIndex), connParam.substring(symbolIndex + 1));
}
}
}

Expand Down Expand Up @@ -477,4 +486,4 @@ public static String getCanonicalHostName(String hostName) {
public static boolean isKyuubiOperationHint(String hint) {
return KYUUBI_OPERATION_HINT_PATTERN.matcher(hint).matches();
}
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码主要进行了以下工作:

  1. 引入了org.apache.commons.lang3.StringUtils类。
  2. 修改了confStr的解析方式,使用了Pattern和Matcher进行处理。
  3. 对于每个conf参数进行了非空和包含"="符号的判断,并按照"="将参数名和值解析出来,放进了connParams.getHiveConfs()中。

可能存在的风险/bug:

1.没有捕获解析URL过程中可能存在的异常,需要增加异常处理机制来提醒错误并尽可能减小错误对整体代码运行的影响。

可以改进的地方:

  1. 可以考虑引入单元测试来确保代码的正确性,并进一步测试处理异常等情况。

  2. 进一步完善方法注释,使代码更易读、易懂。

Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@
import static org.apache.kyuubi.jdbc.hive.Utils.extractURLComponents;
import static org.junit.Assert.assertEquals;

import com.google.common.collect.ImmutableMap;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.Properties;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -35,23 +40,76 @@ public class UtilsTest {
private String expectedPort;
private String expectedCatalog;
private String expectedDb;
private Map<String, String> expectedHiveConf;
private String uri;

@Parameterized.Parameters
public static Collection<String[]> data() {
public static Collection<Object[]> data() throws UnsupportedEncodingException {
return Arrays.asList(
new String[][] {
{"localhost", "10009", null, "db", "jdbc:hive2:///db;k1=v1?k2=v2#k3=v3"},
{"localhost", "10009", null, "default", "jdbc:hive2:///"},
{"localhost", "10009", null, "default", "jdbc:kyuubi://"},
{"localhost", "10009", null, "default", "jdbc:hive2://"},
{"hostname", "10018", null, "db", "jdbc:hive2://hostname:10018/db;k1=v1?k2=v2#k3=v3"},
new Object[][] {
{
"localhost",
"10009",
null,
"db",
new ImmutableMap.Builder<String, String>().put("k2", "v2").build(),
"jdbc:hive2:///db;k1=v1?k2=v2#k3=v3"
},
{
"localhost",
"10009",
null,
"default",
new ImmutableMap.Builder<String, String>().build(),
"jdbc:hive2:///"
},
{
"localhost",
"10009",
null,
"default",
new ImmutableMap.Builder<String, String>().build(),
"jdbc:kyuubi://"
},
{
"localhost",
"10009",
null,
"default",
new ImmutableMap.Builder<String, String>().build(),
"jdbc:hive2://"
},
{
"hostname",
"10018",
null,
"db",
new ImmutableMap.Builder<String, String>().put("k2", "v2").build(),
"jdbc:hive2://hostname:10018/db;k1=v1?k2=v2#k3=v3"
},
{
"hostname",
"10018",
"catalog",
"db",
new ImmutableMap.Builder<String, String>().put("k2", "v2").build(),
"jdbc:hive2://hostname:10018/catalog/db;k1=v1?k2=v2#k3=v3"
},
{
"hostname",
"10018",
"catalog",
"db",
new ImmutableMap.Builder<String, String>()
.put("k2", "v2")
.put("k3", "-Xmx2g -XX:+PrintGCDetails -XX:HeapDumpPath=/heap.hprof")
.build(),
"jdbc:hive2://hostname:10018/catalog/db;k1=v1?"
+ URLEncoder.encode(
"k2=v2;k3=-Xmx2g -XX:+PrintGCDetails -XX:HeapDumpPath=/heap.hprof",
StandardCharsets.UTF_8.toString())
.replaceAll("\\+", "%20")
+ "#k4=v4"
}
});
}
Expand All @@ -61,11 +119,13 @@ public UtilsTest(
String expectedPort,
String expectedCatalog,
String expectedDb,
Map<String, String> expectedHiveConf,
String uri) {
this.expectedHost = expectedHost;
this.expectedPort = expectedPort;
this.expectedCatalog = expectedCatalog;
this.expectedDb = expectedDb;
this.expectedHiveConf = expectedHiveConf;
this.uri = uri;
}

Expand All @@ -76,5 +136,6 @@ public void testExtractURLComponents() throws JdbcUriParseException {
assertEquals(Integer.parseInt(expectedPort), jdbcConnectionParams1.getPort());
assertEquals(expectedCatalog, jdbcConnectionParams1.getCatalogName());
assertEquals(expectedDb, jdbcConnectionParams1.getDbName());
assertEquals(expectedHiveConf, jdbcConnectionParams1.getHiveConfs());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ SCOPE_TABLE: 'SCOPE_TABLE';
SOURCE_DATA_TYPE: 'SOURCE_DATA_TYPE';
IS_AUTOINCREMENT: 'IS_AUTOINCREMENT';
IS_GENERATEDCOLUMN: 'IS_GENERATEDCOLUMN';
VARCHAR: 'VARCHAR';
SMALLINT: 'SMALLINT';
CAST: 'CAST';
AS: 'AS';
KEY_SEQ: 'KEY_SEQ';
PK_NAME: 'PK_NAME';

fragment SEARCH_STRING_ESCAPE: '\'' '\\' '\'';

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码看起来只是定义了几个常量,没有什么明显的错误风险。以下是一些改进建议:

  1. 建议将定义的常量放在单独的文件中,或者与相关实现部分放在同一个文件中,这样更易于组织和维护代码。

  2. 如果其中一些常量与其他地方的常量重复,请确保它们的含义相同,以避免混淆和不必要的错误。

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ statement
SOURCE_DATA_TYPE COMMA IS_AUTOINCREMENT COMMA IS_GENERATEDCOLUMN FROM SYSTEM_JDBC_COLUMNS
(WHERE tableCatalogFilter? AND? tableSchemaFilter? AND? tableNameFilter? AND? colNameFilter?)?
ORDER BY TABLE_CAT COMMA TABLE_SCHEM COMMA TABLE_NAME COMMA ORDINAL_POSITION #getColumns
| SELECT CAST LEFT_PAREN NULL AS VARCHAR RIGHT_PAREN TABLE_CAT COMMA
CAST LEFT_PAREN NULL AS VARCHAR RIGHT_PAREN TABLE_SCHEM COMMA
CAST LEFT_PAREN NULL AS VARCHAR RIGHT_PAREN TABLE_NAME COMMA
CAST LEFT_PAREN NULL AS VARCHAR RIGHT_PAREN COLUMN_NAME COMMA
CAST LEFT_PAREN NULL AS SMALLINT RIGHT_PAREN KEY_SEQ COMMA
CAST LEFT_PAREN NULL AS VARCHAR RIGHT_PAREN PK_NAME
WHERE FALSE #getPrimaryKeys
| .*? #passThrough
;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码是一个patch,对应多行SQL语句。以下是对每一行的简要说明:

第一行:

statement : 这似乎是先前定义的语法规则的名称。

第二行到第五行:

这里定义了一个SQL语句模板,用于从数据库中获取表格中所有的列名和相关信息。

第六行到第十二行:

这也是一个SQL语句模板,用于从数据库中获取主键的信息。这个查询使用了常量值 NULL 作为占位符。

第十三行:

.*? 是一个正则表达式,表示匹配任何字符,重复零次或更多次。这一行只是一个通配符占位符,用于忽略掉不需要处理的任意字符。

总体而言,这段代码逻辑上没有显著的缺陷和风险。但如果能提供更多上下文,我们可以更准确地给出更多有价值的改进建议。

Expand Down
Loading