diff --git a/pypika_tortoise/functions.py b/pypika_tortoise/functions.py index f2af4d4..25c428c 100644 --- a/pypika_tortoise/functions.py +++ b/pypika_tortoise/functions.py @@ -9,8 +9,9 @@ from typing import TYPE_CHECKING, Any from .context import SqlContext -from .enums import SqlTypes -from .terms import AggregateFunction, Function, Star, Term +from .enums import Dialects, SqlTypes +from .exceptions import DialectNotSupported +from .terms import AggregateFunction, Function, Star, Term, ValueWrapper from .utils import builder if TYPE_CHECKING: @@ -242,8 +243,82 @@ def __init__(self, term: Any, alias: str | None = None) -> None: class Trim(Function): + def __init__( + self, + term: Any, + trim_chars: str = " ", + alias: str | None = None, + ) -> None: + args = [term] + if trim_chars != " ": + args.append(ValueWrapper(trim_chars)) + + super().__init__("TRIM", *args, alias=alias) + + def get_function_sql(self, ctx: SqlContext) -> str: + if len(self.args) == 1: + return super().get_function_sql(ctx) + + args_sql = [self.get_arg_sql(arg, ctx) for arg in self.args] + if ctx.dialect == Dialects.SQLITE: + args = ",".join(args_sql) + else: + args = f"BOTH {args_sql[1]} FROM {args_sql[0]}" + + return "{name}({args})".format( + name=self.get_dialect_special_name(ctx.dialect) or self.name, + args=args, + ) + + +class LTrim(Function): + def __init__(self, term: Any, alias: str | None = None) -> None: + super().__init__("LTRIM", term, alias=alias) + + +class RTrim(Function): def __init__(self, term: Any, alias: str | None = None) -> None: - super().__init__("TRIM", term, alias=alias) + super().__init__("RTRIM", term, alias=alias) + + +class _Pad(Function): + db_function: str + + def __init__( + self, term: Any, length: int, fill_text: str = " ", alias: str | None = None + ) -> None: + super().__init__( + self.db_function, + term, + ValueWrapper(length), + ValueWrapper(fill_text), + alias=alias, + ) + + def get_sql(self, ctx: SqlContext) -> str: + if ctx.dialect in [Dialects.SQLITE, Dialects.MSSQL]: + raise DialectNotSupported(f"{self.db_function} is not supported in {ctx.dialect}.") + + return super().get_sql(ctx) + + +class LPad(_Pad): + db_function = "LPAD" + + +class RPad(_Pad): + db_function = "RPAD" + + +class Replace(Function): + def __init__(self, term: Any, search: str, replacement: str, alias: str | None = None) -> None: + super().__init__( + "REPLACE", + term, + ValueWrapper(search), + ValueWrapper(replacement), + alias=alias, + ) class SplitPart(Function): diff --git a/tests/test_functions.py b/tests/test_functions.py index 18313c2..101d4ae 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -6,8 +6,11 @@ from pypika_tortoise import Table as T from pypika_tortoise import functions as fn from pypika_tortoise.context import DEFAULT_SQL_CONTEXT +from pypika_tortoise.dialects.mssql import MSSQLQuery from pypika_tortoise.dialects.postgresql import PostgreSQLQuery +from pypika_tortoise.dialects.sqlite import SQLLiteQuery from pypika_tortoise.enums import SqlTypes +from pypika_tortoise.exceptions import DialectNotSupported class FunctionTests(unittest.TestCase): @@ -521,6 +524,118 @@ def test__substring(self): self.assertEqual('SELECT SUBSTRING("foo",2,6) FROM "abc"', str(q)) + def test__trim__field(self): + q = Q.from_(self.t).select(fn.Trim(self.t.foo)) + + self.assertEqual('SELECT TRIM("foo") FROM "abc"', str(q)) + + def test__trim__field__chars(self): + q = Q.from_(self.t).select(fn.Trim(self.t.foo, trim_chars="x")) + + self.assertEqual( + 'SELECT TRIM(BOTH \'x\' FROM "foo") FROM "abc"', q.get_sql(PostgreSQLQuery.SQL_CONTEXT) + ) + + def test__trim__field__chars__sqlite(self): + q = Q.from_(self.t).select(fn.Trim(self.t.foo, trim_chars="x")) + + self.assertEqual('SELECT TRIM("foo",\'x\') FROM "abc"', str(q)) + + def test__trim__str(self): + q = Q.select(fn.Trim(" abc ")) + + self.assertEqual("SELECT TRIM(' abc ')", str(q)) + + def test__trim__str__chars(self): + q = Q.select(fn.Trim("xxabcxx", trim_chars="x")) + + self.assertEqual( + "SELECT TRIM(BOTH 'x' FROM 'xxabcxx')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT) + ) + + def test__trim__str__chars__sqlite(self): + q = Q.select(fn.Trim("xxabcxx", trim_chars="x")) + + self.assertEqual("SELECT TRIM('xxabcxx','x')", str(q)) + + def test__ltrim__str(self): + q = Q.select(fn.LTrim(" abc")) + + self.assertEqual("SELECT LTRIM(' abc')", str(q)) + + def test__ltrim__field(self): + q = Q.from_(self.t).select(fn.LTrim(self.t.foo)) + + self.assertEqual('SELECT LTRIM("foo") FROM "abc"', str(q)) + + def test__rtrim__str(self): + q = Q.select(fn.RTrim("abc ")) + + self.assertEqual("SELECT RTRIM('abc ')", str(q)) + + def test__rtrim__field(self): + q = Q.from_(self.t).select(fn.RTrim(self.t.foo)) + + self.assertEqual('SELECT RTRIM("foo") FROM "abc"', str(q)) + + def test__lpad__str(self): + q = Q.select(fn.LPad("abc", 5)) + + self.assertEqual("SELECT LPAD('abc',5,' ')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT)) + + def test__lpad__str_with_fill(self): + q = Q.select(fn.LPad("abc", 5, "x")) + + self.assertEqual("SELECT LPAD('abc',5,'x')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT)) + + def test__lpad__field(self): + q = Q.from_(self.t).select(fn.LPad(self.t.foo, 10, "-")) + + self.assertEqual( + 'SELECT LPAD("foo",10,\'-\') FROM "abc"', q.get_sql(PostgreSQLQuery.SQL_CONTEXT) + ) + + def test__lpad__sqlite__mssql__raises(self): + q = Q.select(fn.LPad("abc", 5)) + + for dialect in [SQLLiteQuery.SQL_CONTEXT, MSSQLQuery.SQL_CONTEXT]: + with self.subTest(dialect=dialect), self.assertRaises(DialectNotSupported): + q.get_sql(dialect) + + def test__rpad__str(self): + q = Q.select(fn.RPad("abc", 5)) + + self.assertEqual("SELECT RPAD('abc',5,' ')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT)) + + def test__rpad__str_with_fill(self): + q = Q.select(fn.RPad("abc", 5, "x")) + + self.assertEqual("SELECT RPAD('abc',5,'x')", q.get_sql(PostgreSQLQuery.SQL_CONTEXT)) + + def test__rpad__field(self): + q = Q.from_(self.t).select(fn.RPad(self.t.foo, 10, "-")) + + self.assertEqual( + 'SELECT RPAD("foo",10,\'-\') FROM "abc"', q.get_sql(PostgreSQLQuery.SQL_CONTEXT) + ) + + def test__rpad__sqlite__mssql__raises(self): + q = Q.select(fn.RPad("abc", 5)) + + for dialect in [SQLLiteQuery.SQL_CONTEXT, MSSQLQuery.SQL_CONTEXT]: + with self.subTest(dialect=dialect), self.assertRaises(DialectNotSupported): + q.get_sql(dialect) + + def test__replace__str(self): + q = Q.select(fn.Replace("abcde", "cd", "xx")) + + self.assertEqual("SELECT REPLACE('abcde','cd','xx')", str(q)) + + def test__replace__field(self): + q = Q.from_(self.t).select(fn.Replace(self.t.foo, "old", "new")) + + self.assertEqual("SELECT REPLACE(\"foo\",'old','new') FROM \"abc\"", str(q)) + class CastTests(unittest.TestCase): t = T("abc")