Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion prqlc/prqlc/src/ir/pl/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ pub enum TransformKind {
range: Range,
pipeline: Box<Expr>,
},
Append(Box<Expr>),
Append {
by: AppendBy,
bottom: Box<Expr>,
},
Loop(Box<Expr>),
}

Expand All @@ -115,6 +118,12 @@ pub enum JoinSide {
Full,
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
pub enum AppendBy {
Name,
Position,
}

impl Expr {
pub fn new(kind: impl Into<ExprKind>) -> Self {
Expr {
Expand Down
5 changes: 4 additions & 1 deletion prqlc/prqlc/src/ir/pl/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ pub fn fold_transform_kind<T: ?Sized + PlFold>(
with: Box::new(fold.fold_expr(*with)?),
filter: Box::new(fold.fold_expr(*filter)?),
},
Append(bottom) => Append(Box::new(fold.fold_expr(*bottom)?)),
Append { by, bottom } => Append {
by,
bottom: Box::new(fold.fold_expr(*bottom)?),
},
Group { by, pipeline } => Group {
by: Box::new(fold.fold_expr(*by)?),
pipeline: Box::new(fold.fold_expr(*pipeline)?),
Expand Down
5 changes: 4 additions & 1 deletion prqlc/prqlc/src/ir/rq/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ pub fn fold_transform<T: ?Sized + RqFold>(
with: fold.fold_table_ref(with)?,
filter: fold.fold_expr(filter)?,
},
Append(bottom) => Append(fold.fold_table_ref(bottom)?),
Append { by, bottom } => Append {
by,
bottom: fold.fold_table_ref(bottom)?,
},
Loop(transforms) => Loop(fold_transforms(fold, transforms)?),
};
Ok(transform)
Expand Down
7 changes: 5 additions & 2 deletions prqlc/prqlc/src/ir/rq/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use super::*;
use crate::ir::generic::ColumnSort;
use crate::ir::generic::WindowFrame;
use crate::ir::pl::JoinSide;
use crate::ir::pl::{AppendBy, JoinSide};

/// Transformation of a table.
#[derive(
Expand All @@ -27,7 +27,10 @@ pub enum Transform {
with: TableRef,
filter: Expr,
},
Append(TableRef),
Append {
by: AppendBy,
bottom: TableRef,
},
Loop(Vec<Transform>),
}

Expand Down
4 changes: 2 additions & 2 deletions prqlc/prqlc/src/semantic/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -612,11 +612,11 @@ impl Lowerer {
};
self.pipeline.push(transform);
}
pl::TransformKind::Append(bottom) => {
pl::TransformKind::Append { by, bottom } => {
let mut bottom = self.lower_table_ref(*bottom)?;
bottom.prefer_cte = false;

self.pipeline.push(Transform::Append(bottom));
self.pipeline.push(Transform::Append { by, bottom });
}
pl::TransformKind::Loop(pipeline) => {
let relation = self.lower_relation(*pipeline)?;
Expand Down
2 changes: 1 addition & 1 deletion prqlc/prqlc/src/semantic/reporting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl PlFold for FrameCollector {
pl::TransformKind::Derive { assigns: ref e }
| pl::TransformKind::Select { assigns: ref e }
| pl::TransformKind::Filter { filter: ref e }
| pl::TransformKind::Append(ref e)
| pl::TransformKind::Append { bottom: ref e, .. }
| pl::TransformKind::Loop(ref e)
| pl::TransformKind::Group {
pipeline: ref e, ..
Expand Down
14 changes: 9 additions & 5 deletions prqlc/prqlc/src/semantic/resolver/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,11 @@ impl PlFold for Flattener {
// in scope for downstream transforms in the outer pipeline. Per the PRQL
// spec a join retains the left (input) side's order, so snapshot the
// input's sort and restore it after folding the kind.
let input_sort =
matches!(kind, TransformKind::Join { .. } | TransformKind::Append(_))
.then(|| self.sort.clone());
let input_sort = matches!(
kind,
TransformKind::Join { .. } | TransformKind::Append { .. }
)
.then(|| self.sort.clone());

let kind = fold_transform_kind(self, kind)?;

Expand All @@ -177,8 +179,10 @@ impl PlFold for Flattener {
// derive {`album_name` = `name`}
// select {`artist_id`, `album_name`}
// ) (this.id == that.artist_id)
let sort = if matches!(kind, TransformKind::Join { .. } | TransformKind::Append(_))
{
let sort = if matches!(
kind,
TransformKind::Join { .. } | TransformKind::Append { .. }
) {
vec![]
} else {
self.sort.clone()
Expand Down
7 changes: 5 additions & 2 deletions prqlc/prqlc/src/semantic/resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl Resolver<'_> {
pub(super) mod test {
use insta::assert_yaml_snapshot;

use crate::ir::pl::{Expr, Lineage, PlFold};
use crate::ir::pl::{AppendBy, Expr, Lineage, PlFold};
use crate::{Errors, Result};

pub fn erase_ids(expr: Expr) -> Expr {
Expand Down Expand Up @@ -463,7 +463,10 @@ pub(super) mod test {
bottom_expr.lineage = Some(bottom_lineage);

let transform_call = TransformCall {
kind: Box::new(TransformKind::Append(Box::new(bottom_expr))),
kind: Box::new(TransformKind::Append {
by: AppendBy::Position,
bottom: Box::new(bottom_expr),
}),
input: Box::new(top_expr),
partition: None,
frame: crate::ir::pl::WindowFrame::default(),
Expand Down
118 changes: 99 additions & 19 deletions prqlc/prqlc/src/semantic/resolver/transforms.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::iter::zip;

use itertools::Itertools;
use serde::Deserialize;

use super::types::{ty_tuple_kind, type_intersection};
use super::types::{ty_tuple_kind, type_intersection, type_union_of_tuples};
use super::Resolver;
use crate::codegen::write_ty_kind;
use crate::ir::decl::{Decl, DeclKind, Module};
use crate::ir::generic::{SortDirection, WindowKind};
use crate::ir::pl::*;
Expand Down Expand Up @@ -254,9 +255,35 @@ impl Resolver<'_> {
(transform_kind, tbl)
}
"append" => {
let [bottom, top] = unpack::<2>(func.args);
let [by, bottom, top] = unpack::<3>(func.args);

(TransformKind::Append(Box::new(bottom)), top)
let by = {
let span = by.span;
let ident = by
.clone()
.try_cast(ExprKind::into_ident, Some("by"), "ident")?;

match ident.to_string().as_str() {
"position" => AppendBy::Position,
"name" => AppendBy::Name,
_ => {
return Err(Error::new(Reason::Expected {
who: Some("`by`".to_string()),
expected: "position or name".to_string(),
found: ident.to_string(),
})
.with_span(span))
}
}
};

(
TransformKind::Append {
by,
bottom: Box::new(bottom),
},
top,
)
}
"loop" => {
let [pipeline, tbl] = unpack::<2>(func.args);
Expand Down Expand Up @@ -600,11 +627,26 @@ impl Resolver<'_> {
let pipeline = pipeline.kind.into_function().unwrap().unwrap();
pipeline.return_ty.map(|x| *x)
}
TransformKind::Append(bottom) => {
TransformKind::Append { bottom, by } => {
let top = transform_call.input.ty.clone().unwrap();
let bottom_span = bottom.span;
let bottom = bottom.ty.clone().unwrap();

Some(type_intersection(top, bottom).with_span(transform_call.input.span)?)
if *by == AppendBy::Position {
Some(type_intersection(top, bottom).with_span(bottom_span)?)
} else if top.clone().is_relation() && bottom.clone().is_relation() {
Some(type_union_of_tuples(
top.into_relation().unwrap(),
bottom.into_relation().unwrap(),
)?)
} else {
return Err(Error::new_simple(format!(
"cannot append type `{}` to `{}`",
write_ty_kind(&bottom.kind),
write_ty_kind(&top.kind)
)))
.with_span(bottom_span);
}
}
})
}
Expand Down Expand Up @@ -757,10 +799,13 @@ impl TransformCall {
let right = lineage_or_default(with)?;
join(left, right)
}
Append(bottom) => {
Append { bottom, by } => {
let top = lineage_or_default(&self.input)?;
let bottom = lineage_or_default(bottom)?;
append(top, bottom)?
let bot = lineage_or_default(bottom)?;
match by {
AppendBy::Position => append(top, bot).with_span(bottom.span)?,
AppendBy::Name => append_by_name(top, bot).with_span(bottom.span)?,
}
}
Loop(_) => lineage_or_default(&self.input)?,
Sort { .. } | Filter { .. } | Take { .. } => lineage_or_default(&self.input)?,
Expand Down Expand Up @@ -855,6 +900,41 @@ fn append(mut top: Lineage, bottom: Lineage) -> Result<Lineage, Error> {
Ok(top)
}

fn append_by_name(mut top: Lineage, bottom: Lineage) -> Result<Lineage, Error> {
// Merge inputs from both relations so lineage can track both sources
// This is similar to how `join` handles inputs
top.inputs.extend(bottom.inputs);

// start with all the columns from top
let mut columns = top.columns.clone();
let top_names: HashSet<String> = top
.columns
.into_iter()
.filter_map(|c| match c {
LineageColumn::Single { name, .. } => Some(name?.name),
_ => None,
})
.collect();

// add columns from bottom that aren't already in top
for column in bottom.columns {
match column {
LineageColumn::Single { ref name, .. } => {
if let Some(name) = name.clone() {
if !top_names.contains(&name.name) {
columns.push(column);
}
}
}
LineageColumn::All { .. } => todo!(),
}
}

log::trace!("append_by_name columns: {columns:#?}");
top.columns = columns;
Ok(top)
}

impl Lineage {
pub fn clear(&mut self) {
self.prev_columns.clear();
Expand Down Expand Up @@ -1183,17 +1263,17 @@ mod tests {
assert_snapshot!(crate::tests::compile(
"from a | select {x, y} | append (from b | select {x})"
)
.unwrap_err(), @r"
.unwrap_err(), @r###"
Error:
╭─[ :1:10 ]
╭─[ :1:43 ]
1 │ from a | select {x, y} | append (from b | select {x})
│ ──────┬──────
╰──────── cannot combine relations with different numbers of columns
─────────
────── cannot combine relations with different numbers of columns
│ Help: `append` requires both tables to have matching columns
───╯
");
"###);
}

// `append` of relations whose columns have incompatible types must also
Expand All @@ -1203,15 +1283,15 @@ mod tests {
assert_snapshot!(crate::tests::compile(
"from a | select {x = 1} | append (from b | select {x = 1.0})"
)
.unwrap_err(), @r"
.unwrap_err(), @r###"
Error:
╭─[ :1:10 ]
╭─[ :1:44 ]
1 │ from a | select {x = 1} | append (from b | select {x = 1.0})
│ ─────────────
──────── cannot combine types `int` and `float`
────────┬───────
╰───────── cannot combine types `int` and `float`
───╯
");
"###);
}

#[test]
Expand Down
37 changes: 37 additions & 0 deletions prqlc/prqlc/src/semantic/resolver/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::ir::pl::*;
use crate::pr::{PrimitiveSet, Ty, TyKind, TyTupleField};
use crate::Result;
use crate::{Error, Reason, WithErrorInfo};
use itertools::Itertools;

impl Resolver<'_> {
pub fn infer_type(expr: &Expr) -> Result<Option<Ty>> {
Expand Down Expand Up @@ -455,3 +456,39 @@ fn different_column_count_error() -> Error {
Error::new_simple("cannot combine relations with different numbers of columns")
.push_hint("`append` requires both tables to have matching columns")
}

pub fn type_union_of_tuples(a: Vec<TyTupleField>, b: Vec<TyTupleField>) -> Result<Ty> {
let has_other = a.iter().any(|f| f.is_wildcard()) || b.iter().any(|f| f.is_wildcard());

let mut fields: Vec<TyTupleField> = a.into_iter().filter(|f| f.is_single()).collect_vec();

for b_field in b.into_iter().filter(|f| f.is_single()) {
match b_field {
TyTupleField::Single(b_name, b_ty) => {
match fields
.iter()
.position(|f| f.clone().into_single().ok().unwrap().0 == b_name)
{
Some(i) => {
let TyTupleField::Single(a_name, a_ty) = fields[i].clone() else {
unreachable!()
};
if let (Some(a_ty), Some(b_ty)) = (a_ty, b_ty) {
fields[i] =
TyTupleField::Single(a_name, Some(type_intersection(a_ty, b_ty)?));
}
}
None => {
fields.push(TyTupleField::Single(b_name, b_ty));
}
}
}
_ => unreachable!(),
}
}
if has_other {
fields.push(TyTupleField::Wildcard(None));
}

Ok(Ty::new(TyKind::Tuple(fields)))
}
7 changes: 6 additions & 1 deletion prqlc/prqlc/src/semantic/std.prql
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,12 @@ let window = func
tbl <relation>
-> <relation> internal window

let append = `default_db.bottom`<relation> top<relation> -> <relation> internal append
let append = func
`noresolve.by`:position
`default_db.bottom`<relation>
top<relation>
-> <relation> internal append

let intersect = `default_db.bottom`<relation> top<relation> -> <relation> (
t = top
join (b = bottom) (tuple_reduce std.and (tuple_map _eq (tuple_zip t.* b.*)))
Expand Down
Loading
Loading