Table Of Contents

Source code for mitto.iov2.steps.builtin

"""Builtin steps"""
import copy
import datetime
import logging
from typing import Any, Dict, Optional

import google.api_core.exceptions
import sqlalchemy
from import bigquery
from sqlalchemy.dialects import mssql
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.schema import CreateSchema
from sqlalchemy.sql.sqltypes import String, Text

from mitto import tz
from mitto.utils import column_spec, table_spec, table_spec_parse
from mitto.exc import Done
from import (
from import (
from mitto.iov2 import MITTO_LAST_MODIFIED
from mitto.iov2.input import StoreInput
from mitto.iov2.transform import builtin
from mitto.iov2.utils import log_iter_time_wrapper, has_table
from mitto.jsonpath import jsonpath_parse
from mitto.sample.value import MAX_VARCHAR
from mitto.sdl import sdl_build_table, sdl_column_sqltype, sdl_reflect_table
from .base import BaseStep

# VARCHAR(n), 1 <= n <= 8000 has 'length == n'
# VARCHAR(max) has 'length == None'
# SDL 'String' with 'length == n' converts to VARCHAR(n) for 1 <= n <= 8000
# SDL 'String' with 'length == n' converts to VARCHAR(max) for n > 8000
# SDL 'Text' converts to VARCHAR(max)
# FIXME: This should be defined elsewhere

SNOWFLAKE_VARCHAR_MAX = 16 * 1024 * 1024  # Snowflake allows up to 16M

    "BigInteger": "Integer",
    "DateTime": "TIMESTAMP",
    "SmallInteger": "Integer",
    "Text": "String",
    "Unicode": "String",

[docs]class CreateTable(BaseStep): """Create an output database table from SDL. Parameters ---------- environ The Mitto job's environment. It is an error to provide a value for this parameter in a job configuration -- Mitto provides this value automatically. max_varchar_length Maximum length of strings that can be saved in a column with type VARCHAR. If a column contains string values longer than this, the column will be assigned a type of TEXT. The default value is 65,536. kwargs A dict containing additional keys/values that are passed to `mitto.iov2.db.todb()` via that function's `kwargs` argument. Currently, the only supported key/value pair is `batch_size`, which controls the number of rows sent to the database at a single time. The default value is 5,000. Examples -------- Examples of using `CreateTable` as a step in a Mitto job configuration. .. code-block:: :caption: Example One -- use default values "steps": [ ... { "use": "mitto.iov2.steps.CreateTable" }, ... ] .. code-block:: :caption: Example Two -- override default values "steps": [ ... { "use": "mitto.iov2.steps.CreateTable", "max_varchar_length": 16384, "batch_size": 10000 }, ... ] """ DROP = "DROP TABLE IF EXISTS {};" RENAME = "ALTER TABLE {} RENAME TO \"{}\"" def __init__( self, environ, max_varchar_length: int = MAX_VARCHAR, **kwargs: Dict[str, Any], ): super().__init__(environ) self.engine = self.environ[ENGINE] # Different DB engines require different quote chars (e.g. PostgreSQL # requires double quotes " which will not work in MySQL which requires # a backtick `) self.preparer = self.engine.dialect.identifier_preparer self.tablename = self.environ[TABLE] self.schema = self.environ[SCHEMA] self.sdl = self.environ[SDL] self.tablespec = table_spec( self.tablename, schema=self.schema, preparer=self.preparer) self.max_varchar_length = max_varchar_length self.kwargs = kwargs # sanity checks assert self.tablename @property def is_sqlserver(self): """ Determine whether or not the engine is MS SQLServer """ return str(self.engine.url).startswith("mssql+") @property def is_mysql(self): """Determine whether the engine is MySQL.""" return self.engine.url.get_backend_name() == "mysql" @property def alter_alter(self): """ Generate the ALTER TABLE SQL in the correct dialect """ if self.is_sqlserver: return "ALTER TABLE {} ALTER COLUMN {} {}" elif self.is_mysql: return "ALTER TABLE {} MODIFY {} {}" return "ALTER TABLE {} ALTER COLUMN {} TYPE {}" @property def alter_add(self): """ Generate 'ALTER TABLE ... ADD ...' SQL in the correct dialect """ if self.is_sqlserver or self.is_mysql: return "ALTER TABLE {} ADD {} {} NULL" return "ALTER TABLE {} ADD COLUMN {} {} NULL" def __call__(self): engine_name = column_names = [col["name"] for col in self.sdl["columns"]] dups = {col for col in column_names if column_names.count(col) > 1} if dups: raise RuntimeError("Duplicated columns found %s" % dups) if self.schema: if hasattr(self.engine.dialect, "has_schema"): # only postgres implements has_schema if not self.engine.dialect.has_schema(self.engine, self.schema): self.engine.execute(CreateSchema(self.schema)) elif engine_name.startswith("snowflake") or self.is_mysql: self.engine.execute( "CREATE SCHEMA IF NOT EXISTS " + self.schema ) elif engine_name.startswith("bigquery"): # Create given or default dataset if not exists client = self.environ[CLIENT] dataset = bigquery.Dataset(f"{client.project}.{self.schema}") dataset.location = client.location try: dataset = client.create_dataset(dataset, exists_ok=False) "Created dataset %s.%s", client.project, dataset.dataset_id ) except google.api_core.exceptions.Conflict: pass self._set_initial_sdl_string_columns() if engine_name.startswith("redshift"): self._adapt_to_redshift(self.sdl) if engine_name.startswith("snowflake"): self._adapt_to_snowflake(self.sdl) elif engine_name.startswith("bigquery"): self._adapt_to_bigquery(self.sdl) elif self.is_mysql: self._adapt_mysql_text_type(self.sdl) if has_table(self.engine, self.tablename, schema=self.schema): "Table %s already exists, adapting to new data", self.tablespec, ) self._adapt_table() else:"Creating table: %s", self.tablespec) self._create_table() @staticmethod def _sqlserver_adapt_column(old_col: Optional[dict], new_type): """ Adjust column types for sqlserver quirks """ if new_type.python_type == datetime.datetime: return mssql.DATETIME2() old_col_type = old_col["type"] if old_col else None old_col_length = old_col.get("length") if old_col else None if isinstance(new_type, (String, Text)): if old_col_type == "Text" or isinstance(new_type, Text): return mssql.VARCHAR() if old_col_type == "String" and old_col_length is None: # Already VARCHAR(max) return mssql.VARCHAR() new_type_length = new_type.length or 0 if new_type_length > SQLSERVER_VARCHARN_MAX: # Adapt to VARCHAR(max) return mssql.VARCHAR() old_col_length = old_col_length or 0 if new_type_length > old_col_length: # Adapt VARCHAR(n) to VARCHAR(n+m) # n+m <= SQLSERVER_VARCHARN_MAX return mssql.VARCHAR(new_type.length) return new_type def _adapt_redshift(self, main_table): # redshift does not support table alteration # we need to create a new table sdl = copy.deepcopy(self.sdl) sdl["name"] = "cttmp__{}".format( sdl["schema"] = main_table.schema new_table = sdl_build_table(sdl) try: new_table.drop(bind=self.engine) except ProgrammingError: pass new_table.create(bind=self.engine) select_columns = [ sqlalchemy.sql.cast(column, sqlalchemy.Integer) if isinstance( column.type, sqlalchemy.Boolean ) else column for column in main_table.columns ] # pylint: disable=no-value-for-parameter insert = new_table.insert().from_select( main_table.c,, ) self.engine.execute(insert) drop_str = self.get_drop_str(main_table) self.engine.execute(drop_str) rename_str = self.get_rename_str(new_table, main_table) self.engine.execute(rename_str) def _snowflake_swap_table_stmt( self, src: sqlalchemy.Table, dest: sqlalchemy.Table ) -> str: # can't use table.fullname since snowflake allows periods # in table names... src_fullname = table_spec(, src.schema, self.preparer) dest_fullname = table_spec(, dest.schema, self.preparer) return "ALTER TABLE {} SWAP WITH {}".format( src_fullname, dest_fullname )
[docs] def get_drop_str(self, table: sqlalchemy.Table) -> str: """Constructs a DROP statement string with schema and table name. Full table name is escaped regarding the engine's dialect e.g. for the PostgreSQL it will return: DROP TABLE IF EXISTS "schema"."table_name"; """ table_name = table_spec(, table.schema, self.preparer) return self.DROP.format(table_name)
[docs] def get_rename_str(self, old_table: sqlalchemy.Table, new_table: sqlalchemy.Table) -> str: """Constructs a RENAME statement string with schema and table name. Full table name is escaped regarding the engine's dialect e.g. for the PostgreSQL it will return: ALTER TABLE "schema.table_name" RENAME TO table_name; """ old_table_name = table_spec(, old_table.schema, self.preparer) return self.RENAME.format(old_table_name,
def _adapt_snowflake(self, main_table): # snowflake does not support table alteration so # we need to create a new table sdl = copy.deepcopy(self.sdl) sdl["name"] = "cttmp__{}".format( sdl["schema"] = main_table.schema for col in sdl["columns"]: col["name"] = col["name"].lower() new_table = sdl_build_table(sdl) try: new_table.drop(bind=self.engine) except ProgrammingError: pass new_table.create(bind=self.engine) select_columns = [ sqlalchemy.func.to_number(column) if isinstance( column.type, sqlalchemy.Boolean ) else column for column in main_table.columns ] # pylint: disable=no-value-for-parameter insert = new_table.insert().from_select( main_table.c,, ) self.engine.execute(insert) sql = self._snowflake_swap_table_stmt( new_table, main_table ) self.engine.execute(sql) # The ctmp__ table is now the old structure. new_table.drop(bind=self.engine) def _adapt_bigquery(self, main_table): """Adapt BigQuery table to a new SDL. BigQuery does not support table alteration so we need to create a new table. """ sdl = copy.deepcopy(self.sdl) sdl["name"] = "cttmp__{}".format( new_table = sdl_build_table(sdl) client = self.environ[CLIENT] dataset_id = self.schema or self.engine.dialect.dataset_id main_table_id = f"{client.project}.{dataset_id}.{}" new_table_id = f"{client.project}.{dataset_id}.{}" client.delete_table(new_table_id, not_found_ok=True) new_table.create(bind=self.engine) # pylint: disable=no-value-for-parameter insert = new_table.insert().from_select( main_table.c,, ) self.engine.execute(insert) client.delete_table(main_table_id) job = client.copy_table(new_table_id, main_table_id) job.result() # Wait for the job to complete. def _adapt_other(self, altered_columns): for new_col, old_col in altered_columns: if old_col is None: sql_template = self.alter_add else: sql_template = self.alter_alter if self.engine.url.get_backend_name() == "postgresql": if new_col["type"] == "Integer": sql_template += f" USING {new_col['name']}::integer" elif (old_col["type"] == "Boolean" and new_col["type"] == "Float"): # postgres doesn't allow boolean to float, so # add intermediate int step: bool->int->float sql_template += \ f" USING {new_col['name']}::int::double precision" if self.is_mysql and old_col["type"] == "Boolean": # MySQL "Boolean" implemented using TINYINT + CHECK # CONSTRAINT. We have to drop the constraint before # the column type upgrade. insp = Inspector.from_engine(self.engine) check_constraints = insp.get_check_constraints( self.tablename, self.schema, ) for constraint in check_constraints: if old_col["name"] in constraint["sqltext"]: self.engine.execute( f"ALTER TABLE {self.tablespec} " f"DROP CHECK {constraint['name']}" ) new_type = sdl_column_sqltype(new_col) if self.is_sqlserver: new_type = self._sqlserver_adapt_column(old_col, new_type) elif old_col and (old_col["type"] == "Text"): # all other engines - never adapt TEXT to VARCHAR new_type = sdl_column_sqltype(old_col) sql = sql_template.format( self.tablespec, column_spec(new_col["name"], self.preparer), new_type.compile(self.engine.dialect) ) self.engine.execute(sql) @staticmethod def _adapt_to_redshift(sdl): """adapt sdl to redshift""" for col in sdl["columns"]: # redshift has no TEXT column if col["type"] == "Text": col["type"] = "String" col["length"] = 65535 @staticmethod def _adapt_to_snowflake(sdl): """ As per the snowflake documentation - There is no performance difference between using the full-length VARCHAR declaration VARCHAR(16777216) or a smaller size. STRING and TEXT are synonymous with VARCHAR. """ for col in sdl["columns"]: # redshift has no TEXT column if col["type"] == "Text" or col["type"] == "String": col["type"] = "String" col["length"] = SNOWFLAKE_VARCHAR_MAX @staticmethod def _adapt_to_bigquery(sdl): """Adapt SDL to Google Big Query.""" for col in sdl["columns"]: if "length" in col: # There is no "length" property in BigQuery columns so remove it # to not alter the columns based on it. col.pop("length") translated_type = BIGQUERY_TYPE_TRANSLATE_MAP.get(col["type"]) if translated_type is not None: col["type"] = translated_type @staticmethod def _adapt_mysql_text_type(sdl): """Adapt SDL Text columns to MySQL.""" for column in sdl["columns"]: if column["type"] == "Text" and "length" in column: del column["length"] def _set_initial_sdl_string_columns(self): """ Make initial attempt at determining column types for strings. May later be overridden by engine-specific tweaks. Must call before _adapt_to_redshift() """ for col in self.sdl["columns"]: if col["type"] == "String": length = int(col.get("length", MAX_VARCHAR + 1)) if length > self.max_varchar_length: col["type"] = "Text" def _adapt_table(self): """adapts table to sdl""" meta = sqlalchemy.MetaData(schema=self.schema) main_table = sqlalchemy.Table( self.tablename, meta, autoload=True, autoload_with=self.engine, ) old_sdl = sdl_reflect_table(main_table) engine_name = if engine_name.startswith("bigquery"): self._adapt_to_bigquery(old_sdl) old_columns = {col["name"]: col for col in old_sdl["columns"]} altered_columns = [] for sdl_column in self.sdl["columns"]: old_column = old_columns.get(sdl_column["name"]) if not old_column: altered_columns.append((sdl_column, None)) continue is_different = ( old_column["type"] != sdl_column["type"] or old_column.get("length") != sdl_column.get("length") or ( # allows _sqlserver_adapt_column to adapt TEXT to VARCHAR self.is_sqlserver and "Text" in [sdl_column["type"], old_column["type"]] ) ) if is_different: altered_columns.append((sdl_column, old_column)) continue if not altered_columns: return for new_col, old_col in altered_columns: old_type = "" if old_col: old_type = "{} {}".format( old_col["type"], old_col.get("length") or "", ) new_type = "{} {}".format( new_col["type"], new_col.get("length") or "", ) msg = "syncing column %s was: %s will be: %s" msg, new_col["name"], old_type, new_type, ) if engine_name.startswith("redshift"): self._adapt_redshift(main_table) elif engine_name.startswith("snowflake"): self._adapt_snowflake(main_table) elif engine_name.startswith("bigquery"): self._adapt_bigquery(main_table) else: self._adapt_other(altered_columns) def _create_table(self): sdl = copy.deepcopy(self.sdl) sdl["name"] = self.tablename sdl["schema"] = self.schema if not sdl["columns"]: if COUNT not in self.environ: self.environ[COUNT] = 0 raise Done("No columns, cannot create table") table = sdl_build_table(sdl, **self.kwargs) if self.is_sqlserver: for column in table.c: column.type = self._sqlserver_adapt_column(None, column.type) table.create(self.engine) self.environ[CREATED] = self.tablespec
[docs]class MaxTimestamp(BaseStep): """ Find the maximum values for a given timestamp column """ def __init__( self, environ, column, delta: dict = None, tablename: str = None, schema: str = None, params: dict = None, tz_default: str = None, ): super().__init__(environ) self.column = column = delta self.tablename = tablename or self.environ[TABLE] self.schema = schema or self.environ[SCHEMA] self.params = params or dict() self.tz_default = tz_default def _find_column(self, table, column_name): """ Find the column object named 'self.column' raises a runtime error if not found. """ for column in table.c: if == column_name: return column
[docs] def ensure_tz_aware(self, timestamp): """ If timestamp is a datetime object, ensure that it is tz-aware. """ if ((not isinstance(timestamp, tz.datetime)) or getattr(timestamp, "tzinfo")): return timestamp chosen_tz = self.tz_default if self.tz_default in \ tz.pytz.all_timezones else "UTC""using timezone '%s' for timestamp", chosen_tz) timezone = tz.pytz.timezone(chosen_tz) timestamp_with_tz = timezone.localize(timestamp) return timestamp_with_tz
def __call__(self): dbo = self.environ[ENGINE] meta = sqlalchemy.MetaData(schema=self.schema) tablespec = table_spec(self.tablename, schema=self.schema) if not dbo.has_table(self.tablename, schema=self.schema): "Table %s does not exist, skipping timestamp", tablespec) return table = sqlalchemy.Table(self.tablename, meta, autoload=True, autoload_with=dbo) timestamp_column = self._find_column(table, self.column) if timestamp_column is None:"MaxTimestamp: column %s not found", self.column) return query =[sqlalchemy.func.max(timestamp_column)]) for key, value in self.params.items(): filter_column = self._find_column(table, key) if filter_column is None:"MaxTimestamp: params column %s not found," " skipping timestamp", key) return query = query.where(filter_column == value) timestamp = dbo.execute(query).scalar() timestamp = self.ensure_tz_aware(timestamp) fmt = "timestamp MAX(%s)%s: %s" if not timestamp:, self.column, self.params, None) return if isinstance(timestamp, str): timestamp = tz.parse(timestamp) if"timestamp delta=%s", time_delta = tz.timedelta(** timestamp = timestamp - time_delta if isinstance(timestamp, (int, float)): self.environ[TIMESTAMP] = timestamp else: self.environ[TIMESTAMP] = timestamp.isoformat() fmt, self.column, self.params, self.environ[TIMESTAMP], )
class DropTable(BaseStep): """ Post-run hook that drops a table without checking for it's existence """ def __init__(self, environ, tablename): super().__init__(environ) self.tablename = tablename def __call__(self): engine = self.environ[ENGINE] schema = self.environ[SCHEMA] meta = sqlalchemy.MetaData(schema=schema) table = sqlalchemy.Table(self.tablename, meta, autoload=True, autoload_with=engine) table.drop(bind=engine) tablespec = table_spec(self.tablename, schema=schema)"Dropped table '%s'", tablespec) self.environ[DROPPED] = tablespec class CleanupTempTables(BaseStep): """ Drop all temporary tables if they exist """ def __call__(self): tables = self.environ.get(TEMP) if not tables: return engine = self.environ[ENGINE] if isinstance(tables, str): tables = [tables] for tablespec in tables: schema, tablename = table_spec_parse(tablespec) meta = sqlalchemy.MetaData(schema=schema) table = sqlalchemy.Table(tablename, meta, autoload=True, autoload_with=engine) table.drop(bind=engine)"Dropped table '%s'", tablespec) self.environ[DROPPED] = tablespec class CollectMeta(BaseStep): """ Count the rows in the output table """ meta_metadata = sqlalchemy.MetaData(schema="meta") Base = declarative_base(metadata=meta_metadata) class Meta(Base): """ Model that stores meta information of given table """ __tablename__ = "meta" schema = sqlalchemy.Column(sqlalchemy.String(length=128)) tablename = sqlalchemy.Column(sqlalchemy.String(length=128)) count = sqlalchemy.Column(sqlalchemy.Integer) created_at = sqlalchemy.Column(sqlalchemy.DateTime) updated_at = sqlalchemy.Column(sqlalchemy.DateTime) __table_args__ = ( sqlalchemy.PrimaryKeyConstraint("schema", "tablename"), ) def __init__(self, environ, tablename=None, schema=None): super().__init__(environ) self.engine = self.environ[ENGINE] self.tablename = tablename or environ[TABLE] self.schema = schema or self._schema() self.Session = environ[SESSION] def _schema(self): if self.environ.get(SCHEMA): return self.environ[SCHEMA] if not self.engine: return None insp = sqlalchemy.inspect(self.engine) return insp.default_schema_name def _store_meta(self, count): meta_schema = self.meta_metadata.schema if hasattr(self.engine.dialect, "has_schema") and \ not self.engine.dialect.has_schema(self.engine, meta_schema): self.engine.execute(CreateSchema(meta_schema)) else: inspector = Inspector.from_engine(self.engine) schema_names = inspector.get_schema_names() if meta_schema not in schema_names: self.engine.execute(CreateSchema(meta_schema))"Created schema: %s", meta_schema) if not self.engine.has_table( self.Meta.__tablename__, schema=meta_schema): # pylint: disable=bare-except try: self.engine.execute(CreateSchema(meta_schema)) except Exception: # pylint: disable=broad-except # Let create_all fail if the schema couldn't be created. # This works if the schema exists, but the table doesn't. pass self.meta_metadata.create_all(self.environ[ENGINE]) session = self.Session() query = session.query(self.Meta).filter_by( schema=self.schema, tablename=self.tablename, ) meta_obj = query.first() if not meta_obj: meta_obj = self.Meta( schema=self.schema, tablename=self.tablename,, ) meta_obj.updated_at = meta_obj.count = count session.add(meta_obj) session.commit() return meta_obj def __call__(self): meta = sqlalchemy.MetaData(schema=self.schema) table = sqlalchemy.Table(self.tablename, meta, autoload=True, autoload_with=self.engine) query =[sqlalchemy.func.count()]).select_from(table) count = self.engine.execute(query).scalar() if self.can_store(): self._store_meta(count) else:"Meta cannot be stored in the output database.") self.environ[COUNT] = count tablespec = table_spec(self.tablename, schema=self.schema)"'%s' row count: %d", tablespec, count) def can_store(self) -> bool: """Check if Meta can be stored in the database.""" dbo = str(self.engine.url).lower() return not any(dbo.startswith(name) for name in ( "bigquery", "snowflake" )) class CollectStoreMeta(BaseStep): """collects metadata about store and save it in environ""" def __init__(self, environ): super().__init__(environ) = environ[STORE] def __call__(self): if not"No store") return count = self.environ[COUNT] ="Count %s", count) class SyncStore(BaseStep): """ sync store with the input """ def __init__(self, environ, key=None, updated_at=None): super().__init__(environ) self.input_ = environ[INPUTTER] self.updated_at = updated_at self.key = [key] if isinstance(key, str) else key self.key_jpaths = None self.updated_at_jpath = None if self.key: self.key_jpaths = [jsonpath_parse(jpa) for jpa in self.key] self.updated_at_jpath = jsonpath_parse(self.updated_at) = environ[STORE] def __call__(self): """run inputter and sync the content to store""" last_modified = if last_modified and hasattr(self.input_, "updated_at"): self.input_.updated_at(last_modified.isoformat()) if self.updated_at == "$." + MITTO_LAST_MODIFIED: input_ = builtin.FilterLastModifiedTransform( self.input_, self.environ, ) else: input_ = builtin.FilterAfterJobStartTimeTransform( self.input_, self.environ, ) data = log_iter_time_wrapper(iter(input_), self.logger) "Store last modified date: %s [%s]", last_modified and last_modified.isoformat(), self.updated_at, ) index = -1 with as batch: for index, row in enumerate(data): checkpoint = getattr(row, "checkpoint", False) if checkpoint: batch.commit() key = self._get_key(row) or [str(index)] updated_at = self._get_updated_at(row), row, updated_at)"Synchronized %s records", index + 1) self.environ[STORE_UPDATES] = index + 1 self.environ[INPUTTER] = StoreInput( def _get_key(self, row): if self.key: try: return [ parser.find(row)[0].value for parser in self.key_jpaths ] except IndexError: raise RuntimeError( "No store key in the row {}".format(self.key) ) def _get_updated_at(self, row): if self.updated_at_jpath: try: return self.updated_at_jpath.find(row)[0].value except IndexError: raise RuntimeError( "No store updated_at " "column in the row {}".format(self.updated_at) ) class DropStore(BaseStep): """ Drop Job related store""" def __call__(self): store = self.environ.get(STORE) if not store: return store.drop()"Store %s dropped.",